Skip to main content

zer_compute/
soa.rs

1/// Fixed stride per string in the interleaved byte buffers.
2/// Used by `BatchSizer` to estimate buffer sizes.
3pub const STRING_STRIDE: usize = 64;
4
5/// Pre-compute `ln(m[f][l] / u[f][l])` weight table for GPU upload.
6pub fn build_weight_table(params: &zer_core::scoring::ModelParams) -> Vec<f32> {
7    let n_fields = params.m.len();
8    let n_levels = if n_fields > 0 { params.m[0].len() } else { 0 };
9    let mut table = Vec::with_capacity(n_fields * n_levels);
10    for f in 0..n_fields {
11        for l in 0..n_levels {
12            let m = params.m[f][l].max(1e-9_f32);
13            let u = params.u[f][l].max(1e-9_f32);
14            table.push((m / u).ln());
15        }
16    }
17    table
18}
19
20#[cfg(test)]
21mod tests {
22    use super::*;
23
24    #[test]
25    fn weight_table_log_ratios_are_finite() {
26        use zer_core::scoring::ModelParams;
27        let params = ModelParams {
28            m: vec![vec![0.05, 0.10, 0.15, 0.70]; 3],
29            u: vec![vec![0.70, 0.15, 0.10, 0.05]; 3],
30            log_prior_odds: 0.0,
31            upper_threshold: 0.9,
32            lower_threshold: 0.1,
33        };
34        let table = build_weight_table(&params);
35        assert_eq!(table.len(), 3 * 4);
36        for v in &table {
37            assert!(v.is_finite(), "weight table must not contain NaN/Inf: {v}");
38        }
39        assert!(
40            table[0 * 4 + 3] > 0.0,
41            "exact match should have positive weight"
42        );
43        assert!(
44            table[0 * 4 + 0] < 0.0,
45            "none match should have negative weight"
46        );
47    }
48}