Skip to main content

zer_compare/
comparator.rs

1use rayon::prelude::*;
2use zer_core::{
3    comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector},
4    field_mapping::{FieldMapping, NullPolicy},
5    record::Record,
6    record_pool::RecordPool,
7    schema::{FieldKind, Schema},
8    traits::Comparator,
9};
10
11use crate::{
12    discretize::LevelThresholds,
13    similarity::{default_fns_for, SimilarityFn},
14};
15
16/// Pairwise field comparator that applies similarity functions to produce a field-major `ComparisonBatch`.
17pub struct FieldComparator {
18    field_fns:  Vec<Vec<Box<dyn SimilarityFn>>>,
19    thresholds: Vec<LevelThresholds>,
20}
21
22impl FieldComparator {
23    pub fn from_schema(schema: &Schema) -> Self {
24        let field_fns = schema.fields.iter()
25            .map(|f| default_fns_for(f.kind))
26            .collect();
27        let thresholds = schema.fields.iter()
28            .map(|f| LevelThresholds::for_kind(f.kind))
29            .collect();
30        Self { field_fns, thresholds }
31    }
32
33    /// Build a comparator for cross-schema linkage from an explicit field-mapping list.
34    ///
35    /// Field kinds are inferred from `a_schema` by looking up each `a_field`.
36    /// Fields not found in `a_schema` default to `FieldKind::Categorical`.
37    pub fn from_mapping(mappings: &[FieldMapping], a_schema: &Schema) -> Self {
38        let kind_of = |name: &str| {
39            a_schema.fields.iter()
40                .find(|f| f.name == name)
41                .map(|f| f.kind)
42                .unwrap_or(FieldKind::Categorical)
43        };
44        let (field_fns, thresholds): (Vec<_>, Vec<_>) = mappings.iter()
45            .map(|m| { let k = kind_of(&m.a_field); (default_fns_for(k), LevelThresholds::for_kind(k)) })
46            .unzip();
47        Self { field_fns, thresholds }
48    }
49
50    /// Compare a cross-schema pair using an explicit field-mapping list.
51    ///
52    /// For each mapping, looks up `a_field` in record `a` and `b_field` in
53    /// record `b`.  When a field is missing the `NullPolicy` decides the level:
54    /// `Skip` gives `Null` (EM ignores it), `PenaliseAbsence` gives `None` (hard fail).
55    pub fn compare_pair_mapped(
56        &self,
57        a:        &Record,
58        b:        &Record,
59        mappings: &[FieldMapping],
60    ) -> ComparisonVector {
61        let levels: Vec<ComparisonLevel> = mappings.iter().enumerate()
62            .map(|(i, m)| {
63                let va = a.fields.get(&m.a_field);
64                let vb = b.fields.get(&m.b_field);
65                match (va, vb, &m.null_policy) {
66                    (Some(va), Some(vb), _) => {
67                        let sim = self.field_fns[i].iter()
68                            .map(|f| f.similarity(va, vb))
69                            .fold(0.0_f32, f32::max);
70                        self.thresholds[i].apply(sim)
71                    }
72                    (_, _, NullPolicy::PenaliseAbsence) => ComparisonLevel::None,
73                    (_, _, NullPolicy::Skip)             => ComparisonLevel::Null,
74                }
75            })
76            .collect();
77        ComparisonVector::new(a.id, b.id, levels)
78    }
79
80    /// Batch comparison for cross-schema linkage using explicit field mappings.
81    ///
82    /// Equivalent to calling `compare_pair_mapped` per pair then assembling the
83    /// field-major `ComparisonBatch`.  `n_fields = mappings.len()`.
84    pub fn compare_batch_mapped(
85        &self,
86        records: &[Record],
87        indices: &[(usize, usize)],
88        mappings: &[FieldMapping],
89    ) -> ComparisonBatch {
90        let n_pairs  = indices.len();
91        let n_fields = mappings.len();
92
93        if n_pairs == 0 {
94            return ComparisonBatch::new(0, n_fields, vec![]);
95        }
96
97        let pair_ids_and_levels: Vec<((u64, u64), Vec<u8>)> = indices
98            .par_iter()
99            .map(|&(i, j)| {
100                let ids    = (records[i].id, records[j].id);
101                let cv     = self.compare_pair_mapped(&records[i], &records[j], mappings);
102                let levels = cv.levels.iter().map(|&l| l as u8).collect();
103                (ids, levels)
104            })
105            .collect();
106
107        Self::scatter_to_batch(n_pairs, n_fields, pair_ids_and_levels)
108    }
109
110    fn scatter_to_batch(
111        n_pairs:  usize,
112        n_fields: usize,
113        pair_ids_and_levels: Vec<((u64, u64), Vec<u8>)>,
114    ) -> ComparisonBatch {
115        let pair_ids: Vec<(u64, u64)> =
116            pair_ids_and_levels.iter().map(|(ids, _)| *ids).collect();
117        let mut levels = vec![0u8; n_fields * n_pairs];
118        for f in 0..n_fields {
119            let field_slice = &mut levels[f * n_pairs..(f + 1) * n_pairs];
120            for (p, (_, pair_lvls)) in pair_ids_and_levels.iter().enumerate() {
121                field_slice[p] = pair_lvls[f];
122            }
123        }
124        ComparisonBatch { n_pairs, n_fields, pair_ids, levels }
125    }
126
127    pub fn with_thresholds(mut self, field_idx: usize, thresholds: LevelThresholds) -> Self {
128        self.thresholds[field_idx] = thresholds;
129        self
130    }
131
132    pub fn with_fns(mut self, field_idx: usize, fns: Vec<Box<dyn SimilarityFn>>) -> Self {
133        self.field_fns[field_idx] = fns;
134        self
135    }
136
137    fn compare_pair(&self, a: &Record, b: &Record, schema: &Schema) -> ComparisonVector {
138        let levels: Vec<ComparisonLevel> = schema.fields.iter().enumerate()
139            .map(|(i, field)| {
140                let va = a.fields.get(&field.name);
141                let vb = b.fields.get(&field.name);
142                match (va, vb) {
143                    (Some(va), Some(vb)) => {
144                        let sim = self.field_fns[i].iter()
145                            .map(|f| f.similarity(va, vb))
146                            .fold(0.0_f32, f32::max);
147                        self.thresholds[i].apply(sim)
148                    }
149                    _ => ComparisonLevel::None,
150                }
151            })
152            .collect();
153        ComparisonVector::new(a.id, b.id, levels)
154    }
155
156    /// Compare field `f` using the zero-alloc `similarity_str` hot path.
157    #[inline]
158    fn compare_pool_field(&self, f: usize, a_str: &str, b_str: &str) -> u8 {
159        if a_str.is_empty() || b_str.is_empty() {
160            return ComparisonLevel::None as u8;
161        }
162        let sim = self.field_fns[f].iter()
163            .map(|fn_| fn_.similarity_str(a_str, b_str))
164            .fold(0.0_f32, f32::max);
165        self.thresholds[f].apply(sim) as u8
166    }
167
168    /// Pool-native batch comparison, the primary hot path.
169    ///
170    /// Reads `RecordPool` columns directly: zero HashMap lookups, no
171    /// `Record::clone()`.  Uses Rayon for parallel per-pair comparison
172    /// into a flat pair-major buffer (zero per-pair heap allocations), then
173    /// transposes to the field-major `ComparisonBatch` layout required by
174    /// all GPU EM kernels (CUDA/Vulkan/AVX2): `levels[f * n_pairs + p]`.
175    pub fn compare_batch_from_pool(
176        &self,
177        pool:    &RecordPool,
178        indices: &[(usize, usize)],
179        schema:  &Schema,
180    ) -> ComparisonBatch {
181        let n_pairs  = indices.len();
182        let n_fields = schema.fields.len();
183
184        if n_pairs == 0 {
185            return ComparisonBatch::new(0, n_fields, vec![]);
186        }
187
188        // Pre-compute pair IDs (cheap, serial).
189        let pair_ids: Vec<(u64, u64)> = indices.iter()
190            .map(|&(i, j)| (pool.ids[i], pool.ids[j]))
191            .collect();
192
193        // Phase 1: parallel pair-major fill.  Each pair owns a contiguous
194        // n_fields-byte slice, no per-pair allocation.
195        // pair_major[p * n_fields + f] = level for pair p, field f.
196        let mut pair_major = vec![0u8; n_pairs * n_fields];
197        pair_major
198            .par_chunks_mut(n_fields)
199            .zip(indices.par_iter())
200            .for_each(|(chunk, &(i, j))| {
201                for f in 0..n_fields {
202                    chunk[f] = self.compare_pool_field(f, pool.get(f, i), pool.get(f, j));
203                }
204            });
205
206        // Phase 2: transpose pair-major → field-major.
207        // Output: levels[f * n_pairs + p]  (required by GPU EM kernels).
208        let mut levels = vec![0u8; n_fields * n_pairs];
209        for (p, chunk) in pair_major.chunks_exact(n_fields).enumerate() {
210            for (f, &lvl) in chunk.iter().enumerate() {
211                levels[f * n_pairs + p] = lvl;
212            }
213        }
214
215        ComparisonBatch { n_pairs, n_fields, pair_ids, levels }
216    }
217}
218
219impl Comparator for FieldComparator {
220    fn compare(&self, a: &Record, b: &Record, schema: &Schema) -> ComparisonVector {
221        self.compare_pair(a, b, schema)
222    }
223
224    fn compare_batch_from_pool(
225        &self,
226        pool:    &RecordPool,
227        indices: &[(usize, usize)],
228        schema:  &Schema,
229    ) -> ComparisonBatch {
230        self.compare_batch_from_pool(pool, indices, schema)
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use zer_core::{
237        comparison::ComparisonLevel,
238        record::FieldValue,
239        record_pool::RecordPool,
240        schema::{FieldKind, SchemaBuilder},
241    };
242
243    use super::*;
244
245    fn person_schema() -> Schema {
246        SchemaBuilder::new()
247            .field("voornamen",     FieldKind::Name)
248            .field("achternaam",    FieldKind::Name)
249            .field("geboortedatum", FieldKind::Date)
250            .field("postcode",      FieldKind::Id)
251            .build()
252            .unwrap()
253    }
254
255    fn make_record(id: u64, voornamen: &str, achternaam: &str, dob: &str, postcode: &str) -> Record {
256        Record::new(id)
257            .insert("voornamen",     FieldValue::Text(voornamen.into()))
258            .insert("achternaam",    FieldValue::Text(achternaam.into()))
259            .insert("geboortedatum", FieldValue::Text(dob.into()))
260            .insert("postcode",      FieldValue::Text(postcode.into()))
261    }
262
263    #[test]
264    fn compare_returns_correct_field_count() {
265        let schema = person_schema();
266        let cmp    = FieldComparator::from_schema(&schema);
267        let a      = make_record(1, "Jan", "Jansen", "1990-06-15", "1011AB");
268        let b      = make_record(2, "Jan", "Jansen", "1990-06-15", "1011AB");
269        let cv     = cmp.compare(&a, &b, &schema);
270        assert_eq!(cv.levels.len(), schema.len());
271    }
272
273    #[test]
274    fn identical_records_score_exact_on_all_fields() {
275        let schema = person_schema();
276        let cmp    = FieldComparator::from_schema(&schema);
277        let a      = make_record(1, "Jan", "Jansen", "1990-06-15", "1011AB");
278        let b      = make_record(2, "Jan", "Jansen", "1990-06-15", "1011AB");
279        let cv     = cmp.compare(&a, &b, &schema);
280        assert!(cv.levels.iter().all(|&l| l == ComparisonLevel::Exact),
281            "identical records should have all Exact levels: {:?}", cv.levels);
282    }
283
284    #[test]
285    fn completely_different_records_score_none_or_low() {
286        let schema = person_schema();
287        let cmp    = FieldComparator::from_schema(&schema);
288        let a      = make_record(1, "Jan", "Jansen", "1990-06-15", "1011AB");
289        let b      = make_record(2, "Maria", "Bakker", "1955-12-01", "3001XY");
290        let cv     = cmp.compare(&a, &b, &schema);
291        let n_none = cv.levels.iter().filter(|&&l| l == ComparisonLevel::None).count();
292        assert!(n_none >= 2, "dissimilar records should have several None levels: {:?}", cv.levels);
293    }
294
295    #[test]
296    fn missing_field_produces_none() {
297        let schema = person_schema();
298        let cmp    = FieldComparator::from_schema(&schema);
299        let a = make_record(1, "Jan", "Jansen", "1990-06-15", "1011AB");
300        let b = Record::new(2)
301            .insert("voornamen",     FieldValue::Text("Jan".into()))
302            .insert("achternaam",    FieldValue::Text("Jansen".into()))
303            .insert("geboortedatum", FieldValue::Text("1990-06-15".into()));
304        let cv = cmp.compare(&a, &b, &schema);
305        assert_eq!(cv.levels[3], ComparisonLevel::None,
306            "missing postcode should yield None, got {:?}", cv.levels[3]);
307    }
308
309    #[test]
310    fn compare_batch_field_major_layout() {
311        let schema   = person_schema();
312        let cmp      = FieldComparator::from_schema(&schema);
313        let n_fields = schema.len();
314
315        let records: Vec<Record> = (0..5).flat_map(|i| vec![
316            make_record(i * 2,     "Jan", "Jansen", "1990-06-15", "1011AB"),
317            make_record(i * 2 + 1, "Jan", "Jansen", "1990-06-15", "1011AB"),
318        ]).collect();
319        let pool    = RecordPool::from_records(&records, &schema);
320        let indices: Vec<(usize, usize)> = (0..5).map(|i| (i * 2, i * 2 + 1)).collect();
321
322        let batch = cmp.compare_batch_from_pool(&pool, &indices, &schema);
323
324        assert_eq!(batch.n_pairs,  5);
325        assert_eq!(batch.n_fields, n_fields);
326        assert_eq!(batch.levels.len(), n_fields * 5);
327
328        // All identical → all Exact
329        for f in 0..n_fields {
330            for p in 0..5 {
331                assert_eq!(
332                    batch.level(f, p),
333                    ComparisonLevel::Exact,
334                    "field {f} pair {p} should be Exact"
335                );
336            }
337        }
338    }
339
340    #[test]
341    fn compare_batch_from_pool_matches_individual_compare() {
342        let schema  = person_schema();
343        let cmp     = FieldComparator::from_schema(&schema);
344        let records: Vec<Record> = (0..20).flat_map(|i| vec![
345            make_record(i * 2,     "Jan", "Jansen", "1990-06-15", "1011AB"),
346            make_record(i * 2 + 1, "Jan", "Jansen", "1990-06-15", "1011AB"),
347        ]).collect();
348        let pool    = RecordPool::from_records(&records, &schema);
349        let indices: Vec<(usize, usize)> = (0..20).map(|i| (i * 2, i * 2 + 1)).collect();
350
351        let batch = cmp.compare_batch_from_pool(&pool, &indices, &schema);
352        for (p, &(i, j)) in indices.iter().enumerate() {
353            let single = cmp.compare(&records[i], &records[j], &schema);
354            for (f, &expected) in single.levels.iter().enumerate() {
355                assert_eq!(
356                    batch.level(f, p), expected,
357                    "batch and individual disagree at field {f} pair {p}"
358                );
359            }
360        }
361    }
362
363    #[test]
364    fn empty_batch_is_valid() {
365        let schema = person_schema();
366        let cmp    = FieldComparator::from_schema(&schema);
367        let pool   = RecordPool::new(schema.fields.len());
368        let batch  = cmp.compare_batch_from_pool(&pool, &[], &schema);
369        assert_eq!(batch.n_pairs, 0);
370        assert!(batch.levels.is_empty());
371    }
372
373}