zer_compute/kernels/em_reduce.rs
1use crate::kernel::Kernel;
2
3/// Marker for the GPU M-step reduction kernel.
4///
5/// One `EmReduce` dispatch corresponds to one M-step of the EM algorithm:
6/// given `match_probs` from the E-step it accumulates weighted level counts
7/// that the caller normalises into updated `m`/`u` probability tables.
8pub struct EmReduce;
9
10/// Input to one M-step reduction pass.
11pub struct EmReduceInput<'a> {
12 /// P(match | vector) for each pair, E-step output, computed on CPU.
13 pub match_probs: &'a [f32],
14 /// Field-major comparison levels: `levels[field * n_pairs + pair]`, values 0–3.
15 pub comparison_levels: &'a [u32],
16 pub n_pairs: usize,
17 pub n_fields: usize,
18}
19
20/// Raw M-step counts returned from one reduction pass.
21///
22/// Normalize to get updated Fellegi-Sunter probability tables:
23/// ```text
24/// m[f][l] = (m_counts[f*4 + l] + smoothing) / (total_match + 4*smoothing)
25/// u[f][l] = (u_counts[f*4 + l] + smoothing) / (total_nonmatch + 4*smoothing)
26/// ```
27pub struct EmReduceOutput {
28 /// Unnormalized m-counts: `m_counts[f*4 + l] = Σ_pairs P(match) × 1[level==l]`.
29 /// Length is `n_fields * 4`.
30 pub m_counts: Vec<f32>,
31 /// Unnormalized u-counts: `u_counts[f*4 + l] = Σ_pairs P(nonmatch) × 1[level==l]`.
32 /// Length is `n_fields * 4`.
33 pub u_counts: Vec<f32>,
34 /// Σ P(match) across all pairs.
35 pub total_match: f32,
36 /// Σ (1 - P(match)) across all pairs.
37 pub total_nonmatch: f32,
38}
39
40impl Kernel for EmReduce {
41 type Input<'a> = EmReduceInput<'a>;
42 type Output = EmReduceOutput;
43}