1use std::sync::Arc;
4
5use zer_core::{
6 comparison::{ComparisonBatch, ComparisonVector},
7 scoring::{ModelParams, ScoredPair},
8 traits::{Result, Scorer},
9};
10
11use crate::{
12 backend::{cpu::CpuFallbackScorer, DeviceBackend},
13 error::GpuError,
14};
15
16pub(crate) const EM_GPU_MIN_PAIRS: usize = 50_000;
18
19pub struct DeviceScorer {
20 backend: Arc<DeviceBackend>,
21 cpu_fallback: CpuFallbackScorer,
22}
23
24impl DeviceScorer {
25 pub fn new(backend: Arc<DeviceBackend>) -> Self {
26 Self { backend, cpu_fallback: CpuFallbackScorer }
27 }
28
29 pub fn backend_name(&self) -> &'static str {
30 self.backend.name()
31 }
32}
33
34impl Scorer for DeviceScorer {
35 fn score(&self, vector: &ComparisonVector, params: &ModelParams) -> ScoredPair {
36 self.cpu_fallback.score(vector, params)
37 }
38
39 fn score_batch(&self, batch: &ComparisonBatch, params: &ModelParams) -> Vec<ScoredPair> {
40 self.cpu_fallback.score_batch(batch, params)
41 }
42
43 fn estimate_params(
44 &self,
45 batch: &ComparisonBatch,
46 init: Option<ModelParams>,
47 max_iter: usize,
48 ) -> Result<ModelParams> {
49 if self.backend.is_accelerated() && batch.n_pairs >= EM_GPU_MIN_PAIRS {
50 let result = zer_prof::trace!("zer_compute::estimate_params_accelerated", {
51 gpu_em_estimate(&self.backend, batch, init.clone(), max_iter)
52 });
53 match result {
54 Ok(params) => {
55 tracing::info!(backend = %self.backend.name(), "EM converged via accelerated backend");
56 return Ok(params);
57 }
58 Err(e) => {
59 tracing::warn!(%e, backend = %self.backend.name(), "accelerated EM failed, falling back to CPU");
60 }
61 }
62 } else if self.backend.is_accelerated() {
63 tracing::debug!(
64 n_pairs = batch.n_pairs,
65 threshold = EM_GPU_MIN_PAIRS,
66 "EM: batch below GPU threshold, using CPU path"
67 );
68 }
69 self.cpu_fallback.estimate_params(batch, init, max_iter)
70 }
71}
72
73#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
76fn build_estep_weights(params: &ModelParams, n_fields: usize) -> Vec<f32> {
77 const LEVELS: usize = 4;
78 let mut w = Vec::with_capacity(n_fields * LEVELS);
79 for f in 0..n_fields {
80 for l in 0..LEVELS {
81 let m = params.m[f][l].max(1e-15_f32);
82 let u = params.u[f][l].max(1e-15_f32);
83 w.push((m / u).ln());
84 }
85 }
86 w
87}
88
89#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
95fn gpu_em_estimate(
96 backend: &DeviceBackend,
97 batch: &ComparisonBatch,
98 init: Option<ModelParams>,
99 max_iter: usize,
100) -> std::result::Result<ModelParams, GpuError> {
101 if batch.n_pairs == 0 {
102 return Err(GpuError::LaunchFailed("EM requires at least one comparison pair".into()));
103 }
104
105 if !backend.is_gpu() {
106 return crate::backend::cpu::cpu_estimate_params(batch, init, max_iter)
107 .map_err(|e| GpuError::LaunchFailed(e.to_string()));
108 }
109
110 let n_fields = batch.n_fields;
111 let n_pairs = batch.n_pairs;
112
113 let comparison_levels: Vec<u32> = batch.levels.iter().map(|&l| l as u32).collect();
115
116 let mut params = init.unwrap_or_else(|| {
117 let lambda = zer_compare::em::estimate_lambda(batch);
118 let log_prior_odds = (lambda / (1.0 - lambda)).ln();
119 ModelParams {
120 m: vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
121 u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
122 log_prior_odds,
123 upper_threshold: 0.9,
124 lower_threshold: 0.1,
125 }
126 });
127
128 let mut session = zer_prof::trace!("zer_compute::em_init_session", {
129 backend.em_init_session(&comparison_levels, n_pairs, n_fields)
130 })?;
131
132 let result: std::result::Result<ModelParams, GpuError> = (|| {
134 for _iter in 0..max_iter {
135 let weights = build_estep_weights(¶ms, n_fields);
136
137 let out = zer_prof::trace!("zer_compute::em_full_iteration", {
138 backend.em_run_iteration(&mut session, &weights, params.log_prior_odds)
139 })?;
140
141 let new_params = em_normalize(
142 &out.m_counts, &out.u_counts,
143 out.total_match, out.total_nonmatch,
144 n_fields,
145 );
146
147 if em_converged(¶ms, &new_params, n_fields) {
148 return Ok(new_params);
149 }
150 params = new_params;
151 }
152 Ok(params)
153 })();
154
155 backend.em_drop_session(session);
156 result
157}
158
159#[cfg(not(any(feature = "cuda", feature = "vulkan", feature = "avx2")))]
160fn gpu_em_estimate(
161 _backend: &DeviceBackend,
162 _batch: &ComparisonBatch,
163 _init: Option<ModelParams>,
164 _max_iter: usize,
165) -> std::result::Result<ModelParams, GpuError> {
166 Err(GpuError::BackendUnavailable(
167 "full-GPU EM requires the cuda or vulkan feature".into(),
168 ))
169}
170
171#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
172fn em_normalize(
173 m_counts: &[f32],
174 u_counts: &[f32],
175 total_match: f32,
176 total_nonmatch: f32,
177 n_fields: usize,
178) -> ModelParams {
179 const ALPHA: f32 = 1e-3;
180 const LEVELS: usize = 4;
181
182 let denom_m = (total_match + LEVELS as f32 * ALPHA).max(1e-9_f32);
183 let denom_u = (total_nonmatch + LEVELS as f32 * ALPHA).max(1e-9_f32);
184
185 let m: Vec<Vec<f32>> = (0..n_fields)
186 .map(|f| (0..LEVELS).map(|l| (m_counts[f * LEVELS + l] + ALPHA) / denom_m).collect())
187 .collect();
188 let u: Vec<Vec<f32>> = (0..n_fields)
189 .map(|f| (0..LEVELS).map(|l| (u_counts[f * LEVELS + l] + ALPHA) / denom_u).collect())
190 .collect();
191
192 let n_total = (total_match + total_nonmatch).max(1.0_f32);
193 let lambda = (total_match / n_total).max(0.001_f32).min(0.999_f32);
194 let log_prior_odds = (lambda / (1.0 - lambda)).ln();
195
196 ModelParams { m, u, log_prior_odds, upper_threshold: 0.9, lower_threshold: 0.1 }
197}
198
199#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
200fn em_converged(old: &ModelParams, new: &ModelParams, n_fields: usize) -> bool {
201 const TOL: f32 = 1e-6;
202 const LEVELS: usize = 4;
203 let mut max_delta = 0.0_f32;
204 for f in 0..n_fields {
205 for l in 0..LEVELS {
206 let dm = (old.m[f][l] - new.m[f][l]).abs();
207 let du = (old.u[f][l] - new.u[f][l]).abs();
208 max_delta = max_delta.max(dm).max(du);
209 }
210 }
211 max_delta < TOL
212}
213
214#[cfg(test)]
217mod tests {
218 use super::*;
219 use zer_core::{
220 comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector},
221 scoring::{MatchBand, ModelParams},
222 };
223
224 fn uniform_params(n_fields: usize) -> ModelParams {
225 ModelParams {
226 m: vec![vec![0.05, 0.10, 0.15, 0.70]; n_fields],
227 u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
228 log_prior_odds: 0.0,
229 upper_threshold: 0.9,
230 lower_threshold: 0.1,
231 }
232 }
233
234 fn all_exact_vector(n_fields: usize) -> ComparisonVector {
235 ComparisonVector::new(1, 2, vec![ComparisonLevel::Exact; n_fields])
236 }
237
238 fn all_none_vector(n_fields: usize) -> ComparisonVector {
239 ComparisonVector::new(3, 4, vec![ComparisonLevel::None; n_fields])
240 }
241
242 fn separable_batch(n_matches: usize, n_nonmatches: usize, n_fields: usize) -> ComparisonBatch {
243 let mut v = Vec::with_capacity(n_matches + n_nonmatches);
244 for i in 0..n_matches as u64 {
245 v.push(ComparisonVector::new(i * 2, i * 2 + 1, vec![ComparisonLevel::Exact; n_fields]));
246 }
247 let off = (n_matches as u64) * 2;
248 for i in 0..n_nonmatches as u64 {
249 v.push(ComparisonVector::new(off + i * 2, off + i * 2 + 1, vec![ComparisonLevel::None; n_fields]));
250 }
251 ComparisonBatch::from_vectors(&v)
252 }
253
254 #[test]
255 fn score_exact_match_gives_high_probability() {
256 let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
257 let params = uniform_params(3);
258 let v = all_exact_vector(3);
259 let pair = scorer.score(&v, ¶ms);
260
261 assert!(pair.match_probability > 0.9,
262 "all-Exact vector should have high match_probability, got {}", pair.match_probability);
263 assert_eq!(pair.band, MatchBand::AutoMatch);
264 }
265
266 #[test]
267 fn score_none_gives_low_probability() {
268 let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
269 let params = uniform_params(3);
270 let v = all_none_vector(3);
271 let pair = scorer.score(&v, ¶ms);
272
273 assert!(pair.match_probability < 0.1,
274 "all-None vector should have low match_probability, got {}", pair.match_probability);
275 assert_eq!(pair.band, MatchBand::AutoReject);
276 }
277
278 #[test]
279 fn score_batch_matches_individual_scores() {
280 let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
281 let params = uniform_params(4);
282 let vectors = vec![
283 all_exact_vector(4),
284 all_none_vector(4),
285 ComparisonVector::new(5, 6, vec![
286 ComparisonLevel::Exact,
287 ComparisonLevel::None,
288 ComparisonLevel::Close,
289 ComparisonLevel::Partial,
290 ]),
291 ];
292 let batch = ComparisonBatch::from_vectors(&vectors);
293 let batch_results = scorer.score_batch(&batch, ¶ms);
294
295 for (v, br) in vectors.iter().zip(batch_results.iter()) {
296 let single = scorer.score(v, ¶ms);
297 assert!(
298 (single.match_probability - br.match_probability).abs() < 1e-6,
299 "batch and individual scores must agree"
300 );
301 }
302 }
303
304 #[test]
305 fn estimate_params_converges_on_separable_data() {
306 let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
307 let n_fields = 4;
308 let batch = separable_batch(200, 1_000, n_fields);
309
310 let params = scorer.estimate_params(&batch, None, 30)
311 .expect("EM should not return an error");
312
313 for f in 0..n_fields {
314 assert!(params.m[f][3] > params.u[f][3],
315 "m[Exact] should exceed u[Exact] for separable data (field {f})");
316 }
317 }
318
319 #[test]
320 fn estimate_params_returns_error_on_empty_input() {
321 let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
322 let batch = ComparisonBatch::new(0, 0, vec![]);
323 let result = scorer.estimate_params(&batch, None, 10);
324 assert!(result.is_err(), "empty input should return an error");
325 }
326
327 #[test]
328 fn weight_table_is_consistent_with_params() {
329 use crate::soa::build_weight_table;
330
331 let params = uniform_params(3);
332 let table = build_weight_table(¶ms);
333
334 let weight_exact = table[0 * 4 + 3];
335 let expected = (0.70_f32 / 0.05_f32).ln();
336 assert!(
337 (weight_exact - expected).abs() < 1e-5,
338 "weight_table Exact entry mismatch: {weight_exact} vs {expected}"
339 );
340 }
341
342 #[test]
343 fn em_cpu_path_correct_below_threshold() {
344 let batch = separable_batch(200, 800, 4);
345 assert!(batch.n_pairs < EM_GPU_MIN_PAIRS);
346
347 let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
348 let params = scorer.estimate_params(&batch, None, 30).unwrap();
349 for f in 0..4 {
350 assert!(params.m[f][3] > params.u[f][3], "field {f}: m[Exact] must exceed u[Exact]");
351 }
352 }
353
354 #[cfg(feature = "cuda")]
355 #[test]
356 fn em_gpu_path_correct_above_threshold() {
357 let n_fields = 4;
358 let n_matches = EM_GPU_MIN_PAIRS / 5;
359 let n_nonmatches = EM_GPU_MIN_PAIRS;
360 let batch = separable_batch(n_matches, n_nonmatches, n_fields);
361 assert!(batch.n_pairs >= EM_GPU_MIN_PAIRS);
362
363 let params = gpu_em_estimate(&DeviceBackend::auto_detect(), &batch, None, 50)
364 .expect("gpu_em_estimate must not fail");
365 for f in 0..n_fields {
366 assert!(params.m[f][3] > params.u[f][3], "field {f}: m[Exact] must exceed u[Exact]");
367 }
368 }
369
370 #[cfg(feature = "cuda")]
371 #[test]
372 fn em_gpu_cpu_agree_on_key_parameters() {
373 let n_fields = 4;
374 let n_matches = EM_GPU_MIN_PAIRS / 5;
375 let n_nonmatches = EM_GPU_MIN_PAIRS;
376 let batch = separable_batch(n_matches, n_nonmatches, n_fields);
377 assert!(batch.n_pairs >= EM_GPU_MIN_PAIRS);
378
379 let cpu_params = gpu_em_estimate(&DeviceBackend::cpu(), &batch, None, 50).unwrap();
380 let gpu_params = gpu_em_estimate(&DeviceBackend::auto_detect(), &batch, None, 50).unwrap();
381
382 for f in 0..n_fields {
383 assert!(cpu_params.m[f][3] > cpu_params.u[f][3],
384 "CPU path field {f}: m[Exact] must exceed u[Exact]");
385 assert!(gpu_params.m[f][3] > gpu_params.u[f][3],
386 "GPU path field {f}: m[Exact] must exceed u[Exact]");
387
388 let dm_exact = (cpu_params.m[f][3] - gpu_params.m[f][3]).abs();
391 let du_exact = (cpu_params.u[f][3] - gpu_params.u[f][3]).abs();
392 assert!(dm_exact < 0.15,
393 "field {f}: CPU/GPU m[Exact] differ by {dm_exact:.4} (cpu={:.4}, gpu={:.4})",
394 cpu_params.m[f][3], gpu_params.m[f][3]);
395 assert!(du_exact < 0.15,
396 "field {f}: CPU/GPU u[Exact] differ by {du_exact:.4} (cpu={:.4}, gpu={:.4})",
397 cpu_params.u[f][3], gpu_params.u[f][3]);
398 }
399
400 assert!(cpu_params.log_prior_odds < 0.0,
402 "CPU log_prior_odds should be negative for rare matches: {}", cpu_params.log_prior_odds);
403 assert!(gpu_params.log_prior_odds < 0.0,
404 "GPU log_prior_odds should be negative for rare matches: {}", gpu_params.log_prior_odds);
405 let dlpo = (cpu_params.log_prior_odds - gpu_params.log_prior_odds).abs();
406 assert!(dlpo < 1.0,
407 "log_prior_odds differ too much: cpu={:.4}, gpu={:.4}",
408 cpu_params.log_prior_odds, gpu_params.log_prior_odds);
409 }
410
411 #[test]
412 fn em_cpu_log_prior_odds_tracks_match_rate() {
413 let n_fields = 2;
415 let batch = separable_batch(100, 900, n_fields);
416 let scorer = DeviceScorer::new(Arc::new(DeviceBackend::cpu()));
417 let params = scorer.estimate_params(&batch, None, 50).unwrap();
418
419 assert!(params.log_prior_odds < 0.0,
420 "log_prior_odds must be negative for 10% match rate: {}", params.log_prior_odds);
421 assert!(params.log_prior_odds > -5.0,
422 "log_prior_odds too negative for 10% match rate: {}", params.log_prior_odds);
423 }
424
425 #[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
426 #[test]
427 fn em_normalize_updates_log_prior_odds() {
428 let m_counts = vec![25.0_f32, 25.0, 25.0, 25.0]; let u_counts = vec![225.0_f32, 225.0, 225.0, 225.0];
432 let total_match = 100.0_f32;
433 let total_nonmatch = 900.0_f32;
434 let params = em_normalize(&m_counts, &u_counts, total_match, total_nonmatch, 1);
435
436 let expected_lpo = (0.1_f32 / 0.9_f32).ln();
437 assert!(
438 (params.log_prior_odds - expected_lpo).abs() < 0.01,
439 "log_prior_odds mismatch: got {:.4}, expected {:.4}",
440 params.log_prior_odds, expected_lpo
441 );
442 }
443
444 #[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
445 #[test]
446 fn em_converged_uses_raw_delta() {
447 let n_fields = 2;
448
449 let p1 = ModelParams {
451 m: vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
452 u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
453 log_prior_odds: -2.0,
454 upper_threshold: 0.9,
455 lower_threshold: 0.1,
456 };
457 let mut p2 = p1.clone();
458 p2.m[0][3] += 5e-7; assert!(em_converged(&p1, &p2, n_fields), "should converge for delta < 1e-6");
460
461 let mut p3 = p1.clone();
463 p3.m[0][3] += 2e-6;
464 assert!(!em_converged(&p1, &p3, n_fields), "should not converge for delta > 1e-6");
465 }
466}