1use std::collections::HashSet;
2use std::path::Path;
3
4use regex::Regex;
5use zer_core::{error::ZerError, schema::FieldKind};
6
7const DEFAULT_NAME_HEURISTICS: &str = include_str!("../heuristics_name.toml");
8const DEFAULT_VALUE_PATTERNS: &str = include_str!("../heuristics_values.toml");
9
10#[derive(Debug, Clone, serde::Deserialize)]
14pub struct NameRule {
15 pub kind: FieldKind,
16 #[serde(default)]
17 pub contains: Vec<String>,
18 #[serde(default)]
19 pub exact: Vec<String>,
20 #[serde(default)]
21 pub starts_with: Vec<String>,
22 #[serde(default)]
23 pub ends_with: Vec<String>,
24}
25
26#[derive(Debug, Clone, serde::Deserialize)]
28pub struct NameHeuristics {
29 pub rules: Vec<NameRule>,
30}
31
32impl NameHeuristics {
33 pub fn from_toml_str(s: &str) -> Result<Self, ZerError> {
35 toml::from_str(s).map_err(|e| ZerError::Config(e.to_string()))
36 }
37
38 pub fn from_file(path: &Path) -> Result<Self, ZerError> {
40 let content = std::fs::read_to_string(path)?;
41 Self::from_toml_str(&content)
42 }
43
44 pub fn load_default() -> Self {
49 if let Ok(path) = std::env::var("ZER_NAME_HEURISTICS") {
50 match Self::from_file(Path::new(&path)) {
51 Ok(h) => return h,
52 Err(e) => tracing::warn!(
53 "ZER_NAME_HEURISTICS={path:?}: failed to load ({e}), using embedded default"
54 ),
55 }
56 }
57 Self::from_toml_str(DEFAULT_NAME_HEURISTICS)
58 .expect("embedded heuristics_name.toml is always valid")
59 }
60
61 pub fn infer_kind(&self, name: &str) -> Option<FieldKind> {
64 let n = name.to_ascii_lowercase();
65 for rule in &self.rules {
66 if rule.exact.iter().any(|p| n == p.as_str())
67 || rule.contains.iter().any(|p| n.contains(p.as_str()))
68 || rule.starts_with.iter().any(|p| n.starts_with(p.as_str()))
69 || rule.ends_with.iter().any(|p| n.ends_with(p.as_str()))
70 {
71 return Some(rule.kind);
72 }
73 }
74 None
75 }
76}
77
78#[derive(Debug, serde::Deserialize)]
81struct RawValuePattern {
82 kind: FieldKind,
83 regex: String,
84 #[serde(default)]
85 threshold: f32,
86 unique_rate_min: Option<f32>,
87 unique_rate_max: Option<f32>,
88 avg_len_min: Option<f32>,
89 avg_len_max: Option<f32>,
90}
91
92#[derive(Debug, serde::Deserialize)]
93struct RawFallback {
94 default_kind: FieldKind,
95}
96
97#[derive(Debug, serde::Deserialize)]
98struct RawValuePatterns {
99 patterns: Vec<RawValuePattern>,
100 fallback: RawFallback,
101}
102
103#[derive(Debug)]
105pub struct CompiledValuePattern {
106 pub kind: FieldKind,
107 pub regex: Option<Regex>,
109 pub threshold: f32,
110 pub unique_rate_min: Option<f32>,
111 pub unique_rate_max: Option<f32>,
112 pub avg_len_min: Option<f32>,
113 pub avg_len_max: Option<f32>,
114}
115
116#[derive(Debug)]
118pub struct ValuePatterns {
119 pub patterns: Vec<CompiledValuePattern>,
120 pub fallback_kind: FieldKind,
121}
122
123impl ValuePatterns {
124 fn from_raw(raw: RawValuePatterns) -> Result<Self, ZerError> {
125 let mut patterns = Vec::with_capacity(raw.patterns.len());
126 for p in raw.patterns {
127 let regex =
128 if p.regex.is_empty() {
129 None
130 } else {
131 Some(Regex::new(&p.regex).map_err(|e| {
132 ZerError::Config(format!("invalid regex {:?}: {e}", p.regex))
133 })?)
134 };
135 patterns.push(CompiledValuePattern {
136 kind: p.kind,
137 regex,
138 threshold: p.threshold,
139 unique_rate_min: p.unique_rate_min,
140 unique_rate_max: p.unique_rate_max,
141 avg_len_min: p.avg_len_min,
142 avg_len_max: p.avg_len_max,
143 });
144 }
145 Ok(Self {
146 patterns,
147 fallback_kind: raw.fallback.default_kind,
148 })
149 }
150
151 pub fn from_toml_str(s: &str) -> Result<Self, ZerError> {
153 let raw: RawValuePatterns =
154 toml::from_str(s).map_err(|e| ZerError::Config(e.to_string()))?;
155 Self::from_raw(raw)
156 }
157
158 pub fn from_file(path: &Path) -> Result<Self, ZerError> {
160 let content = std::fs::read_to_string(path)?;
161 Self::from_toml_str(&content)
162 }
163
164 pub fn load_default() -> Self {
169 if let Ok(path) = std::env::var("ZER_VALUE_PATTERNS") {
170 match Self::from_file(Path::new(&path)) {
171 Ok(p) => return p,
172 Err(e) => tracing::warn!(
173 "ZER_VALUE_PATTERNS={path:?}: failed to load ({e}), using embedded default"
174 ),
175 }
176 }
177 Self::from_toml_str(DEFAULT_VALUE_PATTERNS)
178 .expect("embedded heuristics_values.toml is always valid")
179 }
180
181 pub fn infer_kind(&self, samples: &[&str]) -> FieldKind {
186 if samples.is_empty() {
187 return self.fallback_kind;
188 }
189 let total = samples.len() as f32;
190 let unique_rate = samples.iter().collect::<HashSet<_>>().len() as f32 / total;
191 let avg_len = samples.iter().map(|s| s.len() as f32).sum::<f32>() / total;
192
193 for pat in &self.patterns {
194 let match_frac = match &pat.regex {
195 Some(re) => samples.iter().filter(|s| re.is_match(s)).count() as f32 / total,
196 None => 1.0,
197 };
198 if match_frac >= pat.threshold
199 && pat.unique_rate_min.is_none_or(|min| unique_rate >= min)
200 && pat.unique_rate_max.is_none_or(|max| unique_rate <= max)
201 && pat.avg_len_max.is_none_or(|max| avg_len <= max)
202 && pat.avg_len_min.is_none_or(|min| avg_len >= min)
203 {
204 return pat.kind;
205 }
206 }
207 self.fallback_kind
208 }
209}
210
211#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn name_heuristics_embedded_default_loads() {
219 let h = NameHeuristics::load_default();
220 assert!(!h.rules.is_empty());
221 }
222
223 #[test]
224 fn name_heuristics_matches_known_patterns() {
225 let h = NameHeuristics::load_default();
226 assert_eq!(h.infer_kind("first_name"), Some(FieldKind::Name));
227 assert_eq!(h.infer_kind("geboortedatum"), Some(FieldKind::Date));
228 assert_eq!(h.infer_kind("msisdn"), Some(FieldKind::Phone));
229 assert_eq!(h.infer_kind("postcode"), Some(FieldKind::Address));
230 assert_eq!(h.infer_kind("bsn"), Some(FieldKind::Id));
231 }
232
233 #[test]
234 fn name_heuristics_returns_none_for_unknown() {
235 let h = NameHeuristics::load_default();
236 assert_eq!(h.infer_kind("xyzzy_col"), None);
237 }
238
239 #[test]
240 fn value_patterns_embedded_default_loads() {
241 let p = ValuePatterns::load_default();
242 assert!(!p.patterns.is_empty());
243 }
244
245 #[test]
246 fn value_patterns_date_detection() {
247 let p = ValuePatterns::load_default();
248 let samples: Vec<&str> = (0..20).map(|_| "2024-03-15").collect();
249 assert_eq!(p.infer_kind(&samples), FieldKind::Date);
250 }
251
252 #[test]
253 fn value_patterns_fallback_on_empty() {
254 let p = ValuePatterns::load_default();
255 assert_eq!(p.infer_kind(&[]), FieldKind::FreeText);
256 }
257
258 #[test]
259 fn custom_name_heuristics_from_file() {
260 let dir = tempfile::tempdir().unwrap();
261 let path = dir.path().join("custom_name.toml");
262 std::fs::write(
263 &path,
264 r#"
265[[rules]]
266kind = "Id"
267exact = ["mijnkolom"]
268"#,
269 )
270 .unwrap();
271
272 let h = NameHeuristics::from_file(&path).unwrap();
273 assert_eq!(h.infer_kind("mijnkolom"), Some(FieldKind::Id));
274 assert_eq!(h.infer_kind("other"), None);
275 }
276
277 #[test]
278 fn custom_value_patterns_from_file() {
279 let dir = tempfile::tempdir().unwrap();
280 let path = dir.path().join("custom_values.toml");
281 std::fs::write(
282 &path,
283 r#"
284[[patterns]]
285kind = "Phone"
286regex = '^\+31\d{9}$'
287threshold = 0.8
288
289[fallback]
290default_kind = "FreeText"
291"#,
292 )
293 .unwrap();
294
295 let p = ValuePatterns::from_file(&path).unwrap();
296 let samples: Vec<&str> = (0..20).map(|_| "+31612345678").collect();
297 assert_eq!(p.infer_kind(&samples), FieldKind::Phone);
298 }
299
300 #[test]
301 fn invalid_toml_returns_error() {
302 let result = NameHeuristics::from_toml_str("this is not toml ][");
303 assert!(matches!(result, Err(ZerError::Config(_))));
304 }
305
306 #[test]
307 fn invalid_regex_returns_error() {
308 let result = ValuePatterns::from_toml_str(
309 r#"
310[[patterns]]
311kind = "Date"
312regex = '[invalid'
313threshold = 0.8
314
315[fallback]
316default_kind = "FreeText"
317"#,
318 );
319 assert!(matches!(result, Err(ZerError::Config(_))));
320 }
321}