Skip to main content

zer_compute/kernels/
em_reduce.rs

1use crate::kernel::Kernel;
2
3/// Marker for the GPU M-step reduction kernel.
4///
5/// One `EmReduce` dispatch corresponds to one M-step of the EM algorithm:
6/// given `match_probs` from the E-step it accumulates weighted level counts
7/// that the caller normalises into updated `m`/`u` probability tables.
8pub struct EmReduce;
9
10/// Input to one M-step reduction pass.
11pub struct EmReduceInput<'a> {
12    /// P(match | vector) for each pair, E-step output, computed on CPU.
13    pub match_probs: &'a [f32],
14    /// Field-major comparison levels: `levels[field * n_pairs + pair]`, values 0–3.
15    pub comparison_levels: &'a [u32],
16    pub n_pairs:  usize,
17    pub n_fields: usize,
18}
19
20/// Raw M-step counts returned from one reduction pass.
21///
22/// Normalize to get updated Fellegi-Sunter probability tables:
23/// ```text
24/// m[f][l] = (m_counts[f*4 + l] + smoothing) / (total_match    + 4*smoothing)
25/// u[f][l] = (u_counts[f*4 + l] + smoothing) / (total_nonmatch + 4*smoothing)
26/// ```
27pub struct EmReduceOutput {
28    /// Unnormalized m-counts: `m_counts[f*4 + l] = Σ_pairs P(match) × 1[level==l]`.
29    /// Length is `n_fields * 4`.
30    pub m_counts: Vec<f32>,
31    /// Unnormalized u-counts: `u_counts[f*4 + l] = Σ_pairs P(nonmatch) × 1[level==l]`.
32    /// Length is `n_fields * 4`.
33    pub u_counts: Vec<f32>,
34    /// Σ P(match) across all pairs.
35    pub total_match: f32,
36    /// Σ (1 - P(match)) across all pairs.
37    pub total_nonmatch: f32,
38}
39
40impl Kernel for EmReduce {
41    type Input<'a> = EmReduceInput<'a>;
42    type Output    = EmReduceOutput;
43}