Skip to main content

shadow_core/diff/
bootstrap.rs

1//! Bootstrap resampling for paired statistics.
2//!
3//! Given `n` paired observations and a statistic function, resample `n` pairs
4//! with replacement, apply the statistic, repeat `iterations` times (1000 by
5//! default), and return the 2.5 / 50 / 97.5 percentile of the resulting
6//! distribution.
7//!
8//! Every axis in the nine axes in README.md uses this primitive to turn a sample into a
9//! median + 95% CI — implemented once here instead of per-axis.
10
11use rand::rngs::StdRng;
12use rand::seq::SliceRandom;
13use rand::SeedableRng;
14
15/// Percentile-based 95% CI plus the sample median.
16#[derive(Debug, Clone, Copy, PartialEq)]
17pub struct CiResult {
18    /// Lower bound (2.5 percentile).
19    pub low: f64,
20    /// Median (50 percentile).
21    pub median: f64,
22    /// Upper bound (97.5 percentile).
23    pub high: f64,
24}
25
26/// Default number of bootstrap iterations. Set on the high side of the
27/// bias/variance tradeoff; pair-resamples are cheap (O(n) per iter) so
28/// 1000 is fine even for n=1k samples.
29pub const DEFAULT_ITERATIONS: usize = 1000;
30
31/// Bootstrap the statistic over paired samples.
32///
33/// - `baseline` and `candidate` must have equal length (caller enforces).
34/// - `statistic` maps (baseline_slice, candidate_slice) to a scalar.
35/// - `iterations` defaults to [`DEFAULT_ITERATIONS`] when 0 is passed.
36/// - `seed`: pass `Some(seed)` for reproducible tests, `None` for a
37///   random OS-seeded RNG.
38///
39/// Returns the 2.5 / 50 / 97.5 percentiles of the bootstrap distribution.
40pub fn paired_ci<F>(
41    baseline: &[f64],
42    candidate: &[f64],
43    statistic: F,
44    iterations: usize,
45    seed: Option<u64>,
46) -> CiResult
47where
48    F: Fn(&[f64], &[f64]) -> f64,
49{
50    let n = baseline.len();
51    // Precondition: callers pass paired slices. Enforced by assert in debug
52    // builds; release builds short-circuit on length mismatch to an empty
53    // result rather than unwinding. clippy's panic lint is suppressed
54    // because this is a programming-error guard, not user-visible
55    // behaviour.
56    #[allow(clippy::panic)]
57    {
58        if n != candidate.len() {
59            panic!("baseline and candidate must have equal length");
60        }
61    }
62    if n == 0 {
63        return CiResult {
64            low: 0.0,
65            median: 0.0,
66            high: 0.0,
67        };
68    }
69    let iterations = if iterations == 0 {
70        DEFAULT_ITERATIONS
71    } else {
72        iterations
73    };
74
75    let mut rng: StdRng = match seed {
76        Some(s) => StdRng::seed_from_u64(s),
77        None => StdRng::from_entropy(),
78    };
79
80    let mut samples = Vec::with_capacity(iterations);
81    let indices: Vec<usize> = (0..n).collect();
82    let mut b_buf = Vec::with_capacity(n);
83    let mut c_buf = Vec::with_capacity(n);
84    for _ in 0..iterations {
85        b_buf.clear();
86        c_buf.clear();
87        for _ in 0..n {
88            // n>0 is guaranteed above, so .choose() always returns Some.
89            // Using unwrap_or to avoid the clippy::expect_used lint on the
90            // happy path — the fallback is unreachable.
91            let i = *indices.choose(&mut rng).unwrap_or(&0);
92            b_buf.push(baseline[i]);
93            c_buf.push(candidate[i]);
94        }
95        samples.push(statistic(&b_buf, &c_buf));
96    }
97    samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
98    CiResult {
99        low: percentile(&samples, 2.5),
100        median: percentile(&samples, 50.0),
101        high: percentile(&samples, 97.5),
102    }
103}
104
105/// Compute the median of a slice of floats. Handles even-length lists by
106/// averaging the two central values. The input is cloned; the caller's
107/// slice is untouched.
108pub fn median(xs: &[f64]) -> f64 {
109    if xs.is_empty() {
110        return 0.0;
111    }
112    let mut sorted: Vec<f64> = xs.to_vec();
113    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
114    let n = sorted.len();
115    if n % 2 == 1 {
116        sorted[n / 2]
117    } else {
118        (sorted[n / 2 - 1] + sorted[n / 2]) / 2.0
119    }
120}
121
122/// Percentile `p` (0–100) of a pre-sorted slice.
123fn percentile(sorted: &[f64], p: f64) -> f64 {
124    if sorted.is_empty() {
125        return 0.0;
126    }
127    let n = sorted.len() as f64;
128    let rank = (p / 100.0) * (n - 1.0);
129    let lo = rank.floor() as usize;
130    let hi = rank.ceil() as usize;
131    if lo == hi {
132        sorted[lo]
133    } else {
134        let frac = rank - lo as f64;
135        sorted[lo] * (1.0 - frac) + sorted[hi] * frac
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn median_odd_length() {
145        assert_eq!(median(&[1.0, 3.0, 2.0]), 2.0);
146    }
147
148    #[test]
149    fn median_even_length_averages_two_middle() {
150        assert_eq!(median(&[1.0, 2.0, 3.0, 4.0]), 2.5);
151    }
152
153    #[test]
154    fn median_empty_is_zero() {
155        assert_eq!(median(&[]), 0.0);
156    }
157
158    #[test]
159    fn percentile_exact() {
160        let sorted: Vec<f64> = (0..=100).map(|i| i as f64).collect();
161        assert!((percentile(&sorted, 50.0) - 50.0).abs() < 1e-9);
162        assert!((percentile(&sorted, 2.5) - 2.5).abs() < 1e-9);
163        assert!((percentile(&sorted, 97.5) - 97.5).abs() < 1e-9);
164    }
165
166    #[test]
167    fn paired_ci_zero_on_equal_samples() {
168        let baseline: Vec<f64> = (0..100).map(|i| i as f64).collect();
169        let candidate = baseline.clone();
170        let result = paired_ci(
171            &baseline,
172            &candidate,
173            |b, c| median(c) - median(b),
174            200,
175            Some(42),
176        );
177        assert!(result.low.abs() < 1e-9);
178        assert!(result.median.abs() < 1e-9);
179        assert!(result.high.abs() < 1e-9);
180    }
181
182    #[test]
183    fn paired_ci_detects_consistent_shift() {
184        // candidate is baseline + 10, every element.
185        let baseline: Vec<f64> = (0..100).map(|i| i as f64).collect();
186        let candidate: Vec<f64> = baseline.iter().map(|x| x + 10.0).collect();
187        let result = paired_ci(
188            &baseline,
189            &candidate,
190            |b, c| median(c) - median(b),
191            500,
192            Some(7),
193        );
194        // CI should tightly bracket +10.
195        assert!(result.low > 5.0);
196        assert!(result.high < 15.0);
197        assert!((result.median - 10.0).abs() < 2.0);
198    }
199
200    #[test]
201    fn paired_ci_empty_is_zero() {
202        let r = paired_ci(&[], &[], |_, _| 0.0, 100, Some(1));
203        assert_eq!(r.low, 0.0);
204        assert_eq!(r.median, 0.0);
205        assert_eq!(r.high, 0.0);
206    }
207
208    #[test]
209    fn paired_ci_is_deterministic_with_seed() {
210        let baseline: Vec<f64> = (0..50).map(|i| (i as f64) * 1.5).collect();
211        let candidate: Vec<f64> = (0..50).map(|i| (i as f64) * 1.5 + 3.0).collect();
212        let a = paired_ci(
213            &baseline,
214            &candidate,
215            |b, c| median(c) - median(b),
216            200,
217            Some(123),
218        );
219        let b = paired_ci(
220            &baseline,
221            &candidate,
222            |b, c| median(c) - median(b),
223            200,
224            Some(123),
225        );
226        assert_eq!(a, b);
227    }
228
229    #[test]
230    #[should_panic(expected = "must have equal length")]
231    fn paired_ci_panics_on_length_mismatch() {
232        paired_ci(&[1.0, 2.0], &[1.0], |_, _| 0.0, 100, Some(1));
233    }
234
235    #[test]
236    fn default_iterations_is_used_when_zero_passed() {
237        let baseline: Vec<f64> = (0..50).map(|i| i as f64).collect();
238        let candidate = baseline.clone();
239        let r = paired_ci(
240            &baseline,
241            &candidate,
242            |b, c| median(c) - median(b),
243            0,
244            Some(1),
245        );
246        // Just check it ran (no panic) and returned a CI triple.
247        assert!(r.low <= r.median);
248        assert!(r.median <= r.high);
249    }
250}