1extern crate alloc;
40
41use alloc::vec::Vec;
42use core::f64::consts::PI;
43
44use nalgebra::Cholesky;
45use rand::prelude::*;
46use rand::SeedableRng;
47use rand_xoshiro::Xoshiro256PlusPlus;
48
49use crate::constants::{B_TAIL, ONES};
50use crate::math;
51use crate::types::{Matrix2, Matrix9, Matrix9x2, Vector2, Vector9};
52
53const N_MONTE_CARLO: usize = 1000;
55
56#[derive(Debug, Clone)]
61pub struct BayesResult {
62 pub leak_probability: f64,
64
65 pub delta_post: Vector9,
67
68 pub lambda_post: Matrix9,
70
71 pub beta_proj: Vector2,
75
76 pub beta_proj_cov: Matrix2,
78
79 pub beta_draws: Vec<Vector2>,
82
83 pub projection_mismatch_q: f64,
86
87 pub effect_magnitude_ci: (f64, f64),
89
90 pub is_clamped: bool,
93
94 pub sigma_n: Matrix9,
96
97 pub lambda_mean: f64,
102
103 pub lambda_sd: f64,
105
106 pub lambda_cv: f64,
108
109 pub lambda_ess: f64,
111
112 pub lambda_mixing_ok: bool,
114
115 pub kappa_mean: f64,
119
120 pub kappa_sd: f64,
122
123 pub kappa_cv: f64,
125
126 pub kappa_ess: f64,
128
129 pub kappa_mixing_ok: bool,
131}
132
133pub fn compute_bayes_gibbs(
152 delta: &Vector9,
153 sigma_n: &Matrix9,
154 sigma_t: f64,
155 l_r: &Matrix9,
156 theta: f64,
157 seed: Option<u64>,
158) -> BayesResult {
159 use super::gibbs::run_gibbs_inference;
160
161 let regularized = add_jitter(*sigma_n);
162 let actual_seed = seed.unwrap_or(crate::constants::DEFAULT_SEED);
163
164 let gibbs_result = run_gibbs_inference(delta, ®ularized, sigma_t, l_r, theta, actual_seed);
166
167 BayesResult {
168 leak_probability: gibbs_result.leak_probability,
169 delta_post: gibbs_result.delta_post,
170 lambda_post: gibbs_result.lambda_post,
171 beta_proj: gibbs_result.beta_proj,
172 beta_proj_cov: gibbs_result.beta_proj_cov,
173 beta_draws: gibbs_result.beta_draws,
174 projection_mismatch_q: gibbs_result.projection_mismatch_q,
175 effect_magnitude_ci: gibbs_result.effect_magnitude_ci,
176 is_clamped: false,
177 sigma_n: regularized,
178 lambda_mean: gibbs_result.lambda_mean,
180 lambda_sd: gibbs_result.lambda_sd,
181 lambda_cv: gibbs_result.lambda_cv,
182 lambda_ess: gibbs_result.lambda_ess,
183 lambda_mixing_ok: gibbs_result.lambda_mixing_ok,
184 kappa_mean: gibbs_result.kappa_mean,
186 kappa_sd: gibbs_result.kappa_sd,
187 kappa_cv: gibbs_result.kappa_cv,
188 kappa_ess: gibbs_result.kappa_ess,
189 kappa_mixing_ok: gibbs_result.kappa_mixing_ok,
190 }
191}
192
193pub fn compute_2d_projection(
200 delta_post: &Vector9,
201 lambda_post: &Matrix9,
202 sigma_n: &Matrix9,
203) -> (Vector2, Matrix2, f64) {
204 let design = build_design_matrix();
205
206 let sigma_n_chol = match Cholesky::new(*sigma_n) {
208 Some(c) => c,
209 None => {
210 return (Vector2::zeros(), Matrix2::identity() * 1e6, 0.0);
212 }
213 };
214
215 let mut sigma_n_inv_x = Matrix9x2::zeros();
217 for j in 0..2 {
218 let col = design.column(j).into_owned();
219 let solved = sigma_n_chol.solve(&col);
220 for i in 0..9 {
221 sigma_n_inv_x[(i, j)] = solved[i];
222 }
223 }
224
225 let xt_sigma_n_inv_x = design.transpose() * sigma_n_inv_x;
227
228 let xt_chol = match Cholesky::new(xt_sigma_n_inv_x) {
229 Some(c) => c,
230 None => {
231 return (Vector2::zeros(), Matrix2::identity() * 1e6, 0.0);
232 }
233 };
234
235 let sigma_n_inv_delta = sigma_n_chol.solve(delta_post);
237 let xt_sigma_n_inv_delta = design.transpose() * sigma_n_inv_delta;
238
239 let beta_proj = xt_chol.solve(&xt_sigma_n_inv_delta);
241
242 let a_matrix = xt_chol.inverse() * sigma_n_inv_x.transpose();
248 let beta_proj_cov = a_matrix * lambda_post * a_matrix.transpose();
249
250 let delta_proj = design * beta_proj;
252 let r_proj = delta_post - delta_proj;
253 let sigma_n_inv_r = sigma_n_chol.solve(&r_proj);
254 let q_proj = r_proj.dot(&sigma_n_inv_r);
255
256 (beta_proj, beta_proj_cov, q_proj)
257}
258
259fn sample_standard_normal<R: Rng>(rng: &mut R) -> f64 {
261 let u1: f64 = rng.random();
262 let u2: f64 = rng.random();
263 math::sqrt(-2.0 * math::ln(u1)) * math::cos(2.0 * PI * u2)
264}
265
266#[derive(Debug, Clone)]
268pub struct MaxEffectCI {
269 pub mean: f64,
271 pub ci: (f64, f64),
273}
274
275pub fn compute_max_effect_ci(
279 delta_post: &Vector9,
280 lambda_post: &Matrix9,
281 seed: u64,
282) -> MaxEffectCI {
283 let chol = match Cholesky::new(*lambda_post) {
284 Some(c) => c,
285 None => {
286 let jittered = add_jitter(*lambda_post);
287 match Cholesky::new(jittered) {
288 Some(c) => c,
289 None => {
290 return MaxEffectCI {
291 mean: 0.0,
292 ci: (0.0, 0.0),
293 };
294 }
295 }
296 }
297 };
298 let l = chol.l();
299
300 let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
301 let mut max_effects = Vec::with_capacity(N_MONTE_CARLO);
302 let mut sum = 0.0;
303
304 for _ in 0..N_MONTE_CARLO {
305 let mut z = Vector9::zeros();
306 for i in 0..9 {
307 z[i] = sample_standard_normal(&mut rng);
308 }
309
310 let delta_sample = delta_post + l * z;
311 let max_effect = delta_sample.iter().map(|x| x.abs()).fold(0.0_f64, f64::max);
312 max_effects.push(max_effect);
313 sum += max_effect;
314 }
315
316 let mean = sum / N_MONTE_CARLO as f64;
317
318 max_effects.sort_by(|a, b| a.total_cmp(b));
319 let lo_idx = math::round((N_MONTE_CARLO as f64) * 0.025) as usize;
320 let hi_idx = math::round((N_MONTE_CARLO as f64) * 0.975) as usize;
321 let ci = (
322 max_effects[lo_idx.min(N_MONTE_CARLO - 1)],
323 max_effects[hi_idx.min(N_MONTE_CARLO - 1)],
324 );
325
326 MaxEffectCI { mean, ci }
327}
328
329pub fn compute_quantile_exceedances(
333 delta_post: &Vector9,
334 lambda_post: &Matrix9,
335 theta: f64,
336) -> [f64; 9] {
337 let mut exceedances = [0.0; 9];
338 for k in 0..9 {
339 let mu = delta_post[k];
340 let sigma = math::sqrt(lambda_post[(k, k)].max(1e-12));
341 exceedances[k] = compute_single_quantile_exceedance(mu, sigma, theta);
342 }
343 exceedances
344}
345
346fn compute_single_quantile_exceedance(mu: f64, sigma: f64, theta: f64) -> f64 {
348 if sigma < 1e-12 {
349 return if mu.abs() > theta { 1.0 } else { 0.0 };
351 }
352 let phi_upper = math::normal_cdf((theta - mu) / sigma);
353 let phi_lower = math::normal_cdf((-theta - mu) / sigma);
354 1.0 - (phi_upper - phi_lower)
355}
356
357pub fn build_design_matrix() -> Matrix9x2 {
363 let mut x = Matrix9x2::zeros();
364 for i in 0..9 {
365 x[(i, 0)] = ONES[i];
366 x[(i, 1)] = B_TAIL[i];
367 }
368 x
369}
370
371fn add_jitter(mut sigma: Matrix9) -> Matrix9 {
375 let trace: f64 = (0..9).map(|i| sigma[(i, i)]).sum();
376 let mean_var = trace / 9.0;
377
378 let min_var = (0.01 * mean_var).max(1e-10);
379 let jitter = 1e-10 + mean_var * 1e-8;
380
381 for i in 0..9 {
382 sigma[(i, i)] = sigma[(i, i)].max(min_var) + jitter;
383 }
384 sigma
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn test_2d_projection_uniform_shift() {
393 let mut delta_post = Vector9::zeros();
395 for i in 0..9 {
396 delta_post[i] = 50.0;
397 }
398 let lambda_post = Matrix9::identity();
399 let sigma_n = Matrix9::identity();
400
401 let (beta_proj, _, q_proj) = compute_2d_projection(&delta_post, &lambda_post, &sigma_n);
402
403 assert!(
405 (beta_proj[0] - 50.0).abs() < 1.0,
406 "Shift should be ~50, got {}",
407 beta_proj[0]
408 );
409 assert!(
410 beta_proj[1].abs() < 5.0,
411 "Tail should be ~0, got {}",
412 beta_proj[1]
413 );
414 assert!(
415 q_proj < 1.0,
416 "Uniform shift should have low Q, got {}",
417 q_proj
418 );
419 }
420
421 #[test]
422 fn test_quantile_exceedance_computation() {
423 let mu = 100.0;
424 let sigma = 10.0;
425 let theta = 50.0;
426
427 let exceedance = compute_single_quantile_exceedance(mu, sigma, theta);
428
429 assert!(
431 exceedance > 0.99,
432 "With μ=100, θ=50, exceedance should be ~1.0, got {}",
433 exceedance
434 );
435 }
436
437 #[test]
438 fn test_quantile_exceedance_symmetric() {
439 let sigma = 10.0;
440 let theta = 50.0;
441
442 let exc_pos = compute_single_quantile_exceedance(30.0, sigma, theta);
443 let exc_neg = compute_single_quantile_exceedance(-30.0, sigma, theta);
444
445 assert!(
447 (exc_pos - exc_neg).abs() < 0.01,
448 "Exceedance should be symmetric, got {} vs {}",
449 exc_pos,
450 exc_neg
451 );
452 }
453
454 #[test]
455 fn test_gibbs_determinism() {
456 use crate::adaptive::calibrate_t_prior_scale;
457
458 let delta = Vector9::from_row_slice(&[5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0]);
459 let sigma_n = Matrix9::identity() * 100.0;
460 let sigma_rate = sigma_n * 1000.0;
461 let theta = 10.0;
462
463 let (sigma_t, l_r) = calibrate_t_prior_scale(&sigma_rate, theta, 1000, false, 42);
464
465 let result1 = compute_bayes_gibbs(&delta, &sigma_n, sigma_t, &l_r, theta, Some(42));
466 let result2 = compute_bayes_gibbs(&delta, &sigma_n, sigma_t, &l_r, theta, Some(42));
467
468 assert_eq!(
469 result1.leak_probability, result2.leak_probability,
470 "Same seed should give same result"
471 );
472 }
473}