Skip to main content

zer_compute/
scorer.rs

1//! `DeviceScorer`, implements the `Scorer` trait with GPU-accelerated EM.
2
3use std::sync::Arc;
4
5use zer_core::{
6    comparison::{ComparisonBatch, ComparisonVector},
7    scoring::{ModelParams, ScoredPair},
8    traits::{Result, Scorer},
9};
10
11use crate::{
12    backend::{cpu::CpuFallbackScorer, DeviceBackend},
13    error::GpuError,
14};
15
16/// Minimum pairs before the GPU EM path is used.
17pub(crate) const EM_GPU_MIN_PAIRS: usize = 50_000;
18
19pub struct DeviceScorer {
20    backend:      Arc<DeviceBackend>,
21    cpu_fallback: CpuFallbackScorer,
22}
23
24impl DeviceScorer {
25    pub fn new(backend: Arc<DeviceBackend>) -> Self {
26        Self { backend, cpu_fallback: CpuFallbackScorer }
27    }
28
29    pub fn backend_name(&self) -> &'static str {
30        self.backend.name()
31    }
32}
33
34impl Scorer for DeviceScorer {
35    fn score(&self, vector: &ComparisonVector, params: &ModelParams) -> ScoredPair {
36        self.cpu_fallback.score(vector, params)
37    }
38
39    fn score_batch(&self, batch: &ComparisonBatch, params: &ModelParams) -> Vec<ScoredPair> {
40        self.cpu_fallback.score_batch(batch, params)
41    }
42
43    fn estimate_params(
44        &self,
45        batch:    &ComparisonBatch,
46        init:     Option<ModelParams>,
47        max_iter: usize,
48    ) -> Result<ModelParams> {
49        if self.backend.is_accelerated() && batch.n_pairs >= EM_GPU_MIN_PAIRS {
50            let result = zer_prof::trace!("zer_compute::estimate_params_accelerated", {
51                gpu_em_estimate(&self.backend, batch, init.clone(), max_iter)
52            });
53            match result {
54                Ok(params) => {
55                    tracing::info!(backend = %self.backend.name(), "EM converged via accelerated backend");
56                    return Ok(params);
57                }
58                Err(e) => {
59                    tracing::warn!(%e, backend = %self.backend.name(), "accelerated EM failed, falling back to CPU");
60                }
61            }
62        } else if self.backend.is_accelerated() {
63            tracing::debug!(
64                n_pairs = batch.n_pairs,
65                threshold = EM_GPU_MIN_PAIRS,
66                "EM: batch below GPU threshold, using CPU path"
67            );
68        }
69        self.cpu_fallback.estimate_params(batch, init, max_iter)
70    }
71}
72
73// ── GPU EM loop ───────────────────────────────────────────────────────────────
74
75#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
76fn build_estep_weights(params: &ModelParams, n_fields: usize) -> Vec<f32> {
77    const LEVELS: usize = 4;
78    let mut w = Vec::with_capacity(n_fields * LEVELS);
79    for f in 0..n_fields {
80        for l in 0..LEVELS {
81            let m = params.m[f][l].max(1e-15_f32);
82            let u = params.u[f][l].max(1e-15_f32);
83            w.push((m / u).ln());
84        }
85    }
86    w
87}
88
89/// Run the full EM algorithm on the GPU backend.
90///
91/// `comparison_levels` is uploaded once before the loop as a trivial
92/// `u8 → u32` cast, the `ComparisonBatch` is already field-major, so no
93/// transposition is needed.
94#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
95fn gpu_em_estimate(
96    backend:  &DeviceBackend,
97    batch:    &ComparisonBatch,
98    init:     Option<ModelParams>,
99    max_iter: usize,
100) -> std::result::Result<ModelParams, GpuError> {
101    if batch.n_pairs == 0 {
102        return Err(GpuError::LaunchFailed("EM requires at least one comparison pair".into()));
103    }
104
105    if !backend.is_gpu() {
106        return crate::backend::cpu::cpu_estimate_params(batch, init, max_iter)
107            .map_err(|e| GpuError::LaunchFailed(e.to_string()));
108    }
109
110    let n_fields = batch.n_fields;
111    let n_pairs  = batch.n_pairs;
112
113    // ComparisonBatch.levels is already field-major u8, just widen to u32.
114    let comparison_levels: Vec<u32> = batch.levels.iter().map(|&l| l as u32).collect();
115
116    let mut params = init.unwrap_or_else(|| {
117        let lambda = zer_compare::em::estimate_lambda(batch);
118        let log_prior_odds = (lambda / (1.0 - lambda)).ln();
119        ModelParams {
120            m:               vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
121            u:               vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
122            log_prior_odds,
123            upper_threshold: 0.9,
124            lower_threshold: 0.1,
125        }
126    });
127
128    let mut session = zer_prof::trace!("zer_compute::em_init_session", {
129        backend.em_init_session(&comparison_levels, n_pairs, n_fields)
130    })?;
131
132    // Closure ensures em_drop_session runs even if an iteration returns Err.
133    let result: std::result::Result<ModelParams, GpuError> = (|| {
134        for _iter in 0..max_iter {
135            let weights = build_estep_weights(&params, n_fields);
136
137            let out = zer_prof::trace!("zer_compute::em_full_iteration", {
138                backend.em_run_iteration(&mut session, &weights, params.log_prior_odds)
139            })?;
140
141            let new_params = em_normalize(
142                &out.m_counts, &out.u_counts,
143                out.total_match, out.total_nonmatch,
144                n_fields,
145            );
146
147            if em_converged(&params, &new_params, n_fields) {
148                return Ok(new_params);
149            }
150            params = new_params;
151        }
152        Ok(params)
153    })();
154
155    backend.em_drop_session(session);
156    result
157}
158
159#[cfg(not(any(feature = "cuda", feature = "vulkan", feature = "avx2")))]
160fn gpu_em_estimate(
161    _backend:  &DeviceBackend,
162    _batch:    &ComparisonBatch,
163    _init:     Option<ModelParams>,
164    _max_iter: usize,
165) -> std::result::Result<ModelParams, GpuError> {
166    Err(GpuError::BackendUnavailable(
167        "full-GPU EM requires the cuda or vulkan feature".into(),
168    ))
169}
170
171#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
172fn em_normalize(
173    m_counts:       &[f32],
174    u_counts:       &[f32],
175    total_match:    f32,
176    total_nonmatch: f32,
177    n_fields:       usize,
178) -> ModelParams {
179    const ALPHA: f32 = 1e-3;
180    const LEVELS: usize = 4;
181
182    let denom_m = (total_match    + LEVELS as f32 * ALPHA).max(1e-9_f32);
183    let denom_u = (total_nonmatch + LEVELS as f32 * ALPHA).max(1e-9_f32);
184
185    let m: Vec<Vec<f32>> = (0..n_fields)
186        .map(|f| (0..LEVELS).map(|l| (m_counts[f * LEVELS + l] + ALPHA) / denom_m).collect())
187        .collect();
188    let u: Vec<Vec<f32>> = (0..n_fields)
189        .map(|f| (0..LEVELS).map(|l| (u_counts[f * LEVELS + l] + ALPHA) / denom_u).collect())
190        .collect();
191
192    let n_total = (total_match + total_nonmatch).max(1.0_f32);
193    let lambda  = (total_match / n_total).max(0.001_f32).min(0.999_f32);
194    let log_prior_odds = (lambda / (1.0 - lambda)).ln();
195
196    ModelParams { m, u, log_prior_odds, upper_threshold: 0.9, lower_threshold: 0.1 }
197}
198
199#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
200fn em_converged(old: &ModelParams, new: &ModelParams, n_fields: usize) -> bool {
201    const TOL: f32 = 1e-6;
202    const LEVELS: usize = 4;
203    let mut max_delta = 0.0_f32;
204    for f in 0..n_fields {
205        for l in 0..LEVELS {
206            let dm = (old.m[f][l] - new.m[f][l]).abs();
207            let du = (old.u[f][l] - new.u[f][l]).abs();
208            max_delta = max_delta.max(dm).max(du);
209        }
210    }
211    max_delta < TOL
212}
213
214// ── Unit tests ────────────────────────────────────────────────────────────────
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use zer_core::{
220        comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector},
221        scoring::{MatchBand, ModelParams},
222    };
223
224    fn uniform_params(n_fields: usize) -> ModelParams {
225        ModelParams {
226            m:               vec![vec![0.05, 0.10, 0.15, 0.70]; n_fields],
227            u:               vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
228            log_prior_odds:  0.0,
229            upper_threshold: 0.9,
230            lower_threshold: 0.1,
231        }
232    }
233
234    fn all_exact_vector(n_fields: usize) -> ComparisonVector {
235        ComparisonVector::new(1, 2, vec![ComparisonLevel::Exact; n_fields])
236    }
237
238    fn all_none_vector(n_fields: usize) -> ComparisonVector {
239        ComparisonVector::new(3, 4, vec![ComparisonLevel::None; n_fields])
240    }
241
242    fn separable_batch(n_matches: usize, n_nonmatches: usize, n_fields: usize) -> ComparisonBatch {
243        let mut v = Vec::with_capacity(n_matches + n_nonmatches);
244        for i in 0..n_matches as u64 {
245            v.push(ComparisonVector::new(i * 2, i * 2 + 1, vec![ComparisonLevel::Exact; n_fields]));
246        }
247        let off = (n_matches as u64) * 2;
248        for i in 0..n_nonmatches as u64 {
249            v.push(ComparisonVector::new(off + i * 2, off + i * 2 + 1, vec![ComparisonLevel::None; n_fields]));
250        }
251        ComparisonBatch::from_vectors(&v)
252    }
253
254    #[test]
255    fn score_exact_match_gives_high_probability() {
256        let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
257        let params = uniform_params(3);
258        let v      = all_exact_vector(3);
259        let pair   = scorer.score(&v, &params);
260
261        assert!(pair.match_probability > 0.9,
262            "all-Exact vector should have high match_probability, got {}", pair.match_probability);
263        assert_eq!(pair.band, MatchBand::AutoMatch);
264    }
265
266    #[test]
267    fn score_none_gives_low_probability() {
268        let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
269        let params = uniform_params(3);
270        let v      = all_none_vector(3);
271        let pair   = scorer.score(&v, &params);
272
273        assert!(pair.match_probability < 0.1,
274            "all-None vector should have low match_probability, got {}", pair.match_probability);
275        assert_eq!(pair.band, MatchBand::AutoReject);
276    }
277
278    #[test]
279    fn score_batch_matches_individual_scores() {
280        let scorer  = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
281        let params  = uniform_params(4);
282        let vectors = vec![
283            all_exact_vector(4),
284            all_none_vector(4),
285            ComparisonVector::new(5, 6, vec![
286                ComparisonLevel::Exact,
287                ComparisonLevel::None,
288                ComparisonLevel::Close,
289                ComparisonLevel::Partial,
290            ]),
291        ];
292        let batch = ComparisonBatch::from_vectors(&vectors);
293        let batch_results = scorer.score_batch(&batch, &params);
294
295        for (v, br) in vectors.iter().zip(batch_results.iter()) {
296            let single = scorer.score(v, &params);
297            assert!(
298                (single.match_probability - br.match_probability).abs() < 1e-6,
299                "batch and individual scores must agree"
300            );
301        }
302    }
303
304    #[test]
305    fn estimate_params_converges_on_separable_data() {
306        let scorer   = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
307        let n_fields = 4;
308        let batch    = separable_batch(200, 1_000, n_fields);
309
310        let params = scorer.estimate_params(&batch, None, 30)
311            .expect("EM should not return an error");
312
313        for f in 0..n_fields {
314            assert!(params.m[f][3] > params.u[f][3],
315                "m[Exact] should exceed u[Exact] for separable data (field {f})");
316        }
317    }
318
319    #[test]
320    fn estimate_params_returns_error_on_empty_input() {
321        let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
322        let batch  = ComparisonBatch::new(0, 0, vec![]);
323        let result = scorer.estimate_params(&batch, None, 10);
324        assert!(result.is_err(), "empty input should return an error");
325    }
326
327    #[test]
328    fn weight_table_is_consistent_with_params() {
329        use crate::soa::build_weight_table;
330
331        let params = uniform_params(3);
332        let table  = build_weight_table(&params);
333
334        let weight_exact = table[0 * 4 + 3];
335        let expected     = (0.70_f32 / 0.05_f32).ln();
336        assert!(
337            (weight_exact - expected).abs() < 1e-5,
338            "weight_table Exact entry mismatch: {weight_exact} vs {expected}"
339        );
340    }
341
342    #[test]
343    fn em_cpu_path_correct_below_threshold() {
344        let batch = separable_batch(200, 800, 4);
345        assert!(batch.n_pairs < EM_GPU_MIN_PAIRS);
346
347        let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
348        let params = scorer.estimate_params(&batch, None, 30).unwrap();
349        for f in 0..4 {
350            assert!(params.m[f][3] > params.u[f][3], "field {f}: m[Exact] must exceed u[Exact]");
351        }
352    }
353
354    #[cfg(feature = "cuda")]
355    #[test]
356    fn em_gpu_path_correct_above_threshold() {
357        let n_fields     = 4;
358        let n_matches    = EM_GPU_MIN_PAIRS / 5;
359        let n_nonmatches = EM_GPU_MIN_PAIRS;
360        let batch        = separable_batch(n_matches, n_nonmatches, n_fields);
361        assert!(batch.n_pairs >= EM_GPU_MIN_PAIRS);
362
363        let params = gpu_em_estimate(&DeviceBackend::auto_detect(), &batch, None, 50)
364            .expect("gpu_em_estimate must not fail");
365        for f in 0..n_fields {
366            assert!(params.m[f][3] > params.u[f][3], "field {f}: m[Exact] must exceed u[Exact]");
367        }
368    }
369
370    #[cfg(feature = "cuda")]
371    #[test]
372    fn em_gpu_cpu_agree_on_key_parameters() {
373        let n_fields     = 4;
374        let n_matches    = EM_GPU_MIN_PAIRS / 5;
375        let n_nonmatches = EM_GPU_MIN_PAIRS;
376        let batch        = separable_batch(n_matches, n_nonmatches, n_fields);
377        assert!(batch.n_pairs >= EM_GPU_MIN_PAIRS);
378
379        let cpu_params = gpu_em_estimate(&DeviceBackend::cpu(), &batch, None, 50).unwrap();
380        let gpu_params = gpu_em_estimate(&DeviceBackend::auto_detect(), &batch, None, 50).unwrap();
381
382        for f in 0..n_fields {
383            assert!(cpu_params.m[f][3] > cpu_params.u[f][3],
384                "CPU path field {f}: m[Exact] must exceed u[Exact]");
385            assert!(gpu_params.m[f][3] > gpu_params.u[f][3],
386                "GPU path field {f}: m[Exact] must exceed u[Exact]");
387
388            // Both paths must agree that Exact is a strong match signal.
389            // Allow ≤ 0.15 absolute difference in m[Exact] and u[Exact].
390            let dm_exact = (cpu_params.m[f][3] - gpu_params.m[f][3]).abs();
391            let du_exact = (cpu_params.u[f][3] - gpu_params.u[f][3]).abs();
392            assert!(dm_exact < 0.15,
393                "field {f}: CPU/GPU m[Exact] differ by {dm_exact:.4} (cpu={:.4}, gpu={:.4})",
394                cpu_params.m[f][3], gpu_params.m[f][3]);
395            assert!(du_exact < 0.15,
396                "field {f}: CPU/GPU u[Exact] differ by {du_exact:.4} (cpu={:.4}, gpu={:.4})",
397                cpu_params.u[f][3], gpu_params.u[f][3]);
398        }
399
400        // Both paths must produce a negative log_prior_odds for a 1:5 match rate.
401        assert!(cpu_params.log_prior_odds < 0.0,
402            "CPU log_prior_odds should be negative for rare matches: {}", cpu_params.log_prior_odds);
403        assert!(gpu_params.log_prior_odds < 0.0,
404            "GPU log_prior_odds should be negative for rare matches: {}", gpu_params.log_prior_odds);
405        let dlpo = (cpu_params.log_prior_odds - gpu_params.log_prior_odds).abs();
406        assert!(dlpo < 1.0,
407            "log_prior_odds differ too much: cpu={:.4}, gpu={:.4}",
408            cpu_params.log_prior_odds, gpu_params.log_prior_odds);
409    }
410
411    #[test]
412    fn em_cpu_log_prior_odds_tracks_match_rate() {
413        // 1 match in 10 pairs → lambda ≈ 0.1 → log_prior_odds ≈ ln(0.1/0.9) ≈ -2.2
414        let n_fields = 2;
415        let batch    = separable_batch(100, 900, n_fields);
416        let scorer   = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
417        let params   = scorer.estimate_params(&batch, None, 50).unwrap();
418
419        assert!(params.log_prior_odds < 0.0,
420            "log_prior_odds must be negative for 10% match rate: {}", params.log_prior_odds);
421        assert!(params.log_prior_odds > -5.0,
422            "log_prior_odds too negative for 10% match rate: {}", params.log_prior_odds);
423    }
424
425    #[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
426    #[test]
427    fn em_normalize_updates_log_prior_odds() {
428        // Verify em_normalize computes log_prior_odds from total_match/total_nonmatch.
429        // With total_match=100, total_nonmatch=900, lambda=0.1, log_prior_odds≈-2.2.
430        let m_counts = vec![25.0_f32, 25.0, 25.0, 25.0];  // uniform over levels
431        let u_counts = vec![225.0_f32, 225.0, 225.0, 225.0];
432        let total_match    = 100.0_f32;
433        let total_nonmatch = 900.0_f32;
434        let params = em_normalize(&m_counts, &u_counts, total_match, total_nonmatch, 1);
435
436        let expected_lpo = (0.1_f32 / 0.9_f32).ln();
437        assert!(
438            (params.log_prior_odds - expected_lpo).abs() < 0.01,
439            "log_prior_odds mismatch: got {:.4}, expected {:.4}",
440            params.log_prior_odds, expected_lpo
441        );
442    }
443
444    #[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
445    #[test]
446    fn em_converged_uses_raw_delta() {
447        let n_fields = 2;
448
449        // Params that differ by < 1e-6, should be considered converged.
450        let p1 = ModelParams {
451            m:               vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
452            u:               vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
453            log_prior_odds:  -2.0,
454            upper_threshold: 0.9,
455            lower_threshold: 0.1,
456        };
457        let mut p2 = p1.clone();
458        p2.m[0][3] += 5e-7;  // tiny delta
459        assert!(em_converged(&p1, &p2, n_fields), "should converge for delta < 1e-6");
460
461        // Params that differ by > 1e-6, should not converge.
462        let mut p3 = p1.clone();
463        p3.m[0][3] += 2e-6;
464        assert!(!em_converged(&p1, &p3, n_fields), "should not converge for delta > 1e-6");
465    }
466}