shadow_core/diff/
bootstrap.rs1use rand::rngs::StdRng;
12use rand::seq::SliceRandom;
13use rand::SeedableRng;
14
15#[derive(Debug, Clone, Copy, PartialEq)]
17pub struct CiResult {
18 pub low: f64,
20 pub median: f64,
22 pub high: f64,
24}
25
26pub const DEFAULT_ITERATIONS: usize = 1000;
30
31pub fn paired_ci<F>(
41 baseline: &[f64],
42 candidate: &[f64],
43 statistic: F,
44 iterations: usize,
45 seed: Option<u64>,
46) -> CiResult
47where
48 F: Fn(&[f64], &[f64]) -> f64,
49{
50 let n = baseline.len();
51 #[allow(clippy::panic)]
57 {
58 if n != candidate.len() {
59 panic!("baseline and candidate must have equal length");
60 }
61 }
62 if n == 0 {
63 return CiResult {
64 low: 0.0,
65 median: 0.0,
66 high: 0.0,
67 };
68 }
69 let iterations = if iterations == 0 {
70 DEFAULT_ITERATIONS
71 } else {
72 iterations
73 };
74
75 let mut rng: StdRng = match seed {
76 Some(s) => StdRng::seed_from_u64(s),
77 None => StdRng::from_entropy(),
78 };
79
80 let mut samples = Vec::with_capacity(iterations);
81 let indices: Vec<usize> = (0..n).collect();
82 let mut b_buf = Vec::with_capacity(n);
83 let mut c_buf = Vec::with_capacity(n);
84 for _ in 0..iterations {
85 b_buf.clear();
86 c_buf.clear();
87 for _ in 0..n {
88 let i = *indices.choose(&mut rng).unwrap_or(&0);
92 b_buf.push(baseline[i]);
93 c_buf.push(candidate[i]);
94 }
95 samples.push(statistic(&b_buf, &c_buf));
96 }
97 samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
98 CiResult {
99 low: percentile(&samples, 2.5),
100 median: percentile(&samples, 50.0),
101 high: percentile(&samples, 97.5),
102 }
103}
104
105pub fn median(xs: &[f64]) -> f64 {
109 if xs.is_empty() {
110 return 0.0;
111 }
112 let mut sorted: Vec<f64> = xs.to_vec();
113 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
114 let n = sorted.len();
115 if n % 2 == 1 {
116 sorted[n / 2]
117 } else {
118 (sorted[n / 2 - 1] + sorted[n / 2]) / 2.0
119 }
120}
121
122fn percentile(sorted: &[f64], p: f64) -> f64 {
124 if sorted.is_empty() {
125 return 0.0;
126 }
127 let n = sorted.len() as f64;
128 let rank = (p / 100.0) * (n - 1.0);
129 let lo = rank.floor() as usize;
130 let hi = rank.ceil() as usize;
131 if lo == hi {
132 sorted[lo]
133 } else {
134 let frac = rank - lo as f64;
135 sorted[lo] * (1.0 - frac) + sorted[hi] * frac
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn median_odd_length() {
145 assert_eq!(median(&[1.0, 3.0, 2.0]), 2.0);
146 }
147
148 #[test]
149 fn median_even_length_averages_two_middle() {
150 assert_eq!(median(&[1.0, 2.0, 3.0, 4.0]), 2.5);
151 }
152
153 #[test]
154 fn median_empty_is_zero() {
155 assert_eq!(median(&[]), 0.0);
156 }
157
158 #[test]
159 fn percentile_exact() {
160 let sorted: Vec<f64> = (0..=100).map(|i| i as f64).collect();
161 assert!((percentile(&sorted, 50.0) - 50.0).abs() < 1e-9);
162 assert!((percentile(&sorted, 2.5) - 2.5).abs() < 1e-9);
163 assert!((percentile(&sorted, 97.5) - 97.5).abs() < 1e-9);
164 }
165
166 #[test]
167 fn paired_ci_zero_on_equal_samples() {
168 let baseline: Vec<f64> = (0..100).map(|i| i as f64).collect();
169 let candidate = baseline.clone();
170 let result = paired_ci(
171 &baseline,
172 &candidate,
173 |b, c| median(c) - median(b),
174 200,
175 Some(42),
176 );
177 assert!(result.low.abs() < 1e-9);
178 assert!(result.median.abs() < 1e-9);
179 assert!(result.high.abs() < 1e-9);
180 }
181
182 #[test]
183 fn paired_ci_detects_consistent_shift() {
184 let baseline: Vec<f64> = (0..100).map(|i| i as f64).collect();
186 let candidate: Vec<f64> = baseline.iter().map(|x| x + 10.0).collect();
187 let result = paired_ci(
188 &baseline,
189 &candidate,
190 |b, c| median(c) - median(b),
191 500,
192 Some(7),
193 );
194 assert!(result.low > 5.0);
196 assert!(result.high < 15.0);
197 assert!((result.median - 10.0).abs() < 2.0);
198 }
199
200 #[test]
201 fn paired_ci_empty_is_zero() {
202 let r = paired_ci(&[], &[], |_, _| 0.0, 100, Some(1));
203 assert_eq!(r.low, 0.0);
204 assert_eq!(r.median, 0.0);
205 assert_eq!(r.high, 0.0);
206 }
207
208 #[test]
209 fn paired_ci_is_deterministic_with_seed() {
210 let baseline: Vec<f64> = (0..50).map(|i| (i as f64) * 1.5).collect();
211 let candidate: Vec<f64> = (0..50).map(|i| (i as f64) * 1.5 + 3.0).collect();
212 let a = paired_ci(
213 &baseline,
214 &candidate,
215 |b, c| median(c) - median(b),
216 200,
217 Some(123),
218 );
219 let b = paired_ci(
220 &baseline,
221 &candidate,
222 |b, c| median(c) - median(b),
223 200,
224 Some(123),
225 );
226 assert_eq!(a, b);
227 }
228
229 #[test]
230 #[should_panic(expected = "must have equal length")]
231 fn paired_ci_panics_on_length_mismatch() {
232 paired_ci(&[1.0, 2.0], &[1.0], |_, _| 0.0, 100, Some(1));
233 }
234
235 #[test]
236 fn default_iterations_is_used_when_zero_passed() {
237 let baseline: Vec<f64> = (0..50).map(|i| i as f64).collect();
238 let candidate = baseline.clone();
239 let r = paired_ci(
240 &baseline,
241 &candidate,
242 |b, c| median(c) - median(b),
243 0,
244 Some(1),
245 );
246 assert!(r.low <= r.median);
248 assert!(r.median <= r.high);
249 }
250}