zer_compute/backend/cpu/launch/
em_reduce.rs1use crate::{
4 backend::cpu::device::CpuDevice,
5 error::GpuError,
6 kernel::KernelDispatch,
7 kernels::em_reduce::{EmReduce, EmReduceInput, EmReduceOutput},
8};
9
10const NUM_LEVELS: usize = 4;
11
12impl KernelDispatch<EmReduce> for CpuDevice {
13 fn dispatch(&self, input: EmReduceInput<'_>) -> Result<EmReduceOutput, GpuError> {
14 let EmReduceInput {
15 match_probs,
16 comparison_levels,
17 n_pairs,
18 n_fields,
19 } = input;
20
21 if n_pairs == 0 || n_fields == 0 {
22 let zeros = vec![0.0f32; n_fields * NUM_LEVELS];
23 return Ok(EmReduceOutput {
24 m_counts: zeros.clone(),
25 u_counts: zeros,
26 total_match: 0.0,
27 total_nonmatch: 0.0,
28 });
29 }
30
31 let mut m_counts = vec![0.0f32; n_fields * NUM_LEVELS];
32 let mut u_counts = vec![0.0f32; n_fields * NUM_LEVELS];
33 let mut total_match = 0.0f32;
34 let mut total_nonmatch = 0.0f32;
35
36 for p in 0..n_pairs {
37 let pm = match_probs[p];
38 let pnm = 1.0 - pm;
39 total_match += pm;
40 total_nonmatch += pnm;
41
42 for f in 0..n_fields {
43 let level = comparison_levels[f * n_pairs + p] as usize;
44 let idx = f * NUM_LEVELS + level;
45 m_counts[idx] += pm;
46 u_counts[idx] += pnm;
47 }
48 }
49
50 Ok(EmReduceOutput {
51 m_counts,
52 u_counts,
53 total_match,
54 total_nonmatch,
55 })
56 }
57}