1extern crate alloc;
14
15use alloc::vec::Vec;
16
17use crate::constants::DECILES;
18use crate::math;
19use crate::result::{EffectEstimate, TopQuantile};
20use crate::types::{Matrix9, Vector9};
21
22type QuantileStats = (usize, f64, f64, (f64, f64), f64);
24
25pub fn compute_effect_estimate(delta_draws: &[Vector9], theta: f64) -> EffectEstimate {
41 if delta_draws.is_empty() {
42 return EffectEstimate::default();
43 }
44
45 let n = delta_draws.len();
46
47 let mut max_effects: Vec<f64> = Vec::with_capacity(n);
49 for delta in delta_draws {
50 let max_abs = delta.iter().map(|x| x.abs()).fold(0.0_f64, f64::max);
51 max_effects.push(max_abs);
52 }
53
54 let max_effect_ns = max_effects.iter().sum::<f64>() / n as f64;
56
57 max_effects.sort_by(|a, b| a.total_cmp(b));
59 let lo_idx = ((n as f64 * 0.025).round() as usize).min(n - 1);
60 let hi_idx = ((n as f64 * 0.975).round() as usize).min(n - 1);
61 let credible_interval_ns = (max_effects[lo_idx], max_effects[hi_idx]);
62
63 let top_quantiles = compute_top_quantiles(delta_draws, theta);
65
66 EffectEstimate {
67 max_effect_ns,
68 credible_interval_ns,
69 top_quantiles,
70 }
71}
72
73pub fn compute_top_quantiles(delta_draws: &[Vector9], theta: f64) -> Vec<TopQuantile> {
82 if delta_draws.is_empty() {
83 return Vec::new();
84 }
85
86 let n = delta_draws.len();
87
88 let mut quantile_stats: Vec<QuantileStats> = Vec::with_capacity(9);
90
91 for k in 0..9 {
92 let mut values: Vec<f64> = delta_draws.iter().map(|d| d[k]).collect();
94
95 let mean = values.iter().sum::<f64>() / n as f64;
97
98 values.sort_by(|a, b| a.total_cmp(b));
100 let lo_idx = ((n as f64 * 0.025).round() as usize).min(n - 1);
101 let hi_idx = ((n as f64 * 0.975).round() as usize).min(n - 1);
102 let ci = (values[lo_idx], values[hi_idx]);
103
104 let exceed_count = delta_draws.iter().filter(|d| d[k].abs() > theta).count();
106 let exceed_prob = exceed_count as f64 / n as f64;
107
108 quantile_stats.push((k, DECILES[k], mean, ci, exceed_prob));
109 }
110
111 quantile_stats.sort_by(|a, b| b.4.total_cmp(&a.4));
113
114 quantile_stats
116 .into_iter()
117 .filter(|(_, _, _, _, exceed_prob)| *exceed_prob > 0.5)
118 .take(3)
119 .map(
120 |(_, quantile_p, mean_ns, ci95_ns, exceed_prob)| TopQuantile {
121 quantile_p,
122 mean_ns,
123 ci95_ns,
124 exceed_prob,
125 },
126 )
127 .collect()
128}
129
130pub fn compute_effect_estimate_analytical(
145 delta_post: &Vector9,
146 lambda_post: &Matrix9,
147 theta: f64,
148) -> EffectEstimate {
149 let max_effect_ns = delta_post.iter().map(|x| x.abs()).fold(0.0_f64, f64::max);
151
152 let max_k = delta_post
155 .iter()
156 .enumerate()
157 .max_by(|(_, a), (_, b)| a.abs().total_cmp(&b.abs()))
158 .map(|(k, _)| k)
159 .unwrap_or(0);
160
161 let se = math::sqrt(lambda_post[(max_k, max_k)].max(1e-12));
162 let ci_low = (max_effect_ns - 1.96 * se).max(0.0);
163 let ci_high = max_effect_ns + 1.96 * se;
164
165 let top_quantiles = compute_top_quantiles_analytical(delta_post, lambda_post, theta);
167
168 EffectEstimate {
169 max_effect_ns,
170 credible_interval_ns: (ci_low, ci_high),
171 top_quantiles,
172 }
173}
174
175fn compute_top_quantiles_analytical(
177 delta_post: &Vector9,
178 lambda_post: &Matrix9,
179 theta: f64,
180) -> Vec<TopQuantile> {
181 let mut quantile_stats: Vec<QuantileStats> = Vec::with_capacity(9);
182
183 for k in 0..9 {
184 let mean = delta_post[k];
185 let se = math::sqrt(lambda_post[(k, k)].max(1e-12));
186
187 let ci = (mean - 1.96 * se, mean + 1.96 * se);
189
190 let exceed_prob = compute_exceedance_prob(mean, se, theta);
193
194 quantile_stats.push((k, DECILES[k], mean, ci, exceed_prob));
195 }
196
197 quantile_stats.sort_by(|a, b| b.4.total_cmp(&a.4));
199
200 quantile_stats
202 .into_iter()
203 .filter(|(_, _, _, _, exceed_prob)| *exceed_prob > 0.5)
204 .take(3)
205 .map(
206 |(_, quantile_p, mean_ns, ci95_ns, exceed_prob)| TopQuantile {
207 quantile_p,
208 mean_ns,
209 ci95_ns,
210 exceed_prob,
211 },
212 )
213 .collect()
214}
215
216fn compute_exceedance_prob(mu: f64, sigma: f64, theta: f64) -> f64 {
218 if sigma < 1e-12 {
219 return if mu.abs() > theta { 1.0 } else { 0.0 };
221 }
222 let phi_upper = math::normal_cdf((theta - mu) / sigma);
223 let phi_lower = math::normal_cdf((-theta - mu) / sigma);
224 1.0 - (phi_upper - phi_lower)
225}
226
227pub fn regularize_covariance(sigma: &Matrix9) -> Matrix9 {
236 let trace: f64 = (0..9).map(|i| sigma[(i, i)]).sum();
237 let mean_var = trace / 9.0;
238
239 let min_var = (0.01 * mean_var).max(1e-10);
241
242 let jitter = 1e-10 + mean_var * 1e-8;
244
245 let mut regularized = *sigma;
246 for i in 0..9 {
247 regularized[(i, i)] = regularized[(i, i)].max(min_var) + jitter;
248 }
249 regularized
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_effect_estimate_basic() {
258 let draws: Vec<Vector9> = (0..100)
260 .map(|i| {
261 let val = (i as f64) * 0.1;
262 Vector9::from_row_slice(&[val, val, val, val, val, val, val, val, val])
263 })
264 .collect();
265
266 let estimate = compute_effect_estimate(&draws, 5.0);
267
268 assert!(
270 estimate.max_effect_ns > 4.0,
271 "max effect should be significant"
272 );
273 assert!(
274 estimate.credible_interval_ns.0 < estimate.max_effect_ns,
275 "CI lower should be below mean"
276 );
277 assert!(
278 estimate.credible_interval_ns.1 > estimate.max_effect_ns,
279 "CI upper should be above mean"
280 );
281 }
282
283 #[test]
284 fn test_effect_estimate_empty() {
285 let estimate = compute_effect_estimate(&[], 5.0);
286 assert_eq!(estimate.max_effect_ns, 0.0);
287 assert!(estimate.top_quantiles.is_empty());
288 }
289
290 #[test]
291 fn test_top_quantiles_threshold() {
292 let draws: Vec<Vector9> = (0..100)
294 .map(|_| Vector9::from_row_slice(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 10.0]))
295 .collect();
296
297 let top = compute_top_quantiles(&draws, 5.0);
298
299 assert!(!top.is_empty());
301 assert!((top[0].quantile_p - 0.9).abs() < 0.01);
302 assert!(top[0].exceed_prob > 0.99);
303 }
304
305 #[test]
306 fn test_regularize_covariance() {
307 let mut sigma = Matrix9::zeros();
308 for i in 0..9 {
309 sigma[(i, i)] = if i == 0 { 0.0 } else { 1.0 }; }
311
312 let regularized = regularize_covariance(&sigma);
313
314 for i in 0..9 {
316 assert!(
317 regularized[(i, i)] > 0.0,
318 "diagonal {} should be positive",
319 i
320 );
321 }
322 }
323
324 #[test]
325 fn test_exceedance_prob() {
326 let prob_high = compute_exceedance_prob(100.0, 10.0, 50.0);
328 assert!(prob_high > 0.99, "large mean should exceed threshold");
329
330 let prob_low = compute_exceedance_prob(1.0, 1.0, 50.0);
332 assert!(prob_low < 0.01, "small mean should not exceed threshold");
333
334 let prob_2sigma = compute_exceedance_prob(0.0, 1.0, 2.0);
336 assert!(
337 (prob_2sigma - 0.0455).abs() < 0.01,
338 "2σ threshold should have ~4.5% exceedance"
339 );
340 }
341}