1use rayon::prelude::*;
2use zer_core::{
3 comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector},
4 field_mapping::{FieldMapping, NullPolicy},
5 record::Record,
6 record_pool::RecordPool,
7 schema::{FieldKind, Schema},
8 traits::Comparator,
9};
10
11use crate::{
12 discretize::LevelThresholds,
13 similarity::{default_fns_for, SimilarityFn},
14};
15
16pub struct FieldComparator {
18 field_fns: Vec<Vec<Box<dyn SimilarityFn>>>,
19 thresholds: Vec<LevelThresholds>,
20}
21
22impl FieldComparator {
23 pub fn from_schema(schema: &Schema) -> Self {
24 let field_fns = schema.fields.iter()
25 .map(|f| default_fns_for(f.kind))
26 .collect();
27 let thresholds = schema.fields.iter()
28 .map(|f| LevelThresholds::for_kind(f.kind))
29 .collect();
30 Self { field_fns, thresholds }
31 }
32
33 pub fn from_mapping(mappings: &[FieldMapping], a_schema: &Schema) -> Self {
38 let kind_of = |name: &str| {
39 a_schema.fields.iter()
40 .find(|f| f.name == name)
41 .map(|f| f.kind)
42 .unwrap_or(FieldKind::Categorical)
43 };
44 let (field_fns, thresholds): (Vec<_>, Vec<_>) = mappings.iter()
45 .map(|m| { let k = kind_of(&m.a_field); (default_fns_for(k), LevelThresholds::for_kind(k)) })
46 .unzip();
47 Self { field_fns, thresholds }
48 }
49
50 pub fn compare_pair_mapped(
56 &self,
57 a: &Record,
58 b: &Record,
59 mappings: &[FieldMapping],
60 ) -> ComparisonVector {
61 let levels: Vec<ComparisonLevel> = mappings.iter().enumerate()
62 .map(|(i, m)| {
63 let va = a.fields.get(&m.a_field);
64 let vb = b.fields.get(&m.b_field);
65 match (va, vb, &m.null_policy) {
66 (Some(va), Some(vb), _) => {
67 let sim = self.field_fns[i].iter()
68 .map(|f| f.similarity(va, vb))
69 .fold(0.0_f32, f32::max);
70 self.thresholds[i].apply(sim)
71 }
72 (_, _, NullPolicy::PenaliseAbsence) => ComparisonLevel::None,
73 (_, _, NullPolicy::Skip) => ComparisonLevel::Null,
74 }
75 })
76 .collect();
77 ComparisonVector::new(a.id, b.id, levels)
78 }
79
80 pub fn compare_batch_mapped(
85 &self,
86 records: &[Record],
87 indices: &[(usize, usize)],
88 mappings: &[FieldMapping],
89 ) -> ComparisonBatch {
90 let n_pairs = indices.len();
91 let n_fields = mappings.len();
92
93 if n_pairs == 0 {
94 return ComparisonBatch::new(0, n_fields, vec![]);
95 }
96
97 let pair_ids_and_levels: Vec<((u64, u64), Vec<u8>)> = indices
98 .par_iter()
99 .map(|&(i, j)| {
100 let ids = (records[i].id, records[j].id);
101 let cv = self.compare_pair_mapped(&records[i], &records[j], mappings);
102 let levels = cv.levels.iter().map(|&l| l as u8).collect();
103 (ids, levels)
104 })
105 .collect();
106
107 Self::scatter_to_batch(n_pairs, n_fields, pair_ids_and_levels)
108 }
109
110 fn scatter_to_batch(
111 n_pairs: usize,
112 n_fields: usize,
113 pair_ids_and_levels: Vec<((u64, u64), Vec<u8>)>,
114 ) -> ComparisonBatch {
115 let pair_ids: Vec<(u64, u64)> =
116 pair_ids_and_levels.iter().map(|(ids, _)| *ids).collect();
117 let mut levels = vec![0u8; n_fields * n_pairs];
118 for f in 0..n_fields {
119 let field_slice = &mut levels[f * n_pairs..(f + 1) * n_pairs];
120 for (p, (_, pair_lvls)) in pair_ids_and_levels.iter().enumerate() {
121 field_slice[p] = pair_lvls[f];
122 }
123 }
124 ComparisonBatch { n_pairs, n_fields, pair_ids, levels }
125 }
126
127 pub fn with_thresholds(mut self, field_idx: usize, thresholds: LevelThresholds) -> Self {
128 self.thresholds[field_idx] = thresholds;
129 self
130 }
131
132 pub fn with_fns(mut self, field_idx: usize, fns: Vec<Box<dyn SimilarityFn>>) -> Self {
133 self.field_fns[field_idx] = fns;
134 self
135 }
136
137 fn compare_pair(&self, a: &Record, b: &Record, schema: &Schema) -> ComparisonVector {
138 let levels: Vec<ComparisonLevel> = schema.fields.iter().enumerate()
139 .map(|(i, field)| {
140 let va = a.fields.get(&field.name);
141 let vb = b.fields.get(&field.name);
142 match (va, vb) {
143 (Some(va), Some(vb)) => {
144 let sim = self.field_fns[i].iter()
145 .map(|f| f.similarity(va, vb))
146 .fold(0.0_f32, f32::max);
147 self.thresholds[i].apply(sim)
148 }
149 _ => ComparisonLevel::None,
150 }
151 })
152 .collect();
153 ComparisonVector::new(a.id, b.id, levels)
154 }
155
156 #[inline]
158 fn compare_pool_field(&self, f: usize, a_str: &str, b_str: &str) -> u8 {
159 if a_str.is_empty() || b_str.is_empty() {
160 return ComparisonLevel::None as u8;
161 }
162 let sim = self.field_fns[f].iter()
163 .map(|fn_| fn_.similarity_str(a_str, b_str))
164 .fold(0.0_f32, f32::max);
165 self.thresholds[f].apply(sim) as u8
166 }
167
168 pub fn compare_batch_from_pool(
176 &self,
177 pool: &RecordPool,
178 indices: &[(usize, usize)],
179 schema: &Schema,
180 ) -> ComparisonBatch {
181 let n_pairs = indices.len();
182 let n_fields = schema.fields.len();
183
184 if n_pairs == 0 {
185 return ComparisonBatch::new(0, n_fields, vec![]);
186 }
187
188 let pair_ids: Vec<(u64, u64)> = indices.iter()
190 .map(|&(i, j)| (pool.ids[i], pool.ids[j]))
191 .collect();
192
193 let mut pair_major = vec![0u8; n_pairs * n_fields];
197 pair_major
198 .par_chunks_mut(n_fields)
199 .zip(indices.par_iter())
200 .for_each(|(chunk, &(i, j))| {
201 for f in 0..n_fields {
202 chunk[f] = self.compare_pool_field(f, pool.get(f, i), pool.get(f, j));
203 }
204 });
205
206 let mut levels = vec![0u8; n_fields * n_pairs];
209 for (p, chunk) in pair_major.chunks_exact(n_fields).enumerate() {
210 for (f, &lvl) in chunk.iter().enumerate() {
211 levels[f * n_pairs + p] = lvl;
212 }
213 }
214
215 ComparisonBatch { n_pairs, n_fields, pair_ids, levels }
216 }
217}
218
219impl Comparator for FieldComparator {
220 fn compare(&self, a: &Record, b: &Record, schema: &Schema) -> ComparisonVector {
221 self.compare_pair(a, b, schema)
222 }
223
224 fn compare_batch_from_pool(
225 &self,
226 pool: &RecordPool,
227 indices: &[(usize, usize)],
228 schema: &Schema,
229 ) -> ComparisonBatch {
230 self.compare_batch_from_pool(pool, indices, schema)
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use zer_core::{
237 comparison::ComparisonLevel,
238 record::FieldValue,
239 record_pool::RecordPool,
240 schema::{FieldKind, SchemaBuilder},
241 };
242
243 use super::*;
244
245 fn person_schema() -> Schema {
246 SchemaBuilder::new()
247 .field("voornamen", FieldKind::Name)
248 .field("achternaam", FieldKind::Name)
249 .field("geboortedatum", FieldKind::Date)
250 .field("postcode", FieldKind::Id)
251 .build()
252 .unwrap()
253 }
254
255 fn make_record(id: u64, voornamen: &str, achternaam: &str, dob: &str, postcode: &str) -> Record {
256 Record::new(id)
257 .insert("voornamen", FieldValue::Text(voornamen.into()))
258 .insert("achternaam", FieldValue::Text(achternaam.into()))
259 .insert("geboortedatum", FieldValue::Text(dob.into()))
260 .insert("postcode", FieldValue::Text(postcode.into()))
261 }
262
263 #[test]
264 fn compare_returns_correct_field_count() {
265 let schema = person_schema();
266 let cmp = FieldComparator::from_schema(&schema);
267 let a = make_record(1, "Jan", "Jansen", "1990-06-15", "1011AB");
268 let b = make_record(2, "Jan", "Jansen", "1990-06-15", "1011AB");
269 let cv = cmp.compare(&a, &b, &schema);
270 assert_eq!(cv.levels.len(), schema.len());
271 }
272
273 #[test]
274 fn identical_records_score_exact_on_all_fields() {
275 let schema = person_schema();
276 let cmp = FieldComparator::from_schema(&schema);
277 let a = make_record(1, "Jan", "Jansen", "1990-06-15", "1011AB");
278 let b = make_record(2, "Jan", "Jansen", "1990-06-15", "1011AB");
279 let cv = cmp.compare(&a, &b, &schema);
280 assert!(cv.levels.iter().all(|&l| l == ComparisonLevel::Exact),
281 "identical records should have all Exact levels: {:?}", cv.levels);
282 }
283
284 #[test]
285 fn completely_different_records_score_none_or_low() {
286 let schema = person_schema();
287 let cmp = FieldComparator::from_schema(&schema);
288 let a = make_record(1, "Jan", "Jansen", "1990-06-15", "1011AB");
289 let b = make_record(2, "Maria", "Bakker", "1955-12-01", "3001XY");
290 let cv = cmp.compare(&a, &b, &schema);
291 let n_none = cv.levels.iter().filter(|&&l| l == ComparisonLevel::None).count();
292 assert!(n_none >= 2, "dissimilar records should have several None levels: {:?}", cv.levels);
293 }
294
295 #[test]
296 fn missing_field_produces_none() {
297 let schema = person_schema();
298 let cmp = FieldComparator::from_schema(&schema);
299 let a = make_record(1, "Jan", "Jansen", "1990-06-15", "1011AB");
300 let b = Record::new(2)
301 .insert("voornamen", FieldValue::Text("Jan".into()))
302 .insert("achternaam", FieldValue::Text("Jansen".into()))
303 .insert("geboortedatum", FieldValue::Text("1990-06-15".into()));
304 let cv = cmp.compare(&a, &b, &schema);
305 assert_eq!(cv.levels[3], ComparisonLevel::None,
306 "missing postcode should yield None, got {:?}", cv.levels[3]);
307 }
308
309 #[test]
310 fn compare_batch_field_major_layout() {
311 let schema = person_schema();
312 let cmp = FieldComparator::from_schema(&schema);
313 let n_fields = schema.len();
314
315 let records: Vec<Record> = (0..5).flat_map(|i| vec![
316 make_record(i * 2, "Jan", "Jansen", "1990-06-15", "1011AB"),
317 make_record(i * 2 + 1, "Jan", "Jansen", "1990-06-15", "1011AB"),
318 ]).collect();
319 let pool = RecordPool::from_records(&records, &schema);
320 let indices: Vec<(usize, usize)> = (0..5).map(|i| (i * 2, i * 2 + 1)).collect();
321
322 let batch = cmp.compare_batch_from_pool(&pool, &indices, &schema);
323
324 assert_eq!(batch.n_pairs, 5);
325 assert_eq!(batch.n_fields, n_fields);
326 assert_eq!(batch.levels.len(), n_fields * 5);
327
328 for f in 0..n_fields {
330 for p in 0..5 {
331 assert_eq!(
332 batch.level(f, p),
333 ComparisonLevel::Exact,
334 "field {f} pair {p} should be Exact"
335 );
336 }
337 }
338 }
339
340 #[test]
341 fn compare_batch_from_pool_matches_individual_compare() {
342 let schema = person_schema();
343 let cmp = FieldComparator::from_schema(&schema);
344 let records: Vec<Record> = (0..20).flat_map(|i| vec![
345 make_record(i * 2, "Jan", "Jansen", "1990-06-15", "1011AB"),
346 make_record(i * 2 + 1, "Jan", "Jansen", "1990-06-15", "1011AB"),
347 ]).collect();
348 let pool = RecordPool::from_records(&records, &schema);
349 let indices: Vec<(usize, usize)> = (0..20).map(|i| (i * 2, i * 2 + 1)).collect();
350
351 let batch = cmp.compare_batch_from_pool(&pool, &indices, &schema);
352 for (p, &(i, j)) in indices.iter().enumerate() {
353 let single = cmp.compare(&records[i], &records[j], &schema);
354 for (f, &expected) in single.levels.iter().enumerate() {
355 assert_eq!(
356 batch.level(f, p), expected,
357 "batch and individual disagree at field {f} pair {p}"
358 );
359 }
360 }
361 }
362
363 #[test]
364 fn empty_batch_is_valid() {
365 let schema = person_schema();
366 let cmp = FieldComparator::from_schema(&schema);
367 let pool = RecordPool::new(schema.fields.len());
368 let batch = cmp.compare_batch_from_pool(&pool, &[], &schema);
369 assert_eq!(batch.n_pairs, 0);
370 assert!(batch.levels.is_empty());
371 }
372
373}