Skip to main content

second_brain_api/eval/
bootstrap.rs

1use rand::Rng;
2use rand::SeedableRng;
3use rand::rngs::StdRng;
4
5pub fn paired_bootstrap_ci(deltas: &[f32], n_resamples: usize, ci: f64, seed: u64) -> (f32, f32) {
6    if deltas.is_empty() || n_resamples == 0 {
7        return (0.0, 0.0);
8    }
9
10    let mut rng = StdRng::seed_from_u64(seed);
11    let n = deltas.len();
12    let mut means: Vec<f32> = Vec::with_capacity(n_resamples);
13
14    for _ in 0..n_resamples {
15        let mut sum = 0.0_f64;
16        for _ in 0..n {
17            let idx = rng.random_range(0..n);
18            sum += deltas[idx] as f64;
19        }
20        means.push((sum / n as f64) as f32);
21    }
22
23    means.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
24
25    let tail = (1.0 - ci) / 2.0;
26    let lo = percentile(&means, tail);
27    let hi = percentile(&means, 1.0 - tail);
28    (lo, hi)
29}
30
31fn percentile(sorted: &[f32], p: f64) -> f32 {
32    let last = sorted.len() - 1;
33    let rank = (p * last as f64).round() as usize;
34    sorted[rank.min(last)]
35}
36
37#[cfg(test)]
38mod tests {
39    use super::*;
40
41    #[test]
42    fn all_positive_deltas_collapse_to_one() {
43        let deltas = vec![1.0_f32; 50];
44        let (lo, hi) = paired_bootstrap_ci(&deltas, 2000, 0.95, 7);
45        assert!((lo - 1.0).abs() < 1e-6, "lo was {lo}");
46        assert!((hi - 1.0).abs() < 1e-6, "hi was {hi}");
47    }
48
49    #[test]
50    fn all_zero_deltas_give_zero_interval() {
51        let deltas = vec![0.0_f32; 50];
52        let (lo, hi) = paired_bootstrap_ci(&deltas, 2000, 0.95, 7);
53        assert_eq!(lo, 0.0);
54        assert_eq!(hi, 0.0);
55    }
56
57    #[test]
58    fn mixed_deltas_bracket_the_true_mean() {
59        let deltas: Vec<f32> = (0..40).map(|i| if i % 2 == 0 { 1.0 } else { -0.5 }).collect();
60        let mean = deltas.iter().sum::<f32>() / deltas.len() as f32;
61        let (lo, hi) = paired_bootstrap_ci(&deltas, 2000, 0.95, 7);
62        assert!(lo < hi, "expected lo < hi, got ({lo}, {hi})");
63        assert!(lo <= mean && mean <= hi, "mean {mean} not in [{lo}, {hi}]");
64    }
65
66    #[test]
67    fn is_deterministic_for_a_fixed_seed() {
68        let deltas: Vec<f32> = (0..30).map(|i| (i as f32) * 0.01 - 0.15).collect();
69        let a = paired_bootstrap_ci(&deltas, 2000, 0.95, 42);
70        let b = paired_bootstrap_ci(&deltas, 2000, 0.95, 42);
71        assert_eq!(a, b);
72    }
73}