1extern crate alloc;
18
19use alloc::vec::Vec;
20
21use rand::SeedableRng;
22use rand_distr::{Distribution, StandardNormal};
23use rand_xoshiro::Xoshiro256PlusPlus;
24
25use crate::constants::{B_TAIL, ONES};
26use crate::math;
27use crate::result::MinDetectableEffect;
28use crate::types::{Matrix9, Vector9};
29
30use super::effect::decompose_effect;
31
32#[derive(Debug, Clone)]
34pub struct MdeEstimate {
35 pub shift_ns: f64,
37 pub tail_ns: f64,
39 pub n_simulations: usize,
41}
42
43impl From<MdeEstimate> for MinDetectableEffect {
44 fn from(mde: MdeEstimate) -> Self {
45 MinDetectableEffect {
46 shift_ns: mde.shift_ns,
47 tail_ns: mde.tail_ns,
48 }
49 }
50}
51
52pub fn analytical_mde(covariance: &Matrix9, alpha: f64) -> (f64, f64) {
69 let chol = safe_cholesky(covariance);
71
72 let ones_vec = Vector9::from_iterator(ONES.iter().cloned());
74 let sigma_inv_ones = chol.solve(&ones_vec);
75 let precision_shift = ones_vec.dot(&sigma_inv_ones);
76 let var_shift = if precision_shift.abs() > 1e-12 {
77 1.0 / precision_shift
78 } else {
79 1e12
81 };
82
83 let b_tail_vec = Vector9::from_iterator(B_TAIL.iter().cloned());
85 let sigma_inv_b = chol.solve(&b_tail_vec);
86 let precision_tail = b_tail_vec.dot(&sigma_inv_b);
87 let var_tail = if precision_tail.abs() > 1e-12 {
88 1.0 / precision_tail
89 } else {
90 1e12
91 };
92
93 let z = probit(1.0 - alpha / 2.0);
95 let mde_shift = z * math::sqrt(var_shift);
96 let mde_tail = z * math::sqrt(var_tail);
97
98 (mde_shift, mde_tail)
99}
100
101fn probit(p: f64) -> f64 {
106 if p <= 0.0 {
107 return f64::NEG_INFINITY;
108 }
109 if p >= 1.0 {
110 return f64::INFINITY;
111 }
112
113 let (sign, q) = if p < 0.5 { (-1.0, 1.0 - p) } else { (1.0, p) };
115
116 const C0: f64 = 2.515517;
118 const C1: f64 = 0.802853;
119 const C2: f64 = 0.010328;
120 const D1: f64 = 1.432788;
121 const D2: f64 = 0.189269;
122 const D3: f64 = 0.001308;
123
124 let t = math::sqrt(-2.0 * math::ln(1.0 - q));
125 let z = t - (C0 + C1 * t + C2 * t * t) / (1.0 + D1 * t + D2 * t * t + D3 * t * t * t);
126
127 sign * z
128}
129
130#[allow(dead_code)]
135pub fn estimate_mde_monte_carlo(
136 covariance: &Matrix9,
137 n_simulations: usize,
138 prior_sigmas: (f64, f64),
139) -> MdeEstimate {
140 let mut rng = Xoshiro256PlusPlus::seed_from_u64(42);
141
142 let chol = match nalgebra::Cholesky::new(*covariance) {
144 Some(c) => c,
145 None => {
146 let regularized = covariance + Matrix9::identity() * 1e-10;
148 nalgebra::Cholesky::new(regularized)
149 .expect("Regularized covariance should be positive definite")
150 }
151 };
152
153 let mut shift_effects = Vec::with_capacity(n_simulations);
155 let mut tail_effects = Vec::with_capacity(n_simulations);
156
157 for _ in 0..n_simulations {
158 let z: Vector9 = Vector9::from_fn(|_, _| StandardNormal.sample(&mut rng));
160 let null_sample = chol.l() * z;
161
162 let decomp = decompose_effect(&null_sample, covariance, prior_sigmas);
164
165 shift_effects.push(decomp.posterior_mean[0].abs());
167 tail_effects.push(decomp.posterior_mean[1].abs());
168 }
169
170 let shift_mde = percentile(&mut shift_effects, 0.95);
172 let tail_mde = percentile(&mut tail_effects, 0.95);
173
174 MdeEstimate {
175 shift_ns: shift_mde,
176 tail_ns: tail_mde,
177 n_simulations,
178 }
179}
180
181pub fn estimate_mde(covariance: &Matrix9, alpha: f64) -> MdeEstimate {
192 let (shift_ns, tail_ns) = analytical_mde(covariance, alpha);
193
194 MdeEstimate {
195 shift_ns,
196 tail_ns,
197 n_simulations: 0, }
199}
200
201fn percentile(values: &mut [f64], p: f64) -> f64 {
205 if values.is_empty() {
206 return 0.0;
207 }
208
209 values.sort_by(|a, b| a.total_cmp(b));
210
211 let idx = math::round(p * (values.len() - 1) as f64) as usize;
212 let idx = idx.min(values.len() - 1);
213
214 values[idx]
215}
216
217fn safe_cholesky(matrix: &Matrix9) -> nalgebra::Cholesky<f64, nalgebra::Const<9>> {
221 if let Some(chol) = nalgebra::Cholesky::new(*matrix) {
223 return chol;
224 }
225
226 let trace = matrix.trace();
228 let base_jitter = 1e-10;
229 let adaptive_jitter = (trace / 9.0) * 1e-8;
230 let jitter = base_jitter + adaptive_jitter;
231
232 let mut regularized = *matrix;
233 for i in 0..9 {
234 regularized[(i, i)] += jitter;
235 }
236
237 nalgebra::Cholesky::new(regularized).expect("Cholesky failed even after regularization")
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_percentile_basic() {
246 let mut values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
247 assert!((percentile(&mut values, 0.5) - 3.0).abs() < 0.01);
248 }
249
250 #[test]
251 fn test_percentile_95() {
252 let mut values: Vec<f64> = (1..=100).map(|x| x as f64).collect();
253 let p95 = percentile(&mut values, 0.95);
254 assert!((p95 - 95.0).abs() < 1.0);
255 }
256
257 #[test]
258 fn test_mde_positive() {
259 let cov = Matrix9::identity();
261 let mde = estimate_mde(&cov, 0.05);
262
263 assert!(mde.shift_ns > 0.0, "MDE shift should be positive");
264 assert!(mde.tail_ns > 0.0, "MDE tail should be positive");
265 }
266
267 #[test]
268 fn test_probit_accuracy() {
269 assert!((probit(0.5) - 0.0).abs() < 1e-3, "probit(0.5) should be 0");
271 assert!(
272 (probit(0.975) - 1.96).abs() < 1e-2,
273 "probit(0.975) should be ~1.96"
274 );
275 assert!(
276 (probit(0.995) - 2.576).abs() < 1e-2,
277 "probit(0.995) should be ~2.576"
278 );
279 assert!(
280 (probit(0.025) + 1.96).abs() < 1e-2,
281 "probit(0.025) should be ~-1.96"
282 );
283 }
284
285 #[test]
286 fn test_analytical_mde_iid_sanity_check() {
287 let cov = Matrix9::identity();
296 let (mde_shift, mde_tail) = analytical_mde(&cov, 0.05);
297
298 let z = 1.96; let expected_shift = z / 3.0;
300 let expected_tail = z * (1.0 / 0.9375_f64).sqrt();
301
302 assert!(
303 (mde_shift - expected_shift).abs() < 0.05,
304 "shift MDE should be ~{:.3}, got {:.3}",
305 expected_shift,
306 mde_shift
307 );
308 assert!(
309 (mde_tail - expected_tail).abs() < 0.1,
310 "tail MDE should be ~{:.3}, got {:.3}",
311 expected_tail,
312 mde_tail
313 );
314 }
315
316 #[test]
317 fn test_analytical_mde_alpha_scaling() {
318 let cov = Matrix9::identity();
320 let (mde_05, _) = analytical_mde(&cov, 0.05); let (mde_01, _) = analytical_mde(&cov, 0.01); assert!(
324 mde_01 > mde_05,
325 "MDE at α=0.01 ({:.3}) should be larger than α=0.05 ({:.3})",
326 mde_01,
327 mde_05
328 );
329 }
330
331 #[test]
332 fn test_analytical_mde_diagonal_covariance() {
333 let mut cov = Matrix9::zeros();
335 for i in 0..9 {
336 cov[(i, i)] = (i + 1) as f64;
337 }
338
339 let (mde_shift, mde_tail) = analytical_mde(&cov, 0.05);
340
341 assert!(
342 mde_shift.is_finite() && mde_shift > 0.0,
343 "shift MDE not finite or positive: {}",
344 mde_shift
345 );
346 assert!(
347 mde_tail.is_finite() && mde_tail > 0.0,
348 "tail MDE not finite or positive: {}",
349 mde_tail
350 );
351 }
352
353 #[test]
354 fn test_analytical_mde_near_singular() {
355 let mut cov = Matrix9::identity() * 1e-6;
357 cov[(0, 0)] = 1.0;
358
359 let (mde_shift, mde_tail) = analytical_mde(&cov, 0.05);
360
361 assert!(mde_shift.is_finite(), "shift MDE not finite: {}", mde_shift);
362 assert!(mde_tail.is_finite(), "tail MDE not finite: {}", mde_tail);
363 }
364
365 #[test]
366 #[cfg(feature = "std")]
367 #[cfg_attr(debug_assertions, ignore)] fn test_analytical_mde_performance() {
369 let mut cov = Matrix9::identity();
370 for i in 0..9 {
371 for j in 0..9 {
372 let dist = (i as f64 - j as f64).abs();
373 cov[(i, j)] = (-dist / 2.0).exp();
374 }
375 }
376
377 let start = std::time::Instant::now();
378 for _ in 0..1000 {
379 let _ = analytical_mde(&cov, 0.05);
380 }
381 let analytical_time = start.elapsed();
382
383 assert!(
386 analytical_time.as_micros() < 3000,
387 "analytical MDE too slow: {:.1}µs per call (threshold: 3µs)",
388 analytical_time.as_micros() as f64 / 1000.0
389 );
390 }
391
392 }