Skip to main content

zer_compute/
scorer.rs

1//! `DeviceScorer`, implements the `Scorer` trait with GPU-accelerated EM.
2
3use std::sync::Arc;
4
5use zer_core::{
6    comparison::{ComparisonBatch, ComparisonVector},
7    scoring::{ModelParams, ScoredPair},
8    traits::{Result, Scorer},
9};
10
11use crate::{
12    backend::{cpu::CpuFallbackScorer, DeviceBackend},
13    error::GpuError,
14};
15
16/// Minimum pairs before the GPU EM path is used.
17pub(crate) const EM_GPU_MIN_PAIRS: usize = 50_000;
18
19pub struct DeviceScorer {
20    backend: Arc<DeviceBackend>,
21    cpu_fallback: CpuFallbackScorer,
22}
23
24impl DeviceScorer {
25    pub fn new(backend: Arc<DeviceBackend>) -> Self {
26        Self {
27            backend,
28            cpu_fallback: CpuFallbackScorer,
29        }
30    }
31
32    pub fn backend_name(&self) -> &'static str {
33        self.backend.name()
34    }
35}
36
37impl Scorer for DeviceScorer {
38    fn score(&self, vector: &ComparisonVector, params: &ModelParams) -> ScoredPair {
39        self.cpu_fallback.score(vector, params)
40    }
41
42    fn score_batch(&self, batch: &ComparisonBatch, params: &ModelParams) -> Vec<ScoredPair> {
43        self.cpu_fallback.score_batch(batch, params)
44    }
45
46    fn estimate_params(
47        &self,
48        batch: &ComparisonBatch,
49        init: Option<ModelParams>,
50        max_iter: usize,
51    ) -> Result<ModelParams> {
52        if self.backend.is_accelerated() && batch.n_pairs >= EM_GPU_MIN_PAIRS {
53            let result = zer_prof::trace!("zer_compute::estimate_params_accelerated", {
54                gpu_em_estimate(&self.backend, batch, init.clone(), max_iter)
55            });
56            match result {
57                Ok(params) => {
58                    tracing::info!(backend = %self.backend.name(), "EM converged via accelerated backend");
59                    return Ok(params);
60                }
61                Err(e) => {
62                    tracing::warn!(%e, backend = %self.backend.name(), "accelerated EM failed, falling back to CPU");
63                }
64            }
65        } else if self.backend.is_accelerated() {
66            tracing::debug!(
67                n_pairs = batch.n_pairs,
68                threshold = EM_GPU_MIN_PAIRS,
69                "EM: batch below GPU threshold, using CPU path"
70            );
71        }
72        self.cpu_fallback.estimate_params(batch, init, max_iter)
73    }
74}
75
76// ── GPU EM loop ───────────────────────────────────────────────────────────────
77
78#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
79fn build_estep_weights(params: &ModelParams, n_fields: usize) -> Vec<f32> {
80    const LEVELS: usize = 4;
81    let mut w = Vec::with_capacity(n_fields * LEVELS);
82    for f in 0..n_fields {
83        for l in 0..LEVELS {
84            let m = params.m[f][l].max(1e-15_f32);
85            let u = params.u[f][l].max(1e-15_f32);
86            w.push((m / u).ln());
87        }
88    }
89    w
90}
91
92/// Run the full EM algorithm on the GPU backend.
93///
94/// `comparison_levels` is uploaded once before the loop as a trivial
95/// `u8 → u32` cast, the `ComparisonBatch` is already field-major, so no
96/// transposition is needed.
97#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
98fn gpu_em_estimate(
99    backend: &DeviceBackend,
100    batch: &ComparisonBatch,
101    init: Option<ModelParams>,
102    max_iter: usize,
103) -> std::result::Result<ModelParams, GpuError> {
104    if batch.n_pairs == 0 {
105        return Err(GpuError::LaunchFailed(
106            "EM requires at least one comparison pair".into(),
107        ));
108    }
109
110    if !backend.is_gpu() {
111        return crate::backend::cpu::cpu_estimate_params(batch, init, max_iter)
112            .map_err(|e| GpuError::LaunchFailed(e.to_string()));
113    }
114
115    let n_fields = batch.n_fields;
116    let n_pairs = batch.n_pairs;
117
118    // ComparisonBatch.levels is already field-major u8, just widen to u32.
119    let comparison_levels: Vec<u32> = batch.levels.iter().map(|&l| l as u32).collect();
120
121    let mut params = init.unwrap_or_else(|| {
122        let lambda = zer_compare::em::estimate_lambda(batch);
123        let log_prior_odds = (lambda / (1.0 - lambda)).ln();
124        ModelParams {
125            m: vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
126            u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
127            log_prior_odds,
128            upper_threshold: 0.9,
129            lower_threshold: 0.1,
130        }
131    });
132
133    let mut session = zer_prof::trace!("zer_compute::em_init_session", {
134        backend.em_init_session(&comparison_levels, n_pairs, n_fields)
135    })?;
136
137    // Closure ensures em_drop_session runs even if an iteration returns Err.
138    let result: std::result::Result<ModelParams, GpuError> = (|| {
139        for _iter in 0..max_iter {
140            let weights = build_estep_weights(&params, n_fields);
141
142            let out = zer_prof::trace!("zer_compute::em_full_iteration", {
143                backend.em_run_iteration(&mut session, &weights, params.log_prior_odds)
144            })?;
145
146            let new_params = em_normalize(
147                &out.m_counts,
148                &out.u_counts,
149                out.total_match,
150                out.total_nonmatch,
151                n_fields,
152            );
153
154            if em_converged(&params, &new_params, n_fields) {
155                return Ok(new_params);
156            }
157            params = new_params;
158        }
159        Ok(params)
160    })();
161
162    backend.em_drop_session(session);
163    result
164}
165
166#[cfg(not(any(feature = "cuda", feature = "vulkan", feature = "avx2")))]
167fn gpu_em_estimate(
168    _backend: &DeviceBackend,
169    _batch: &ComparisonBatch,
170    _init: Option<ModelParams>,
171    _max_iter: usize,
172) -> std::result::Result<ModelParams, GpuError> {
173    Err(GpuError::BackendUnavailable(
174        "full-GPU EM requires the cuda or vulkan feature".into(),
175    ))
176}
177
178#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
179fn em_normalize(
180    m_counts: &[f32],
181    u_counts: &[f32],
182    total_match: f32,
183    total_nonmatch: f32,
184    n_fields: usize,
185) -> ModelParams {
186    const ALPHA: f32 = 1e-3;
187    const LEVELS: usize = 4;
188
189    let denom_m = (total_match + LEVELS as f32 * ALPHA).max(1e-9_f32);
190    let denom_u = (total_nonmatch + LEVELS as f32 * ALPHA).max(1e-9_f32);
191
192    let m: Vec<Vec<f32>> = (0..n_fields)
193        .map(|f| {
194            (0..LEVELS)
195                .map(|l| (m_counts[f * LEVELS + l] + ALPHA) / denom_m)
196                .collect()
197        })
198        .collect();
199    let u: Vec<Vec<f32>> = (0..n_fields)
200        .map(|f| {
201            (0..LEVELS)
202                .map(|l| (u_counts[f * LEVELS + l] + ALPHA) / denom_u)
203                .collect()
204        })
205        .collect();
206
207    let n_total = (total_match + total_nonmatch).max(1.0_f32);
208    let lambda = (total_match / n_total).max(0.001_f32).min(0.999_f32);
209    let log_prior_odds = (lambda / (1.0 - lambda)).ln();
210
211    ModelParams {
212        m,
213        u,
214        log_prior_odds,
215        upper_threshold: 0.9,
216        lower_threshold: 0.1,
217    }
218}
219
220#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
221fn em_converged(old: &ModelParams, new: &ModelParams, n_fields: usize) -> bool {
222    const TOL: f32 = 1e-6;
223    const LEVELS: usize = 4;
224    let mut max_delta = 0.0_f32;
225    for f in 0..n_fields {
226        for l in 0..LEVELS {
227            let dm = (old.m[f][l] - new.m[f][l]).abs();
228            let du = (old.u[f][l] - new.u[f][l]).abs();
229            max_delta = max_delta.max(dm).max(du);
230        }
231    }
232    max_delta < TOL
233}
234
235// ── Unit tests ────────────────────────────────────────────────────────────────
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use zer_core::{
241        comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector},
242        scoring::{MatchBand, ModelParams},
243    };
244
245    fn uniform_params(n_fields: usize) -> ModelParams {
246        ModelParams {
247            m: vec![vec![0.05, 0.10, 0.15, 0.70]; n_fields],
248            u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
249            log_prior_odds: 0.0,
250            upper_threshold: 0.9,
251            lower_threshold: 0.1,
252        }
253    }
254
255    fn all_exact_vector(n_fields: usize) -> ComparisonVector {
256        ComparisonVector::new(1, 2, vec![ComparisonLevel::Exact; n_fields])
257    }
258
259    fn all_none_vector(n_fields: usize) -> ComparisonVector {
260        ComparisonVector::new(3, 4, vec![ComparisonLevel::None; n_fields])
261    }
262
263    fn separable_batch(n_matches: usize, n_nonmatches: usize, n_fields: usize) -> ComparisonBatch {
264        let mut v = Vec::with_capacity(n_matches + n_nonmatches);
265        for i in 0..n_matches as u64 {
266            v.push(ComparisonVector::new(
267                i * 2,
268                i * 2 + 1,
269                vec![ComparisonLevel::Exact; n_fields],
270            ));
271        }
272        let off = (n_matches as u64) * 2;
273        for i in 0..n_nonmatches as u64 {
274            v.push(ComparisonVector::new(
275                off + i * 2,
276                off + i * 2 + 1,
277                vec![ComparisonLevel::None; n_fields],
278            ));
279        }
280        ComparisonBatch::from_vectors(&v)
281    }
282
283    #[test]
284    fn score_exact_match_gives_high_probability() {
285        let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
286        let params = uniform_params(3);
287        let v = all_exact_vector(3);
288        let pair = scorer.score(&v, &params);
289
290        assert!(
291            pair.match_probability > 0.9,
292            "all-Exact vector should have high match_probability, got {}",
293            pair.match_probability
294        );
295        assert_eq!(pair.band, MatchBand::AutoMatch);
296    }
297
298    #[test]
299    fn score_none_gives_low_probability() {
300        let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
301        let params = uniform_params(3);
302        let v = all_none_vector(3);
303        let pair = scorer.score(&v, &params);
304
305        assert!(
306            pair.match_probability < 0.1,
307            "all-None vector should have low match_probability, got {}",
308            pair.match_probability
309        );
310        assert_eq!(pair.band, MatchBand::AutoReject);
311    }
312
313    #[test]
314    fn score_batch_matches_individual_scores() {
315        let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
316        let params = uniform_params(4);
317        let vectors = vec![
318            all_exact_vector(4),
319            all_none_vector(4),
320            ComparisonVector::new(
321                5,
322                6,
323                vec![
324                    ComparisonLevel::Exact,
325                    ComparisonLevel::None,
326                    ComparisonLevel::Close,
327                    ComparisonLevel::Partial,
328                ],
329            ),
330        ];
331        let batch = ComparisonBatch::from_vectors(&vectors);
332        let batch_results = scorer.score_batch(&batch, &params);
333
334        for (v, br) in vectors.iter().zip(batch_results.iter()) {
335            let single = scorer.score(v, &params);
336            assert!(
337                (single.match_probability - br.match_probability).abs() < 1e-6,
338                "batch and individual scores must agree"
339            );
340        }
341    }
342
343    #[test]
344    fn estimate_params_converges_on_separable_data() {
345        let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
346        let n_fields = 4;
347        let batch = separable_batch(200, 1_000, n_fields);
348
349        let params = scorer
350            .estimate_params(&batch, None, 30)
351            .expect("EM should not return an error");
352
353        for f in 0..n_fields {
354            assert!(
355                params.m[f][3] > params.u[f][3],
356                "m[Exact] should exceed u[Exact] for separable data (field {f})"
357            );
358        }
359    }
360
361    #[test]
362    fn estimate_params_returns_error_on_empty_input() {
363        let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
364        let batch = ComparisonBatch::new(0, 0, vec![]);
365        let result = scorer.estimate_params(&batch, None, 10);
366        assert!(result.is_err(), "empty input should return an error");
367    }
368
369    #[test]
370    fn weight_table_is_consistent_with_params() {
371        use crate::soa::build_weight_table;
372
373        let params = uniform_params(3);
374        let table = build_weight_table(&params);
375
376        let weight_exact = table[0 * 4 + 3];
377        let expected = (0.70_f32 / 0.05_f32).ln();
378        assert!(
379            (weight_exact - expected).abs() < 1e-5,
380            "weight_table Exact entry mismatch: {weight_exact} vs {expected}"
381        );
382    }
383
384    #[test]
385    fn em_cpu_path_correct_below_threshold() {
386        let batch = separable_batch(200, 800, 4);
387        assert!(batch.n_pairs < EM_GPU_MIN_PAIRS);
388
389        let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
390        let params = scorer.estimate_params(&batch, None, 30).unwrap();
391        for f in 0..4 {
392            assert!(
393                params.m[f][3] > params.u[f][3],
394                "field {f}: m[Exact] must exceed u[Exact]"
395            );
396        }
397    }
398
399    #[cfg(feature = "cuda")]
400    #[test]
401    fn em_gpu_path_correct_above_threshold() {
402        let n_fields = 4;
403        let n_matches = EM_GPU_MIN_PAIRS / 5;
404        let n_nonmatches = EM_GPU_MIN_PAIRS;
405        let batch = separable_batch(n_matches, n_nonmatches, n_fields);
406        assert!(batch.n_pairs >= EM_GPU_MIN_PAIRS);
407
408        let params = gpu_em_estimate(&DeviceBackend::auto_detect(), &batch, None, 50)
409            .expect("gpu_em_estimate must not fail");
410        for f in 0..n_fields {
411            assert!(
412                params.m[f][3] > params.u[f][3],
413                "field {f}: m[Exact] must exceed u[Exact]"
414            );
415        }
416    }
417
418    #[cfg(feature = "cuda")]
419    #[test]
420    fn em_gpu_cpu_agree_on_key_parameters() {
421        let n_fields = 4;
422        let n_matches = EM_GPU_MIN_PAIRS / 5;
423        let n_nonmatches = EM_GPU_MIN_PAIRS;
424        let batch = separable_batch(n_matches, n_nonmatches, n_fields);
425        assert!(batch.n_pairs >= EM_GPU_MIN_PAIRS);
426
427        let cpu_params = gpu_em_estimate(&DeviceBackend::cpu(), &batch, None, 50).unwrap();
428        let gpu_params = gpu_em_estimate(&DeviceBackend::auto_detect(), &batch, None, 50).unwrap();
429
430        for f in 0..n_fields {
431            assert!(
432                cpu_params.m[f][3] > cpu_params.u[f][3],
433                "CPU path field {f}: m[Exact] must exceed u[Exact]"
434            );
435            assert!(
436                gpu_params.m[f][3] > gpu_params.u[f][3],
437                "GPU path field {f}: m[Exact] must exceed u[Exact]"
438            );
439
440            // Both paths must agree that Exact is a strong match signal.
441            // Allow ≤ 0.15 absolute difference in m[Exact] and u[Exact].
442            let dm_exact = (cpu_params.m[f][3] - gpu_params.m[f][3]).abs();
443            let du_exact = (cpu_params.u[f][3] - gpu_params.u[f][3]).abs();
444            assert!(
445                dm_exact < 0.15,
446                "field {f}: CPU/GPU m[Exact] differ by {dm_exact:.4} (cpu={:.4}, gpu={:.4})",
447                cpu_params.m[f][3],
448                gpu_params.m[f][3]
449            );
450            assert!(
451                du_exact < 0.15,
452                "field {f}: CPU/GPU u[Exact] differ by {du_exact:.4} (cpu={:.4}, gpu={:.4})",
453                cpu_params.u[f][3],
454                gpu_params.u[f][3]
455            );
456        }
457
458        // Both paths must produce a negative log_prior_odds for a 1:5 match rate.
459        assert!(
460            cpu_params.log_prior_odds < 0.0,
461            "CPU log_prior_odds should be negative for rare matches: {}",
462            cpu_params.log_prior_odds
463        );
464        assert!(
465            gpu_params.log_prior_odds < 0.0,
466            "GPU log_prior_odds should be negative for rare matches: {}",
467            gpu_params.log_prior_odds
468        );
469        let dlpo = (cpu_params.log_prior_odds - gpu_params.log_prior_odds).abs();
470        assert!(
471            dlpo < 1.0,
472            "log_prior_odds differ too much: cpu={:.4}, gpu={:.4}",
473            cpu_params.log_prior_odds,
474            gpu_params.log_prior_odds
475        );
476    }
477
478    #[test]
479    fn em_cpu_log_prior_odds_tracks_match_rate() {
480        // 1 match in 10 pairs → lambda ≈ 0.1 → log_prior_odds ≈ ln(0.1/0.9) ≈ -2.2
481        let n_fields = 2;
482        let batch = separable_batch(100, 900, n_fields);
483        let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
484        let params = scorer.estimate_params(&batch, None, 50).unwrap();
485
486        assert!(
487            params.log_prior_odds < 0.0,
488            "log_prior_odds must be negative for 10% match rate: {}",
489            params.log_prior_odds
490        );
491        assert!(
492            params.log_prior_odds > -5.0,
493            "log_prior_odds too negative for 10% match rate: {}",
494            params.log_prior_odds
495        );
496    }
497
498    #[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
499    #[test]
500    fn em_normalize_updates_log_prior_odds() {
501        // Verify em_normalize computes log_prior_odds from total_match/total_nonmatch.
502        // With total_match=100, total_nonmatch=900, lambda=0.1, log_prior_odds≈-2.2.
503        let m_counts = vec![25.0_f32, 25.0, 25.0, 25.0]; // uniform over levels
504        let u_counts = vec![225.0_f32, 225.0, 225.0, 225.0];
505        let total_match = 100.0_f32;
506        let total_nonmatch = 900.0_f32;
507        let params = em_normalize(&m_counts, &u_counts, total_match, total_nonmatch, 1);
508
509        let expected_lpo = (0.1_f32 / 0.9_f32).ln();
510        assert!(
511            (params.log_prior_odds - expected_lpo).abs() < 0.01,
512            "log_prior_odds mismatch: got {:.4}, expected {:.4}",
513            params.log_prior_odds,
514            expected_lpo
515        );
516    }
517
518    #[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
519    #[test]
520    fn em_converged_uses_raw_delta() {
521        let n_fields = 2;
522
523        // Params that differ by < 1e-6, should be considered converged.
524        let p1 = ModelParams {
525            m: vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
526            u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
527            log_prior_odds: -2.0,
528            upper_threshold: 0.9,
529            lower_threshold: 0.1,
530        };
531        let mut p2 = p1.clone();
532        p2.m[0][3] += 5e-7; // tiny delta
533        assert!(
534            em_converged(&p1, &p2, n_fields),
535            "should converge for delta < 1e-6"
536        );
537
538        // Params that differ by > 1e-6, should not converge.
539        let mut p3 = p1.clone();
540        p3.m[0][3] += 2e-6;
541        assert!(
542            !em_converged(&p1, &p3, n_fields),
543            "should not converge for delta > 1e-6"
544        );
545    }
546}