Skip to main content

zer_compute/backend/cpu/launch/
em_reduce.rs

1//! CPU dispatch for [`EmReduce`], pure-CPU M-step reduction.
2
3use crate::{
4    backend::cpu::device::CpuDevice,
5    error::GpuError,
6    kernel::KernelDispatch,
7    kernels::em_reduce::{EmReduce, EmReduceInput, EmReduceOutput},
8};
9
10const NUM_LEVELS: usize = 4;
11
12impl KernelDispatch<EmReduce> for CpuDevice {
13    fn dispatch(&self, input: EmReduceInput<'_>) -> Result<EmReduceOutput, GpuError> {
14        let EmReduceInput { match_probs, comparison_levels, n_pairs, n_fields } = input;
15
16        if n_pairs == 0 || n_fields == 0 {
17            let zeros = vec![0.0f32; n_fields * NUM_LEVELS];
18            return Ok(EmReduceOutput {
19                m_counts: zeros.clone(), u_counts: zeros,
20                total_match: 0.0, total_nonmatch: 0.0,
21            });
22        }
23
24        let mut m_counts       = vec![0.0f32; n_fields * NUM_LEVELS];
25        let mut u_counts       = vec![0.0f32; n_fields * NUM_LEVELS];
26        let mut total_match    = 0.0f32;
27        let mut total_nonmatch = 0.0f32;
28
29        for p in 0..n_pairs {
30            let pm  = match_probs[p];
31            let pnm = 1.0 - pm;
32            total_match    += pm;
33            total_nonmatch += pnm;
34
35            for f in 0..n_fields {
36                let level = comparison_levels[f * n_pairs + p] as usize;
37                let idx   = f * NUM_LEVELS + level;
38                m_counts[idx] += pm;
39                u_counts[idx] += pnm;
40            }
41        }
42
43        Ok(EmReduceOutput { m_counts, u_counts, total_match, total_nonmatch })
44    }
45}