1use std::sync::Arc;
8
9use zer_core::{
10 comparison::{ComparisonBatch, ComparisonVector},
11 record::Record,
12 record_pool::RecordPool,
13 schema::Schema,
14 traits::Comparator,
15};
16
17use crate::{
18 backend::{cpu::CpuFallbackComparator, DeviceBackend},
19 error::GpuError,
20};
21
22pub struct DeviceComparator {
23 backend: Arc<DeviceBackend>,
24 cpu_fallback: CpuFallbackComparator,
25}
26
27impl DeviceComparator {
28 pub fn new(backend: Arc<DeviceBackend>, schema: &Schema) -> Result<Self, GpuError> {
29 let cpu_fallback = CpuFallbackComparator::from_schema(schema);
30 Ok(Self {
31 backend,
32 cpu_fallback,
33 })
34 }
35
36 pub fn backend_name(&self) -> &'static str {
37 self.backend.name()
38 }
39}
40
41impl Comparator for DeviceComparator {
42 fn compare(&self, a: &Record, b: &Record, schema: &Schema) -> ComparisonVector {
43 self.cpu_fallback.compare(a, b, schema)
44 }
45
46 fn compare_batch_from_pool(
47 &self,
48 pool: &RecordPool,
49 indices: &[(usize, usize)],
50 schema: &Schema,
51 ) -> ComparisonBatch {
52 if indices.is_empty() {
53 return ComparisonBatch::new(0, schema.fields.len(), vec![]);
54 }
55
56 self.cpu_fallback
59 .compare_batch_from_pool(pool, indices, schema)
60 }
61}
62
63#[cfg(test)]
66mod tests {
67 use super::*;
68 use zer_core::{
69 comparison::ComparisonLevel,
70 record::{FieldValue, Record},
71 record_pool::RecordPool,
72 schema::{FieldKind, SchemaBuilder},
73 };
74
75 fn test_schema() -> Schema {
76 SchemaBuilder::new()
77 .field("naam", FieldKind::Name)
78 .field("datum", FieldKind::Date)
79 .field("kenteken", FieldKind::LicensePlate)
80 .build()
81 .unwrap()
82 }
83
84 fn make_record(id: u64) -> Record {
85 Record::new(id)
86 .insert("naam", FieldValue::Text("Alice de Vries".into()))
87 .insert("datum", FieldValue::Text("1990-03-15".into()))
88 .insert("kenteken", FieldValue::Text("12-ABC-3".into()))
89 }
90
91 fn make_record_b(id: u64) -> Record {
92 Record::new(id)
93 .insert("naam", FieldValue::Text("Alicia de Vrees".into()))
94 .insert("datum", FieldValue::Text("1990-03-15".into()))
95 .insert("kenteken", FieldValue::Text("12-ABC-3".into()))
96 }
97
98 #[test]
99 fn single_pair_uses_cpu_path() {
100 let schema = test_schema();
101 let backend = Arc::new(DeviceBackend::cpu());
102 let cmp = DeviceComparator::new(backend, &schema).unwrap();
103
104 let a = make_record(1);
105 let b = make_record_b(2);
106 let vec = cmp.compare(&a, &b, &schema);
107
108 assert_eq!(vec.record_a, 1);
109 assert_eq!(vec.record_b, 2);
110 assert_eq!(vec.levels.len(), 3);
111 }
112
113 #[test]
114 fn small_batch_uses_cpu_fallback() {
115 let schema = test_schema();
116 let backend = Arc::new(DeviceBackend::cpu());
117 let cmp = DeviceComparator::new(backend, &schema).unwrap();
118
119 let records: Vec<Record> = (0..20)
120 .map(|i| {
121 if i % 2 == 0 {
122 make_record(i)
123 } else {
124 make_record_b(i)
125 }
126 })
127 .collect();
128 let pool = RecordPool::from_records(&records, &schema);
129 let indices: Vec<(usize, usize)> = (0..10).map(|i| (i * 2, i * 2 + 1)).collect();
130
131 let batch = cmp.compare_batch_from_pool(&pool, &indices, &schema);
132 assert_eq!(batch.n_pairs, 10);
133 assert_eq!(batch.n_fields, 3);
134 assert_eq!(batch.levels.len(), 3 * 10);
135 }
136
137 #[test]
138 fn empty_batch_returns_empty() {
139 let schema = test_schema();
140 let backend = Arc::new(DeviceBackend::cpu());
141 let cmp = DeviceComparator::new(backend, &schema).unwrap();
142 let pool = RecordPool::new(schema.fields.len());
143 let batch = cmp.compare_batch_from_pool(&pool, &[], &schema);
144 assert_eq!(batch.n_pairs, 0);
145 assert!(batch.levels.is_empty());
146 }
147
148 #[test]
149 fn exact_match_produces_exact_levels() {
150 let schema = test_schema();
151 let backend = Arc::new(DeviceBackend::cpu());
152 let cmp = DeviceComparator::new(backend, &schema).unwrap();
153
154 let r = Record::new(1)
155 .insert("naam", FieldValue::Text("Jan Jansen".into()))
156 .insert("datum", FieldValue::Text("1980-01-01".into()))
157 .insert("kenteken", FieldValue::Text("AB-123-C".into()));
158 let vec = cmp.compare(&r.clone(), &r, &schema);
159
160 for level in &vec.levels {
161 assert_eq!(
162 *level,
163 ComparisonLevel::Exact,
164 "identical records should give Exact"
165 );
166 }
167 }
168
169 #[test]
170 fn completely_different_records_produce_none_levels() {
171 let schema = test_schema();
172 let backend = Arc::new(DeviceBackend::cpu());
173 let cmp = DeviceComparator::new(backend, &schema).unwrap();
174
175 let a = Record::new(1)
176 .insert("naam", FieldValue::Text("Henk".into()))
177 .insert("datum", FieldValue::Text("1950-01-01".into()))
178 .insert("kenteken", FieldValue::Text("XX-000-X".into()));
179 let b = Record::new(2)
180 .insert("naam", FieldValue::Text("Zäzä".into()))
181 .insert("datum", FieldValue::Text("2010-12-31".into()))
182 .insert("kenteken", FieldValue::Text("YY-999-Y".into()));
183
184 let vec = cmp.compare(&a, &b, &schema);
185 for level in &vec.levels {
186 assert!(
187 matches!(level, ComparisonLevel::None | ComparisonLevel::Partial),
188 "very different records should produce None or Partial levels"
189 );
190 }
191 }
192
193 fn synthetic_records(n: usize, schema: &Schema) -> Vec<Record> {
194 use rand::Rng;
195 let mut rng = rand::thread_rng();
196
197 let names = [
198 "Alice",
199 "Alicia",
200 "Bob",
201 "Robert",
202 "Eva",
203 "Eva-Marie",
204 "Jan",
205 "Johan",
206 "Petra",
207 "Pietra",
208 "Lena",
209 "Lena-Marie",
210 ];
211 let dates = [
212 "1990-01-15",
213 "1990-01-16",
214 "1985-06-20",
215 "1975-03-03",
216 "2000-12-31",
217 "2001-01-01",
218 "1960-07-07",
219 "1970-11-22",
220 ];
221 let plates = [
222 "12-ABC-3", "12-ABD-3", "45-XYZ-6", "46-XYZ-6", "AB-123-C", "AB-124-C", "ZZ-999-Z",
223 "ZZ-998-Z",
224 ];
225
226 let fields: Vec<&str> = schema.fields.iter().map(|f| f.name.as_str()).collect();
227
228 (0..n)
229 .map(|i| {
230 let mut r = Record::new(i as u64);
231 for field in &fields {
232 let val = match *field {
233 "naam" => names[rng.gen_range(0..names.len())],
234 "datum" => dates[rng.gen_range(0..dates.len())],
235 "kenteken" => plates[rng.gen_range(0..plates.len())],
236 _ => "unknown",
237 };
238 r = r.insert(*field, FieldValue::Text(val.into()));
239 }
240 r
241 })
242 .collect()
243 }
244
245 #[test]
246 fn large_batch_auto_detect_returns_correct_count() {
247 let schema = test_schema();
248 let backend = Arc::new(DeviceBackend::auto_detect());
249 let cmp = DeviceComparator::new(backend, &schema).unwrap();
250
251 let records = synthetic_records(4_000, &schema);
252 let pool = RecordPool::from_records(&records, &schema);
253 let indices: Vec<(usize, usize)> = (0..2_000).map(|i| (i * 2, i * 2 + 1)).collect();
254
255 let batch = cmp.compare_batch_from_pool(&pool, &indices, &schema);
256
257 assert_eq!(
258 batch.n_pairs, 2_000,
259 "compare_batch_from_pool must return one row per pair"
260 );
261 assert_eq!(
262 batch.n_fields, 3,
263 "each batch must have one column per field"
264 );
265 }
266}