zer_compute/backend/cpu/launch/
em_reduce.rs1use crate::{
4 backend::cpu::device::CpuDevice,
5 error::GpuError,
6 kernel::KernelDispatch,
7 kernels::em_reduce::{EmReduce, EmReduceInput, EmReduceOutput},
8};
9
10const NUM_LEVELS: usize = 4;
11
12impl KernelDispatch<EmReduce> for CpuDevice {
13 fn dispatch(&self, input: EmReduceInput<'_>) -> Result<EmReduceOutput, GpuError> {
14 let EmReduceInput { match_probs, comparison_levels, n_pairs, n_fields } = input;
15
16 if n_pairs == 0 || n_fields == 0 {
17 let zeros = vec![0.0f32; n_fields * NUM_LEVELS];
18 return Ok(EmReduceOutput {
19 m_counts: zeros.clone(), u_counts: zeros,
20 total_match: 0.0, total_nonmatch: 0.0,
21 });
22 }
23
24 let mut m_counts = vec![0.0f32; n_fields * NUM_LEVELS];
25 let mut u_counts = vec![0.0f32; n_fields * NUM_LEVELS];
26 let mut total_match = 0.0f32;
27 let mut total_nonmatch = 0.0f32;
28
29 for p in 0..n_pairs {
30 let pm = match_probs[p];
31 let pnm = 1.0 - pm;
32 total_match += pm;
33 total_nonmatch += pnm;
34
35 for f in 0..n_fields {
36 let level = comparison_levels[f * n_pairs + p] as usize;
37 let idx = f * NUM_LEVELS + level;
38 m_counts[idx] += pm;
39 u_counts[idx] += pnm;
40 }
41 }
42
43 Ok(EmReduceOutput { m_counts, u_counts, total_match, total_nonmatch })
44 }
45}