Skip to main content

zer_schema/
infer.rs

1use std::collections::{HashMap, HashSet};
2use std::path::Path;
3
4use zer_core::{
5    error::ZerError,
6    record::{FieldValue, Record},
7    schema::{FieldDef, FieldKind, Schema},
8};
9
10use crate::config::{NameHeuristics, ValuePatterns};
11
12// ── Helpers ───────────────────────────────────────────────────────────────────
13
14fn text_samples<'a>(field_name: &str, records: &'a [Record], n: usize) -> Vec<&'a str> {
15    records
16        .iter()
17        .filter_map(|r| match r.fields.get(field_name) {
18            Some(FieldValue::Text(s)) if !s.is_empty() => Some(s.as_str()),
19            _ => None,
20        })
21        .take(n)
22        .collect()
23}
24
25fn collect_field_names(records: &[Record]) -> Vec<String> {
26    let mut names: HashSet<String> = HashSet::new();
27    for record in records {
28        for name in record.fields.keys() {
29            names.insert(name.clone());
30        }
31    }
32    let mut sorted: Vec<String> = names.into_iter().collect();
33    sorted.sort();
34    sorted
35}
36
37// ── Public API ────────────────────────────────────────────────────────────────
38
39/// Automatic schema detector.
40///
41/// Samples column names and record values to produce a best-effort [`Schema`].
42/// The inferred schema should be reviewed before use in production, call
43/// individual field overrides for any column the heuristics might misclassify.
44///
45/// # Example
46///
47/// ```rust,no_run
48/// # use zer_schema::infer::SchemaInferrer;
49/// # use zer_core::schema::FieldKind;
50/// # let records = vec![];
51/// let schema = SchemaInferrer::new()
52///     .override_field("internal_code", FieldKind::Id)
53///     .override_field("notes",         FieldKind::FreeText)
54///     .infer(&records)
55///     .unwrap();
56/// ```
57pub struct SchemaInferrer {
58    overrides:       HashMap<String, FieldKind>,
59    name_heuristics: NameHeuristics,
60    value_patterns:  ValuePatterns,
61}
62
63impl SchemaInferrer {
64    /// Create a new inferrer loading heuristics from the embedded defaults
65    /// (or from `ZER_NAME_HEURISTICS` / `ZER_VALUE_PATTERNS` env vars if set).
66    pub fn new() -> Self {
67        Self {
68            overrides:       HashMap::new(),
69            name_heuristics: NameHeuristics::load_default(),
70            value_patterns:  ValuePatterns::load_default(),
71        }
72    }
73
74    /// Override the name-based heuristics with rules loaded from a TOML file.
75    ///
76    /// Returns `Err` if the file cannot be read or parsed.
77    pub fn with_name_heuristics_file(mut self, path: impl AsRef<Path>) -> Result<Self, ZerError> {
78        self.name_heuristics = NameHeuristics::from_file(path.as_ref())?;
79        Ok(self)
80    }
81
82    /// Override the value-pattern sampling with patterns loaded from a TOML file.
83    ///
84    /// Returns `Err` if the file cannot be read, parsed, or contains an invalid regex.
85    pub fn with_value_patterns_file(mut self, path: impl AsRef<Path>) -> Result<Self, ZerError> {
86        self.value_patterns = ValuePatterns::from_file(path.as_ref())?;
87        Ok(self)
88    }
89
90    /// Force a specific `FieldKind` for one field, bypassing inference.
91    ///
92    /// This always takes precedence over both name-based and value-based
93    /// heuristics.
94    pub fn override_field(mut self, name: impl Into<String>, kind: FieldKind) -> Self {
95        self.overrides.insert(name.into(), kind);
96        self
97    }
98
99    /// Infer a [`Schema`] from a sample of records.
100    ///
101    /// 50–100 non-null values per field is enough for reliable inference.
102    ///
103    /// Returns `Err(ZerError::EmptySchema)` when `records` is empty (no
104    /// field names can be discovered).
105    pub fn infer(&self, records: &[Record]) -> Result<Schema, ZerError> {
106        let field_names = collect_field_names(records);
107        if field_names.is_empty() {
108            return Err(ZerError::EmptySchema);
109        }
110
111        let fields: Vec<FieldDef> = field_names
112            .into_iter()
113            .map(|name| {
114                let kind = self.overrides.get(&name).copied().unwrap_or_else(|| {
115                    self.name_heuristics.infer_kind(&name).unwrap_or_else(|| {
116                        let samples = text_samples(&name, records, 50);
117                        self.value_patterns.infer_kind(&samples)
118                    })
119                });
120                FieldDef { name, kind }
121            })
122            .collect();
123
124        Ok(Schema { fields })
125    }
126}
127
128impl Default for SchemaInferrer {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134// ── Unit tests ────────────────────────────────────────────────────────────────
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    fn text_record(id: u64, fields: &[(&str, &str)]) -> Record {
141        let mut r = Record::new(id);
142        for (k, v) in fields {
143            r = r.insert(*k, FieldValue::Text(v.to_string()));
144        }
145        r
146    }
147
148    // Thin helpers so existing test bodies don't need rewiring.
149    fn infer_name(col: &str) -> Option<FieldKind> {
150        NameHeuristics::load_default().infer_kind(col)
151    }
152
153    fn infer_values(field: &str, records: &[Record]) -> FieldKind {
154        let samples = text_samples(field, records, 50);
155        ValuePatterns::load_default().infer_kind(&samples)
156    }
157
158    // ── Name-heuristic tests ──────────────────────────────────────────────────
159
160    #[test]
161    fn infer_common_name_fields() {
162        let cases = [
163            ("first_name", FieldKind::Name),
164            ("last_name", FieldKind::Name),
165            ("voornamen", FieldKind::Name),
166            ("achternaam", FieldKind::Name),
167            ("surname", FieldKind::Name),
168        ];
169        for (col, expected) in cases {
170            assert_eq!(
171                infer_name(col),
172                Some(expected),
173                "'{col}' should infer as {expected:?}"
174            );
175        }
176    }
177
178    #[test]
179    fn infer_date_fields_by_name() {
180        for col in ["dob", "geboortedatum", "birth_date", "created_at"] {
181            assert_eq!(
182                infer_name(col),
183                Some(FieldKind::Date),
184                "'{col}' should infer as Date"
185            );
186        }
187    }
188
189    #[test]
190    fn infer_phone_fields_by_name() {
191        for col in ["phone", "tel", "mobile", "msisdn"] {
192            assert_eq!(
193                infer_name(col),
194                Some(FieldKind::Phone),
195                "'{col}' should infer as Phone"
196            );
197        }
198    }
199
200    #[test]
201    fn infer_address_fields_by_name() {
202        for col in ["straatnaam", "postcode", "woonplaats", "huisnummer"] {
203            assert_eq!(
204                infer_name(col),
205                Some(FieldKind::Address),
206                "'{col}' should infer as Address"
207            );
208        }
209    }
210
211    #[test]
212    fn infer_id_fields_by_name() {
213        for col in ["bsn", "imsi", "iccid", "document_nummer", "passport_id"] {
214            let result = infer_name(col);
215            assert_eq!(
216                result,
217                Some(FieldKind::Id),
218                "'{col}' should infer as Id, got {result:?}"
219            );
220        }
221    }
222
223    // ── Value-pattern tests ───────────────────────────────────────────────────
224
225    #[test]
226    fn infer_date_from_iso_values() {
227        let records: Vec<Record> = (0..20)
228            .map(|i| text_record(i, &[("col_1", "2024-03-15")]))
229            .collect();
230        assert_eq!(infer_values("col_1", &records), FieldKind::Date);
231    }
232
233    #[test]
234    fn infer_numeric_from_number_values() {
235        let records: Vec<Record> = (0..20)
236            .map(|i| text_record(i, &[("col_1", &i.to_string())]))
237            .collect();
238        assert_eq!(infer_values("col_1", &records), FieldKind::Numeric);
239    }
240
241    #[test]
242    fn infer_categorical_from_low_cardinality_values() {
243        let values = ["M", "V", "M", "V", "M", "V", "M", "V", "M", "V"];
244        let records: Vec<Record> = values
245            .iter()
246            .enumerate()
247            .map(|(i, v)| text_record(i as u64, &[("geslacht", v)]))
248            .collect();
249        assert_eq!(infer_values("geslacht", &records), FieldKind::Categorical);
250    }
251
252    #[test]
253    fn infer_falls_back_to_freetext_for_empty_field() {
254        let records = vec![Record::new(1)];
255        assert_eq!(infer_values("col_1", &records), FieldKind::FreeText);
256    }
257
258    // ── Override tests ────────────────────────────────────────────────────────
259
260    #[test]
261    fn override_takes_precedence_over_name_heuristic() {
262        let records = vec![text_record(1, &[("dob", "1990-01-01")])];
263        let schema = SchemaInferrer::new()
264            .override_field("dob", FieldKind::Id)
265            .infer(&records)
266            .unwrap();
267
268        let dob = schema.fields.iter().find(|f| f.name == "dob").unwrap();
269        assert_eq!(dob.kind, FieldKind::Id, "override must win over name heuristic");
270    }
271
272    #[test]
273    fn override_takes_precedence_over_value_pattern() {
274        let records: Vec<Record> = (0..20)
275            .map(|i| text_record(i, &[("col_x", "2024-01-01")]))
276            .collect();
277        let schema = SchemaInferrer::new()
278            .override_field("col_x", FieldKind::FreeText)
279            .infer(&records)
280            .unwrap();
281
282        let field = schema.fields.iter().find(|f| f.name == "col_x").unwrap();
283        assert_eq!(field.kind, FieldKind::FreeText);
284    }
285
286    // ── Custom config file tests ──────────────────────────────────────────────
287
288    #[test]
289    fn with_name_heuristics_file_overrides_default() {
290        let dir = tempfile::tempdir().unwrap();
291        let path = dir.path().join("names.toml");
292        std::fs::write(
293            &path,
294            r#"
295[[rules]]
296kind  = "Id"
297exact = ["custom_col"]
298"#,
299        )
300        .unwrap();
301
302        let records = vec![text_record(1, &[("custom_col", "ABC123")])];
303        let schema = SchemaInferrer::new()
304            .with_name_heuristics_file(&path)
305            .unwrap()
306            .infer(&records)
307            .unwrap();
308
309        let f = schema.fields.iter().find(|f| f.name == "custom_col").unwrap();
310        assert_eq!(f.kind, FieldKind::Id);
311    }
312
313    #[test]
314    fn with_value_patterns_file_overrides_default() {
315        let dir = tempfile::tempdir().unwrap();
316        let path = dir.path().join("values.toml");
317        std::fs::write(
318            &path,
319            r#"
320[[patterns]]
321kind      = "Phone"
322regex     = '^\+31\d{9}$'
323threshold = 0.8
324
325[fallback]
326default_kind = "FreeText"
327"#,
328        )
329        .unwrap();
330
331        let records: Vec<Record> = (0..20)
332            .map(|i| text_record(i, &[("col", "+31612345678")]))
333            .collect();
334        let schema = SchemaInferrer::new()
335            .with_value_patterns_file(&path)
336            .unwrap()
337            .infer(&records)
338            .unwrap();
339
340        let f = schema.fields.iter().find(|f| f.name == "col").unwrap();
341        assert_eq!(f.kind, FieldKind::Phone);
342    }
343
344    #[test]
345    fn with_name_heuristics_file_missing_returns_error() {
346        let result =
347            SchemaInferrer::new().with_name_heuristics_file("/nonexistent/path/names.toml");
348        assert!(result.is_err());
349    }
350
351    // ── Full-inference integration tests ─────────────────────────────────────
352
353    #[test]
354    fn infer_brp_like_records() {
355        let records: Vec<Record> = (0..10)
356            .map(|i| {
357                text_record(
358                    i,
359                    &[
360                        ("voornamen", "Erik"),
361                        ("achternaam", "Hendriks"),
362                        ("geboortedatum", "1980-06-15"),
363                        ("postcode", "1234AB"),
364                        ("nationaliteit", "Nederland"),
365                    ],
366                )
367            })
368            .collect();
369
370        let schema = SchemaInferrer::new().infer(&records).unwrap();
371        let kind_of = |n: &str| schema.fields.iter().find(|f| f.name == n).map(|f| f.kind);
372
373        assert_eq!(kind_of("voornamen"), Some(FieldKind::Name));
374        assert_eq!(kind_of("achternaam"), Some(FieldKind::Name));
375        assert_eq!(kind_of("geboortedatum"), Some(FieldKind::Date));
376    }
377
378    #[test]
379    fn infer_empty_records_returns_error() {
380        let result = SchemaInferrer::new().infer(&[]);
381        assert!(
382            matches!(result, Err(ZerError::EmptySchema)),
383            "empty record slice must return EmptySchema"
384        );
385    }
386
387    #[test]
388    fn infer_record_with_no_fields_returns_error() {
389        let records = vec![Record::new(1), Record::new(2)];
390        let result = SchemaInferrer::new().infer(&records);
391        assert!(
392            matches!(result, Err(ZerError::EmptySchema)),
393            "records with no fields must return EmptySchema"
394        );
395    }
396
397    #[test]
398    fn infer_handles_null_values_gracefully() {
399        let mut records = vec![];
400        for i in 0..10u64 {
401            let mut r = Record::new(i);
402            if i % 2 == 0 {
403                r = r.insert("col", FieldValue::Text("2024-01-01".into()));
404            } else {
405                r = r.insert("col", FieldValue::Null);
406            }
407            records.push(r);
408        }
409        let schema = SchemaInferrer::new().infer(&records).unwrap();
410        assert_eq!(schema.len(), 1);
411    }
412
413    #[test]
414    fn infer_field_names_sorted_deterministically() {
415        let records = vec![text_record(1, &[("zzz", "a"), ("aaa", "b"), ("mmm", "c")])];
416        let schema = SchemaInferrer::new().infer(&records).unwrap();
417        let names: Vec<&str> = schema.fields.iter().map(|f| f.name.as_str()).collect();
418        assert_eq!(names, vec!["aaa", "mmm", "zzz"]);
419    }
420}