1use std::sync::Arc;
4
5use zer_core::{
6 comparison::{ComparisonBatch, ComparisonVector},
7 scoring::{ModelParams, ScoredPair},
8 traits::{Result, Scorer},
9};
10
11use crate::{
12 backend::{cpu::CpuFallbackScorer, DeviceBackend},
13 error::GpuError,
14};
15
16pub(crate) const EM_GPU_MIN_PAIRS: usize = 50_000;
18
19pub struct DeviceScorer {
20 backend: Arc<DeviceBackend>,
21 cpu_fallback: CpuFallbackScorer,
22}
23
24impl DeviceScorer {
25 pub fn new(backend: Arc<DeviceBackend>) -> Self {
26 Self {
27 backend,
28 cpu_fallback: CpuFallbackScorer,
29 }
30 }
31
32 pub fn backend_name(&self) -> &'static str {
33 self.backend.name()
34 }
35}
36
37impl Scorer for DeviceScorer {
38 fn score(&self, vector: &ComparisonVector, params: &ModelParams) -> ScoredPair {
39 self.cpu_fallback.score(vector, params)
40 }
41
42 fn score_batch(&self, batch: &ComparisonBatch, params: &ModelParams) -> Vec<ScoredPair> {
43 self.cpu_fallback.score_batch(batch, params)
44 }
45
46 fn estimate_params(
47 &self,
48 batch: &ComparisonBatch,
49 init: Option<ModelParams>,
50 max_iter: usize,
51 ) -> Result<ModelParams> {
52 if self.backend.is_accelerated() && batch.n_pairs >= EM_GPU_MIN_PAIRS {
53 let result = zer_prof::trace!("zer_compute::estimate_params_accelerated", {
54 gpu_em_estimate(&self.backend, batch, init.clone(), max_iter)
55 });
56 match result {
57 Ok(params) => {
58 tracing::info!(backend = %self.backend.name(), "EM converged via accelerated backend");
59 return Ok(params);
60 }
61 Err(e) => {
62 tracing::warn!(%e, backend = %self.backend.name(), "accelerated EM failed, falling back to CPU");
63 }
64 }
65 } else if self.backend.is_accelerated() {
66 tracing::debug!(
67 n_pairs = batch.n_pairs,
68 threshold = EM_GPU_MIN_PAIRS,
69 "EM: batch below GPU threshold, using CPU path"
70 );
71 }
72 self.cpu_fallback.estimate_params(batch, init, max_iter)
73 }
74}
75
76#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
79fn build_estep_weights(params: &ModelParams, n_fields: usize) -> Vec<f32> {
80 const LEVELS: usize = 4;
81 let mut w = Vec::with_capacity(n_fields * LEVELS);
82 for f in 0..n_fields {
83 for l in 0..LEVELS {
84 let m = params.m[f][l].max(1e-15_f32);
85 let u = params.u[f][l].max(1e-15_f32);
86 w.push((m / u).ln());
87 }
88 }
89 w
90}
91
92#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
98fn gpu_em_estimate(
99 backend: &DeviceBackend,
100 batch: &ComparisonBatch,
101 init: Option<ModelParams>,
102 max_iter: usize,
103) -> std::result::Result<ModelParams, GpuError> {
104 if batch.n_pairs == 0 {
105 return Err(GpuError::LaunchFailed(
106 "EM requires at least one comparison pair".into(),
107 ));
108 }
109
110 if !backend.is_gpu() {
111 return crate::backend::cpu::cpu_estimate_params(batch, init, max_iter)
112 .map_err(|e| GpuError::LaunchFailed(e.to_string()));
113 }
114
115 let n_fields = batch.n_fields;
116 let n_pairs = batch.n_pairs;
117
118 let comparison_levels: Vec<u32> = batch.levels.iter().map(|&l| l as u32).collect();
120
121 let mut params = init.unwrap_or_else(|| {
122 let lambda = zer_compare::em::estimate_lambda(batch);
123 let log_prior_odds = (lambda / (1.0 - lambda)).ln();
124 ModelParams {
125 m: vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
126 u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
127 log_prior_odds,
128 upper_threshold: 0.9,
129 lower_threshold: 0.1,
130 }
131 });
132
133 let mut session = zer_prof::trace!("zer_compute::em_init_session", {
134 backend.em_init_session(&comparison_levels, n_pairs, n_fields)
135 })?;
136
137 let result: std::result::Result<ModelParams, GpuError> = (|| {
139 for _iter in 0..max_iter {
140 let weights = build_estep_weights(¶ms, n_fields);
141
142 let out = zer_prof::trace!("zer_compute::em_full_iteration", {
143 backend.em_run_iteration(&mut session, &weights, params.log_prior_odds)
144 })?;
145
146 let new_params = em_normalize(
147 &out.m_counts,
148 &out.u_counts,
149 out.total_match,
150 out.total_nonmatch,
151 n_fields,
152 );
153
154 if em_converged(¶ms, &new_params, n_fields) {
155 return Ok(new_params);
156 }
157 params = new_params;
158 }
159 Ok(params)
160 })();
161
162 backend.em_drop_session(session);
163 result
164}
165
166#[cfg(not(any(feature = "cuda", feature = "vulkan", feature = "avx2")))]
167fn gpu_em_estimate(
168 _backend: &DeviceBackend,
169 _batch: &ComparisonBatch,
170 _init: Option<ModelParams>,
171 _max_iter: usize,
172) -> std::result::Result<ModelParams, GpuError> {
173 Err(GpuError::BackendUnavailable(
174 "full-GPU EM requires the cuda or vulkan feature".into(),
175 ))
176}
177
178#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
179fn em_normalize(
180 m_counts: &[f32],
181 u_counts: &[f32],
182 total_match: f32,
183 total_nonmatch: f32,
184 n_fields: usize,
185) -> ModelParams {
186 const ALPHA: f32 = 1e-3;
187 const LEVELS: usize = 4;
188
189 let denom_m = (total_match + LEVELS as f32 * ALPHA).max(1e-9_f32);
190 let denom_u = (total_nonmatch + LEVELS as f32 * ALPHA).max(1e-9_f32);
191
192 let m: Vec<Vec<f32>> = (0..n_fields)
193 .map(|f| {
194 (0..LEVELS)
195 .map(|l| (m_counts[f * LEVELS + l] + ALPHA) / denom_m)
196 .collect()
197 })
198 .collect();
199 let u: Vec<Vec<f32>> = (0..n_fields)
200 .map(|f| {
201 (0..LEVELS)
202 .map(|l| (u_counts[f * LEVELS + l] + ALPHA) / denom_u)
203 .collect()
204 })
205 .collect();
206
207 let n_total = (total_match + total_nonmatch).max(1.0_f32);
208 let lambda = (total_match / n_total).max(0.001_f32).min(0.999_f32);
209 let log_prior_odds = (lambda / (1.0 - lambda)).ln();
210
211 ModelParams {
212 m,
213 u,
214 log_prior_odds,
215 upper_threshold: 0.9,
216 lower_threshold: 0.1,
217 }
218}
219
220#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
221fn em_converged(old: &ModelParams, new: &ModelParams, n_fields: usize) -> bool {
222 const TOL: f32 = 1e-6;
223 const LEVELS: usize = 4;
224 let mut max_delta = 0.0_f32;
225 for f in 0..n_fields {
226 for l in 0..LEVELS {
227 let dm = (old.m[f][l] - new.m[f][l]).abs();
228 let du = (old.u[f][l] - new.u[f][l]).abs();
229 max_delta = max_delta.max(dm).max(du);
230 }
231 }
232 max_delta < TOL
233}
234
235#[cfg(test)]
238mod tests {
239 use super::*;
240 use zer_core::{
241 comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector},
242 scoring::{MatchBand, ModelParams},
243 };
244
245 fn uniform_params(n_fields: usize) -> ModelParams {
246 ModelParams {
247 m: vec![vec![0.05, 0.10, 0.15, 0.70]; n_fields],
248 u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
249 log_prior_odds: 0.0,
250 upper_threshold: 0.9,
251 lower_threshold: 0.1,
252 }
253 }
254
255 fn all_exact_vector(n_fields: usize) -> ComparisonVector {
256 ComparisonVector::new(1, 2, vec![ComparisonLevel::Exact; n_fields])
257 }
258
259 fn all_none_vector(n_fields: usize) -> ComparisonVector {
260 ComparisonVector::new(3, 4, vec![ComparisonLevel::None; n_fields])
261 }
262
263 fn separable_batch(n_matches: usize, n_nonmatches: usize, n_fields: usize) -> ComparisonBatch {
264 let mut v = Vec::with_capacity(n_matches + n_nonmatches);
265 for i in 0..n_matches as u64 {
266 v.push(ComparisonVector::new(
267 i * 2,
268 i * 2 + 1,
269 vec![ComparisonLevel::Exact; n_fields],
270 ));
271 }
272 let off = (n_matches as u64) * 2;
273 for i in 0..n_nonmatches as u64 {
274 v.push(ComparisonVector::new(
275 off + i * 2,
276 off + i * 2 + 1,
277 vec![ComparisonLevel::None; n_fields],
278 ));
279 }
280 ComparisonBatch::from_vectors(&v)
281 }
282
283 #[test]
284 fn score_exact_match_gives_high_probability() {
285 let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
286 let params = uniform_params(3);
287 let v = all_exact_vector(3);
288 let pair = scorer.score(&v, ¶ms);
289
290 assert!(
291 pair.match_probability > 0.9,
292 "all-Exact vector should have high match_probability, got {}",
293 pair.match_probability
294 );
295 assert_eq!(pair.band, MatchBand::AutoMatch);
296 }
297
298 #[test]
299 fn score_none_gives_low_probability() {
300 let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
301 let params = uniform_params(3);
302 let v = all_none_vector(3);
303 let pair = scorer.score(&v, ¶ms);
304
305 assert!(
306 pair.match_probability < 0.1,
307 "all-None vector should have low match_probability, got {}",
308 pair.match_probability
309 );
310 assert_eq!(pair.band, MatchBand::AutoReject);
311 }
312
313 #[test]
314 fn score_batch_matches_individual_scores() {
315 let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
316 let params = uniform_params(4);
317 let vectors = vec![
318 all_exact_vector(4),
319 all_none_vector(4),
320 ComparisonVector::new(
321 5,
322 6,
323 vec![
324 ComparisonLevel::Exact,
325 ComparisonLevel::None,
326 ComparisonLevel::Close,
327 ComparisonLevel::Partial,
328 ],
329 ),
330 ];
331 let batch = ComparisonBatch::from_vectors(&vectors);
332 let batch_results = scorer.score_batch(&batch, ¶ms);
333
334 for (v, br) in vectors.iter().zip(batch_results.iter()) {
335 let single = scorer.score(v, ¶ms);
336 assert!(
337 (single.match_probability - br.match_probability).abs() < 1e-6,
338 "batch and individual scores must agree"
339 );
340 }
341 }
342
343 #[test]
344 fn estimate_params_converges_on_separable_data() {
345 let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
346 let n_fields = 4;
347 let batch = separable_batch(200, 1_000, n_fields);
348
349 let params = scorer
350 .estimate_params(&batch, None, 30)
351 .expect("EM should not return an error");
352
353 for f in 0..n_fields {
354 assert!(
355 params.m[f][3] > params.u[f][3],
356 "m[Exact] should exceed u[Exact] for separable data (field {f})"
357 );
358 }
359 }
360
361 #[test]
362 fn estimate_params_returns_error_on_empty_input() {
363 let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
364 let batch = ComparisonBatch::new(0, 0, vec![]);
365 let result = scorer.estimate_params(&batch, None, 10);
366 assert!(result.is_err(), "empty input should return an error");
367 }
368
369 #[test]
370 fn weight_table_is_consistent_with_params() {
371 use crate::soa::build_weight_table;
372
373 let params = uniform_params(3);
374 let table = build_weight_table(¶ms);
375
376 let weight_exact = table[0 * 4 + 3];
377 let expected = (0.70_f32 / 0.05_f32).ln();
378 assert!(
379 (weight_exact - expected).abs() < 1e-5,
380 "weight_table Exact entry mismatch: {weight_exact} vs {expected}"
381 );
382 }
383
384 #[test]
385 fn em_cpu_path_correct_below_threshold() {
386 let batch = separable_batch(200, 800, 4);
387 assert!(batch.n_pairs < EM_GPU_MIN_PAIRS);
388
389 let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
390 let params = scorer.estimate_params(&batch, None, 30).unwrap();
391 for f in 0..4 {
392 assert!(
393 params.m[f][3] > params.u[f][3],
394 "field {f}: m[Exact] must exceed u[Exact]"
395 );
396 }
397 }
398
399 #[cfg(feature = "cuda")]
400 #[test]
401 fn em_gpu_path_correct_above_threshold() {
402 let n_fields = 4;
403 let n_matches = EM_GPU_MIN_PAIRS / 5;
404 let n_nonmatches = EM_GPU_MIN_PAIRS;
405 let batch = separable_batch(n_matches, n_nonmatches, n_fields);
406 assert!(batch.n_pairs >= EM_GPU_MIN_PAIRS);
407
408 let params = gpu_em_estimate(&DeviceBackend::auto_detect(), &batch, None, 50)
409 .expect("gpu_em_estimate must not fail");
410 for f in 0..n_fields {
411 assert!(
412 params.m[f][3] > params.u[f][3],
413 "field {f}: m[Exact] must exceed u[Exact]"
414 );
415 }
416 }
417
418 #[cfg(feature = "cuda")]
419 #[test]
420 fn em_gpu_cpu_agree_on_key_parameters() {
421 let n_fields = 4;
422 let n_matches = EM_GPU_MIN_PAIRS / 5;
423 let n_nonmatches = EM_GPU_MIN_PAIRS;
424 let batch = separable_batch(n_matches, n_nonmatches, n_fields);
425 assert!(batch.n_pairs >= EM_GPU_MIN_PAIRS);
426
427 let cpu_params = gpu_em_estimate(&DeviceBackend::cpu(), &batch, None, 50).unwrap();
428 let gpu_params = gpu_em_estimate(&DeviceBackend::auto_detect(), &batch, None, 50).unwrap();
429
430 for f in 0..n_fields {
431 assert!(
432 cpu_params.m[f][3] > cpu_params.u[f][3],
433 "CPU path field {f}: m[Exact] must exceed u[Exact]"
434 );
435 assert!(
436 gpu_params.m[f][3] > gpu_params.u[f][3],
437 "GPU path field {f}: m[Exact] must exceed u[Exact]"
438 );
439
440 let dm_exact = (cpu_params.m[f][3] - gpu_params.m[f][3]).abs();
443 let du_exact = (cpu_params.u[f][3] - gpu_params.u[f][3]).abs();
444 assert!(
445 dm_exact < 0.15,
446 "field {f}: CPU/GPU m[Exact] differ by {dm_exact:.4} (cpu={:.4}, gpu={:.4})",
447 cpu_params.m[f][3],
448 gpu_params.m[f][3]
449 );
450 assert!(
451 du_exact < 0.15,
452 "field {f}: CPU/GPU u[Exact] differ by {du_exact:.4} (cpu={:.4}, gpu={:.4})",
453 cpu_params.u[f][3],
454 gpu_params.u[f][3]
455 );
456 }
457
458 assert!(
460 cpu_params.log_prior_odds < 0.0,
461 "CPU log_prior_odds should be negative for rare matches: {}",
462 cpu_params.log_prior_odds
463 );
464 assert!(
465 gpu_params.log_prior_odds < 0.0,
466 "GPU log_prior_odds should be negative for rare matches: {}",
467 gpu_params.log_prior_odds
468 );
469 let dlpo = (cpu_params.log_prior_odds - gpu_params.log_prior_odds).abs();
470 assert!(
471 dlpo < 1.0,
472 "log_prior_odds differ too much: cpu={:.4}, gpu={:.4}",
473 cpu_params.log_prior_odds,
474 gpu_params.log_prior_odds
475 );
476 }
477
478 #[test]
479 fn em_cpu_log_prior_odds_tracks_match_rate() {
480 let n_fields = 2;
482 let batch = separable_batch(100, 900, n_fields);
483 let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
484 let params = scorer.estimate_params(&batch, None, 50).unwrap();
485
486 assert!(
487 params.log_prior_odds < 0.0,
488 "log_prior_odds must be negative for 10% match rate: {}",
489 params.log_prior_odds
490 );
491 assert!(
492 params.log_prior_odds > -5.0,
493 "log_prior_odds too negative for 10% match rate: {}",
494 params.log_prior_odds
495 );
496 }
497
498 #[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
499 #[test]
500 fn em_normalize_updates_log_prior_odds() {
501 let m_counts = vec![25.0_f32, 25.0, 25.0, 25.0]; let u_counts = vec![225.0_f32, 225.0, 225.0, 225.0];
505 let total_match = 100.0_f32;
506 let total_nonmatch = 900.0_f32;
507 let params = em_normalize(&m_counts, &u_counts, total_match, total_nonmatch, 1);
508
509 let expected_lpo = (0.1_f32 / 0.9_f32).ln();
510 assert!(
511 (params.log_prior_odds - expected_lpo).abs() < 0.01,
512 "log_prior_odds mismatch: got {:.4}, expected {:.4}",
513 params.log_prior_odds,
514 expected_lpo
515 );
516 }
517
518 #[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
519 #[test]
520 fn em_converged_uses_raw_delta() {
521 let n_fields = 2;
522
523 let p1 = ModelParams {
525 m: vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
526 u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
527 log_prior_odds: -2.0,
528 upper_threshold: 0.9,
529 lower_threshold: 0.1,
530 };
531 let mut p2 = p1.clone();
532 p2.m[0][3] += 5e-7; assert!(
534 em_converged(&p1, &p2, n_fields),
535 "should converge for delta < 1e-6"
536 );
537
538 let mut p3 = p1.clone();
540 p3.m[0][3] += 2e-6;
541 assert!(
542 !em_converged(&p1, &p3, n_fields),
543 "should not converge for delta > 1e-6"
544 );
545 }
546}