Skip to main content

zer_compute/backend/cpu/launch/
em_reduce.rs

1//! CPU dispatch for [`EmReduce`], pure-CPU M-step reduction.
2
3use crate::{
4    backend::cpu::device::CpuDevice,
5    error::GpuError,
6    kernel::KernelDispatch,
7    kernels::em_reduce::{EmReduce, EmReduceInput, EmReduceOutput},
8};
9
10const NUM_LEVELS: usize = 4;
11
12impl KernelDispatch<EmReduce> for CpuDevice {
13    fn dispatch(&self, input: EmReduceInput<'_>) -> Result<EmReduceOutput, GpuError> {
14        let EmReduceInput {
15            match_probs,
16            comparison_levels,
17            n_pairs,
18            n_fields,
19        } = input;
20
21        if n_pairs == 0 || n_fields == 0 {
22            let zeros = vec![0.0f32; n_fields * NUM_LEVELS];
23            return Ok(EmReduceOutput {
24                m_counts: zeros.clone(),
25                u_counts: zeros,
26                total_match: 0.0,
27                total_nonmatch: 0.0,
28            });
29        }
30
31        let mut m_counts = vec![0.0f32; n_fields * NUM_LEVELS];
32        let mut u_counts = vec![0.0f32; n_fields * NUM_LEVELS];
33        let mut total_match = 0.0f32;
34        let mut total_nonmatch = 0.0f32;
35
36        for p in 0..n_pairs {
37            let pm = match_probs[p];
38            let pnm = 1.0 - pm;
39            total_match += pm;
40            total_nonmatch += pnm;
41
42            for f in 0..n_fields {
43                let level = comparison_levels[f * n_pairs + p] as usize;
44                let idx = f * NUM_LEVELS + level;
45                m_counts[idx] += pm;
46                u_counts[idx] += pnm;
47            }
48        }
49
50        Ok(EmReduceOutput {
51            m_counts,
52            u_counts,
53            total_match,
54            total_nonmatch,
55        })
56    }
57}