Skip to main content

zer_compute/
comparator.rs

1//! `DeviceComparator`, implements the `Comparator` trait.
2//!
3//! `compare_batch_from_pool` always routes to the CPU (Rayon parallel) path.
4//! String comparison (Jaro-Winkler) is branch-heavy and dominated by PCIe
5//! transfer overhead on GPU, making CPU faster for all observed batch sizes.
6
7use std::sync::Arc;
8
9use zer_core::{
10    comparison::{ComparisonBatch, ComparisonVector},
11    record::Record,
12    record_pool::RecordPool,
13    schema::Schema,
14    traits::Comparator,
15};
16
17use crate::{
18    backend::{cpu::CpuFallbackComparator, DeviceBackend},
19    error::GpuError,
20};
21
22pub struct DeviceComparator {
23    backend:      Arc<DeviceBackend>,
24    cpu_fallback: CpuFallbackComparator,
25}
26
27impl DeviceComparator {
28    pub fn new(backend: Arc<DeviceBackend>, schema: &Schema) -> Result<Self, GpuError> {
29        let cpu_fallback = CpuFallbackComparator::from_schema(schema);
30        Ok(Self { backend, cpu_fallback })
31    }
32
33    pub fn backend_name(&self) -> &'static str {
34        self.backend.name()
35    }
36}
37
38impl Comparator for DeviceComparator {
39    fn compare(&self, a: &Record, b: &Record, schema: &Schema) -> ComparisonVector {
40        self.cpu_fallback.compare(a, b, schema)
41    }
42
43    fn compare_batch_from_pool(
44        &self,
45        pool:    &RecordPool,
46        indices: &[(usize, usize)],
47        schema:  &Schema,
48    ) -> ComparisonBatch {
49        if indices.is_empty() {
50            return ComparisonBatch::new(0, schema.fields.len(), vec![]);
51        }
52
53        // compare_batch always runs on CPU: string comparison (Jaro-Winkler) is branch-heavy
54        // and dominated by PCIe transfer overhead on GPU for all observed batch sizes.
55        self.cpu_fallback.compare_batch_from_pool(pool, indices, schema)
56    }
57}
58
59// ── Unit tests ────────────────────────────────────────────────────────────────
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64    use zer_core::{
65        comparison::ComparisonLevel,
66        record::{FieldValue, Record},
67        record_pool::RecordPool,
68        schema::{FieldKind, SchemaBuilder},
69    };
70
71    fn test_schema() -> Schema {
72        SchemaBuilder::new()
73            .field("naam",     FieldKind::Name)
74            .field("datum",    FieldKind::Date)
75            .field("kenteken", FieldKind::LicensePlate)
76            .build()
77            .unwrap()
78    }
79
80    fn make_record(id: u64) -> Record {
81        Record::new(id)
82            .insert("naam",     FieldValue::Text("Alice de Vries".into()))
83            .insert("datum",    FieldValue::Text("1990-03-15".into()))
84            .insert("kenteken", FieldValue::Text("12-ABC-3".into()))
85    }
86
87    fn make_record_b(id: u64) -> Record {
88        Record::new(id)
89            .insert("naam",     FieldValue::Text("Alicia de Vrees".into()))
90            .insert("datum",    FieldValue::Text("1990-03-15".into()))
91            .insert("kenteken", FieldValue::Text("12-ABC-3".into()))
92    }
93
94    #[test]
95    fn single_pair_uses_cpu_path() {
96        let schema  = test_schema();
97        let backend = Arc::new(DeviceBackend::cpu());
98        let cmp     = DeviceComparator::new(backend, &schema).unwrap();
99
100        let a = make_record(1);
101        let b = make_record_b(2);
102        let vec = cmp.compare(&a, &b, &schema);
103
104        assert_eq!(vec.record_a, 1);
105        assert_eq!(vec.record_b, 2);
106        assert_eq!(vec.levels.len(), 3);
107    }
108
109    #[test]
110    fn small_batch_uses_cpu_fallback() {
111        let schema  = test_schema();
112        let backend = Arc::new(DeviceBackend::cpu());
113        let cmp     = DeviceComparator::new(backend, &schema).unwrap();
114
115        let records: Vec<Record> = (0..20).map(|i| {
116            if i % 2 == 0 { make_record(i) } else { make_record_b(i) }
117        }).collect();
118        let pool    = RecordPool::from_records(&records, &schema);
119        let indices: Vec<(usize, usize)> = (0..10).map(|i| (i * 2, i * 2 + 1)).collect();
120
121        let batch = cmp.compare_batch_from_pool(&pool, &indices, &schema);
122        assert_eq!(batch.n_pairs,  10);
123        assert_eq!(batch.n_fields, 3);
124        assert_eq!(batch.levels.len(), 3 * 10);
125    }
126
127    #[test]
128    fn empty_batch_returns_empty() {
129        let schema  = test_schema();
130        let backend = Arc::new(DeviceBackend::cpu());
131        let cmp     = DeviceComparator::new(backend, &schema).unwrap();
132        let pool    = RecordPool::new(schema.fields.len());
133        let batch   = cmp.compare_batch_from_pool(&pool, &[], &schema);
134        assert_eq!(batch.n_pairs, 0);
135        assert!(batch.levels.is_empty());
136    }
137
138    #[test]
139    fn exact_match_produces_exact_levels() {
140        let schema  = test_schema();
141        let backend = Arc::new(DeviceBackend::cpu());
142        let cmp     = DeviceComparator::new(backend, &schema).unwrap();
143
144        let r = Record::new(1)
145            .insert("naam",     FieldValue::Text("Jan Jansen".into()))
146            .insert("datum",    FieldValue::Text("1980-01-01".into()))
147            .insert("kenteken", FieldValue::Text("AB-123-C".into()));
148        let vec = cmp.compare(&r.clone(), &r, &schema);
149
150        for level in &vec.levels {
151            assert_eq!(*level, ComparisonLevel::Exact, "identical records should give Exact");
152        }
153    }
154
155    #[test]
156    fn completely_different_records_produce_none_levels() {
157        let schema  = test_schema();
158        let backend = Arc::new(DeviceBackend::cpu());
159        let cmp     = DeviceComparator::new(backend, &schema).unwrap();
160
161        let a = Record::new(1)
162            .insert("naam",     FieldValue::Text("Henk".into()))
163            .insert("datum",    FieldValue::Text("1950-01-01".into()))
164            .insert("kenteken", FieldValue::Text("XX-000-X".into()));
165        let b = Record::new(2)
166            .insert("naam",     FieldValue::Text("Zäzä".into()))
167            .insert("datum",    FieldValue::Text("2010-12-31".into()))
168            .insert("kenteken", FieldValue::Text("YY-999-Y".into()));
169
170        let vec = cmp.compare(&a, &b, &schema);
171        for level in &vec.levels {
172            assert!(
173                matches!(level, ComparisonLevel::None | ComparisonLevel::Partial),
174                "very different records should produce None or Partial levels"
175            );
176        }
177    }
178
179    fn synthetic_records(n: usize, schema: &Schema) -> Vec<Record> {
180        use rand::Rng;
181        let mut rng = rand::thread_rng();
182
183        let names  = ["Alice", "Alicia", "Bob", "Robert", "Eva", "Eva-Marie",
184                      "Jan", "Johan", "Petra", "Pietra", "Lena", "Lena-Marie"];
185        let dates  = ["1990-01-15", "1990-01-16", "1985-06-20", "1975-03-03",
186                      "2000-12-31", "2001-01-01", "1960-07-07", "1970-11-22"];
187        let plates = ["12-ABC-3", "12-ABD-3", "45-XYZ-6", "46-XYZ-6",
188                      "AB-123-C", "AB-124-C", "ZZ-999-Z", "ZZ-998-Z"];
189
190        let fields: Vec<&str> = schema.fields.iter().map(|f| f.name.as_str()).collect();
191
192        (0..n).map(|i| {
193            let mut r = Record::new(i as u64);
194            for field in &fields {
195                let val = match *field {
196                    "naam"     => names[rng.gen_range(0..names.len())],
197                    "datum"    => dates[rng.gen_range(0..dates.len())],
198                    "kenteken" => plates[rng.gen_range(0..plates.len())],
199                    _          => "unknown",
200                };
201                r = r.insert(*field, FieldValue::Text(val.into()));
202            }
203            r
204        }).collect()
205    }
206
207    #[test]
208    fn large_batch_auto_detect_returns_correct_count() {
209        let schema  = test_schema();
210        let backend = Arc::new(DeviceBackend::auto_detect());
211        let cmp     = DeviceComparator::new(backend, &schema).unwrap();
212
213        let records = synthetic_records(4_000, &schema);
214        let pool    = RecordPool::from_records(&records, &schema);
215        let indices: Vec<(usize, usize)> = (0..2_000).map(|i| (i * 2, i * 2 + 1)).collect();
216
217        let batch = cmp.compare_batch_from_pool(&pool, &indices, &schema);
218
219        assert_eq!(batch.n_pairs,  2_000, "compare_batch_from_pool must return one row per pair");
220        assert_eq!(batch.n_fields, 3,     "each batch must have one column per field");
221    }
222
223}