1use std::sync::Arc;
8
9use zer_core::{
10 comparison::{ComparisonBatch, ComparisonVector},
11 record::Record,
12 record_pool::RecordPool,
13 schema::Schema,
14 traits::Comparator,
15};
16
17use crate::{
18 backend::{cpu::CpuFallbackComparator, DeviceBackend},
19 error::GpuError,
20};
21
22pub struct DeviceComparator {
23 backend: Arc<DeviceBackend>,
24 cpu_fallback: CpuFallbackComparator,
25}
26
27impl DeviceComparator {
28 pub fn new(backend: Arc<DeviceBackend>, schema: &Schema) -> Result<Self, GpuError> {
29 let cpu_fallback = CpuFallbackComparator::from_schema(schema);
30 Ok(Self { backend, cpu_fallback })
31 }
32
33 pub fn backend_name(&self) -> &'static str {
34 self.backend.name()
35 }
36}
37
38impl Comparator for DeviceComparator {
39 fn compare(&self, a: &Record, b: &Record, schema: &Schema) -> ComparisonVector {
40 self.cpu_fallback.compare(a, b, schema)
41 }
42
43 fn compare_batch_from_pool(
44 &self,
45 pool: &RecordPool,
46 indices: &[(usize, usize)],
47 schema: &Schema,
48 ) -> ComparisonBatch {
49 if indices.is_empty() {
50 return ComparisonBatch::new(0, schema.fields.len(), vec![]);
51 }
52
53 self.cpu_fallback.compare_batch_from_pool(pool, indices, schema)
56 }
57}
58
59#[cfg(test)]
62mod tests {
63 use super::*;
64 use zer_core::{
65 comparison::ComparisonLevel,
66 record::{FieldValue, Record},
67 record_pool::RecordPool,
68 schema::{FieldKind, SchemaBuilder},
69 };
70
71 fn test_schema() -> Schema {
72 SchemaBuilder::new()
73 .field("naam", FieldKind::Name)
74 .field("datum", FieldKind::Date)
75 .field("kenteken", FieldKind::LicensePlate)
76 .build()
77 .unwrap()
78 }
79
80 fn make_record(id: u64) -> Record {
81 Record::new(id)
82 .insert("naam", FieldValue::Text("Alice de Vries".into()))
83 .insert("datum", FieldValue::Text("1990-03-15".into()))
84 .insert("kenteken", FieldValue::Text("12-ABC-3".into()))
85 }
86
87 fn make_record_b(id: u64) -> Record {
88 Record::new(id)
89 .insert("naam", FieldValue::Text("Alicia de Vrees".into()))
90 .insert("datum", FieldValue::Text("1990-03-15".into()))
91 .insert("kenteken", FieldValue::Text("12-ABC-3".into()))
92 }
93
94 #[test]
95 fn single_pair_uses_cpu_path() {
96 let schema = test_schema();
97 let backend = Arc::new(DeviceBackend::cpu());
98 let cmp = DeviceComparator::new(backend, &schema).unwrap();
99
100 let a = make_record(1);
101 let b = make_record_b(2);
102 let vec = cmp.compare(&a, &b, &schema);
103
104 assert_eq!(vec.record_a, 1);
105 assert_eq!(vec.record_b, 2);
106 assert_eq!(vec.levels.len(), 3);
107 }
108
109 #[test]
110 fn small_batch_uses_cpu_fallback() {
111 let schema = test_schema();
112 let backend = Arc::new(DeviceBackend::cpu());
113 let cmp = DeviceComparator::new(backend, &schema).unwrap();
114
115 let records: Vec<Record> = (0..20).map(|i| {
116 if i % 2 == 0 { make_record(i) } else { make_record_b(i) }
117 }).collect();
118 let pool = RecordPool::from_records(&records, &schema);
119 let indices: Vec<(usize, usize)> = (0..10).map(|i| (i * 2, i * 2 + 1)).collect();
120
121 let batch = cmp.compare_batch_from_pool(&pool, &indices, &schema);
122 assert_eq!(batch.n_pairs, 10);
123 assert_eq!(batch.n_fields, 3);
124 assert_eq!(batch.levels.len(), 3 * 10);
125 }
126
127 #[test]
128 fn empty_batch_returns_empty() {
129 let schema = test_schema();
130 let backend = Arc::new(DeviceBackend::cpu());
131 let cmp = DeviceComparator::new(backend, &schema).unwrap();
132 let pool = RecordPool::new(schema.fields.len());
133 let batch = cmp.compare_batch_from_pool(&pool, &[], &schema);
134 assert_eq!(batch.n_pairs, 0);
135 assert!(batch.levels.is_empty());
136 }
137
138 #[test]
139 fn exact_match_produces_exact_levels() {
140 let schema = test_schema();
141 let backend = Arc::new(DeviceBackend::cpu());
142 let cmp = DeviceComparator::new(backend, &schema).unwrap();
143
144 let r = Record::new(1)
145 .insert("naam", FieldValue::Text("Jan Jansen".into()))
146 .insert("datum", FieldValue::Text("1980-01-01".into()))
147 .insert("kenteken", FieldValue::Text("AB-123-C".into()));
148 let vec = cmp.compare(&r.clone(), &r, &schema);
149
150 for level in &vec.levels {
151 assert_eq!(*level, ComparisonLevel::Exact, "identical records should give Exact");
152 }
153 }
154
155 #[test]
156 fn completely_different_records_produce_none_levels() {
157 let schema = test_schema();
158 let backend = Arc::new(DeviceBackend::cpu());
159 let cmp = DeviceComparator::new(backend, &schema).unwrap();
160
161 let a = Record::new(1)
162 .insert("naam", FieldValue::Text("Henk".into()))
163 .insert("datum", FieldValue::Text("1950-01-01".into()))
164 .insert("kenteken", FieldValue::Text("XX-000-X".into()));
165 let b = Record::new(2)
166 .insert("naam", FieldValue::Text("Zäzä".into()))
167 .insert("datum", FieldValue::Text("2010-12-31".into()))
168 .insert("kenteken", FieldValue::Text("YY-999-Y".into()));
169
170 let vec = cmp.compare(&a, &b, &schema);
171 for level in &vec.levels {
172 assert!(
173 matches!(level, ComparisonLevel::None | ComparisonLevel::Partial),
174 "very different records should produce None or Partial levels"
175 );
176 }
177 }
178
179 fn synthetic_records(n: usize, schema: &Schema) -> Vec<Record> {
180 use rand::Rng;
181 let mut rng = rand::thread_rng();
182
183 let names = ["Alice", "Alicia", "Bob", "Robert", "Eva", "Eva-Marie",
184 "Jan", "Johan", "Petra", "Pietra", "Lena", "Lena-Marie"];
185 let dates = ["1990-01-15", "1990-01-16", "1985-06-20", "1975-03-03",
186 "2000-12-31", "2001-01-01", "1960-07-07", "1970-11-22"];
187 let plates = ["12-ABC-3", "12-ABD-3", "45-XYZ-6", "46-XYZ-6",
188 "AB-123-C", "AB-124-C", "ZZ-999-Z", "ZZ-998-Z"];
189
190 let fields: Vec<&str> = schema.fields.iter().map(|f| f.name.as_str()).collect();
191
192 (0..n).map(|i| {
193 let mut r = Record::new(i as u64);
194 for field in &fields {
195 let val = match *field {
196 "naam" => names[rng.gen_range(0..names.len())],
197 "datum" => dates[rng.gen_range(0..dates.len())],
198 "kenteken" => plates[rng.gen_range(0..plates.len())],
199 _ => "unknown",
200 };
201 r = r.insert(*field, FieldValue::Text(val.into()));
202 }
203 r
204 }).collect()
205 }
206
207 #[test]
208 fn large_batch_auto_detect_returns_correct_count() {
209 let schema = test_schema();
210 let backend = Arc::new(DeviceBackend::auto_detect());
211 let cmp = DeviceComparator::new(backend, &schema).unwrap();
212
213 let records = synthetic_records(4_000, &schema);
214 let pool = RecordPool::from_records(&records, &schema);
215 let indices: Vec<(usize, usize)> = (0..2_000).map(|i| (i * 2, i * 2 + 1)).collect();
216
217 let batch = cmp.compare_batch_from_pool(&pool, &indices, &schema);
218
219 assert_eq!(batch.n_pairs, 2_000, "compare_batch_from_pool must return one row per pair");
220 assert_eq!(batch.n_fields, 3, "each batch must have one column per field");
221 }
222
223}