1use std::collections::HashSet;
2use std::path::Path;
3
4use regex::Regex;
5use zer_core::{error::ZerError, schema::FieldKind};
6
7const DEFAULT_NAME_HEURISTICS: &str = include_str!("../heuristics_name.toml");
8const DEFAULT_VALUE_PATTERNS: &str = include_str!("../heuristics_values.toml");
9
10#[derive(Debug, Clone, serde::Deserialize)]
14pub struct NameRule {
15 pub kind: FieldKind,
16 #[serde(default)]
17 pub contains: Vec<String>,
18 #[serde(default)]
19 pub exact: Vec<String>,
20 #[serde(default)]
21 pub starts_with: Vec<String>,
22 #[serde(default)]
23 pub ends_with: Vec<String>,
24}
25
26#[derive(Debug, Clone, serde::Deserialize)]
28pub struct NameHeuristics {
29 pub rules: Vec<NameRule>,
30}
31
32impl NameHeuristics {
33 pub fn from_toml_str(s: &str) -> Result<Self, ZerError> {
35 toml::from_str(s).map_err(|e| ZerError::Config(e.to_string()))
36 }
37
38 pub fn from_file(path: &Path) -> Result<Self, ZerError> {
40 let content = std::fs::read_to_string(path)?;
41 Self::from_toml_str(&content)
42 }
43
44 pub fn load_default() -> Self {
49 if let Ok(path) = std::env::var("ZER_NAME_HEURISTICS") {
50 match Self::from_file(Path::new(&path)) {
51 Ok(h) => return h,
52 Err(e) => tracing::warn!(
53 "ZER_NAME_HEURISTICS={path:?}: failed to load ({e}), using embedded default"
54 ),
55 }
56 }
57 Self::from_toml_str(DEFAULT_NAME_HEURISTICS)
58 .expect("embedded heuristics_name.toml is always valid")
59 }
60
61 pub fn infer_kind(&self, name: &str) -> Option<FieldKind> {
64 let n = name.to_ascii_lowercase();
65 for rule in &self.rules {
66 if rule.exact.iter().any(|p| n == p.as_str())
67 || rule.contains.iter().any(|p| n.contains(p.as_str()))
68 || rule.starts_with.iter().any(|p| n.starts_with(p.as_str()))
69 || rule.ends_with.iter().any(|p| n.ends_with(p.as_str()))
70 {
71 return Some(rule.kind);
72 }
73 }
74 None
75 }
76}
77
78#[derive(Debug, serde::Deserialize)]
81struct RawValuePattern {
82 kind: FieldKind,
83 regex: String,
84 #[serde(default)]
85 threshold: f32,
86 unique_rate_min: Option<f32>,
87 unique_rate_max: Option<f32>,
88 avg_len_min: Option<f32>,
89 avg_len_max: Option<f32>,
90}
91
92#[derive(Debug, serde::Deserialize)]
93struct RawFallback {
94 default_kind: FieldKind,
95}
96
97#[derive(Debug, serde::Deserialize)]
98struct RawValuePatterns {
99 patterns: Vec<RawValuePattern>,
100 fallback: RawFallback,
101}
102
103#[derive(Debug)]
105pub struct CompiledValuePattern {
106 pub kind: FieldKind,
107 pub regex: Option<Regex>,
109 pub threshold: f32,
110 pub unique_rate_min: Option<f32>,
111 pub unique_rate_max: Option<f32>,
112 pub avg_len_min: Option<f32>,
113 pub avg_len_max: Option<f32>,
114}
115
116#[derive(Debug)]
118pub struct ValuePatterns {
119 pub patterns: Vec<CompiledValuePattern>,
120 pub fallback_kind: FieldKind,
121}
122
123impl ValuePatterns {
124 fn from_raw(raw: RawValuePatterns) -> Result<Self, ZerError> {
125 let mut patterns = Vec::with_capacity(raw.patterns.len());
126 for p in raw.patterns {
127 let regex = if p.regex.is_empty() {
128 None
129 } else {
130 Some(Regex::new(&p.regex).map_err(|e| {
131 ZerError::Config(format!("invalid regex {:?}: {e}", p.regex))
132 })?)
133 };
134 patterns.push(CompiledValuePattern {
135 kind: p.kind,
136 regex,
137 threshold: p.threshold,
138 unique_rate_min: p.unique_rate_min,
139 unique_rate_max: p.unique_rate_max,
140 avg_len_min: p.avg_len_min,
141 avg_len_max: p.avg_len_max,
142 });
143 }
144 Ok(Self { patterns, fallback_kind: raw.fallback.default_kind })
145 }
146
147 pub fn from_toml_str(s: &str) -> Result<Self, ZerError> {
149 let raw: RawValuePatterns =
150 toml::from_str(s).map_err(|e| ZerError::Config(e.to_string()))?;
151 Self::from_raw(raw)
152 }
153
154 pub fn from_file(path: &Path) -> Result<Self, ZerError> {
156 let content = std::fs::read_to_string(path)?;
157 Self::from_toml_str(&content)
158 }
159
160 pub fn load_default() -> Self {
165 if let Ok(path) = std::env::var("ZER_VALUE_PATTERNS") {
166 match Self::from_file(Path::new(&path)) {
167 Ok(p) => return p,
168 Err(e) => tracing::warn!(
169 "ZER_VALUE_PATTERNS={path:?}: failed to load ({e}), using embedded default"
170 ),
171 }
172 }
173 Self::from_toml_str(DEFAULT_VALUE_PATTERNS)
174 .expect("embedded heuristics_values.toml is always valid")
175 }
176
177 pub fn infer_kind(&self, samples: &[&str]) -> FieldKind {
182 if samples.is_empty() {
183 return self.fallback_kind;
184 }
185 let total = samples.len() as f32;
186 let unique_rate = samples.iter().collect::<HashSet<_>>().len() as f32 / total;
187 let avg_len = samples.iter().map(|s| s.len() as f32).sum::<f32>() / total;
188
189 for pat in &self.patterns {
190 let match_frac = match &pat.regex {
191 Some(re) => samples.iter().filter(|s| re.is_match(s)).count() as f32 / total,
192 None => 1.0,
193 };
194 if match_frac >= pat.threshold
195 && pat.unique_rate_min.map_or(true, |min| unique_rate >= min)
196 && pat.unique_rate_max.map_or(true, |max| unique_rate <= max)
197 && pat.avg_len_max.map_or(true, |max| avg_len <= max)
198 && pat.avg_len_min.map_or(true, |min| avg_len >= min)
199 {
200 return pat.kind;
201 }
202 }
203 self.fallback_kind
204 }
205}
206
207#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn name_heuristics_embedded_default_loads() {
215 let h = NameHeuristics::load_default();
216 assert!(!h.rules.is_empty());
217 }
218
219 #[test]
220 fn name_heuristics_matches_known_patterns() {
221 let h = NameHeuristics::load_default();
222 assert_eq!(h.infer_kind("first_name"), Some(FieldKind::Name));
223 assert_eq!(h.infer_kind("geboortedatum"), Some(FieldKind::Date));
224 assert_eq!(h.infer_kind("msisdn"), Some(FieldKind::Phone));
225 assert_eq!(h.infer_kind("postcode"), Some(FieldKind::Address));
226 assert_eq!(h.infer_kind("bsn"), Some(FieldKind::Id));
227 }
228
229 #[test]
230 fn name_heuristics_returns_none_for_unknown() {
231 let h = NameHeuristics::load_default();
232 assert_eq!(h.infer_kind("xyzzy_col"), None);
233 }
234
235 #[test]
236 fn value_patterns_embedded_default_loads() {
237 let p = ValuePatterns::load_default();
238 assert!(!p.patterns.is_empty());
239 }
240
241 #[test]
242 fn value_patterns_date_detection() {
243 let p = ValuePatterns::load_default();
244 let samples: Vec<&str> = (0..20).map(|_| "2024-03-15").collect();
245 assert_eq!(p.infer_kind(&samples), FieldKind::Date);
246 }
247
248 #[test]
249 fn value_patterns_fallback_on_empty() {
250 let p = ValuePatterns::load_default();
251 assert_eq!(p.infer_kind(&[]), FieldKind::FreeText);
252 }
253
254 #[test]
255 fn custom_name_heuristics_from_file() {
256 let dir = tempfile::tempdir().unwrap();
257 let path = dir.path().join("custom_name.toml");
258 std::fs::write(
259 &path,
260 r#"
261[[rules]]
262kind = "Id"
263exact = ["mijnkolom"]
264"#,
265 )
266 .unwrap();
267
268 let h = NameHeuristics::from_file(&path).unwrap();
269 assert_eq!(h.infer_kind("mijnkolom"), Some(FieldKind::Id));
270 assert_eq!(h.infer_kind("other"), None);
271 }
272
273 #[test]
274 fn custom_value_patterns_from_file() {
275 let dir = tempfile::tempdir().unwrap();
276 let path = dir.path().join("custom_values.toml");
277 std::fs::write(
278 &path,
279 r#"
280[[patterns]]
281kind = "Phone"
282regex = '^\+31\d{9}$'
283threshold = 0.8
284
285[fallback]
286default_kind = "FreeText"
287"#,
288 )
289 .unwrap();
290
291 let p = ValuePatterns::from_file(&path).unwrap();
292 let samples: Vec<&str> = (0..20).map(|_| "+31612345678").collect();
293 assert_eq!(p.infer_kind(&samples), FieldKind::Phone);
294 }
295
296 #[test]
297 fn invalid_toml_returns_error() {
298 let result = NameHeuristics::from_toml_str("this is not toml ][");
299 assert!(matches!(result, Err(ZerError::Config(_))));
300 }
301
302 #[test]
303 fn invalid_regex_returns_error() {
304 let result = ValuePatterns::from_toml_str(
305 r#"
306[[patterns]]
307kind = "Date"
308regex = '[invalid'
309threshold = 0.8
310
311[fallback]
312default_kind = "FreeText"
313"#,
314 );
315 assert!(matches!(result, Err(ZerError::Config(_))));
316 }
317}