1extern crate alloc;
40
41use alloc::vec::Vec;
42use core::f64::consts::PI;
43
44use nalgebra::Cholesky;
45use rand::prelude::*;
46use rand::SeedableRng;
47use rand_xoshiro::Xoshiro256PlusPlus;
48
49use crate::constants::{B_TAIL, ONES};
50use crate::math;
51use crate::types::{Matrix2, Matrix9, Matrix9x2, Vector2, Vector9};
52
53const N_MONTE_CARLO: usize = 1000;
55
56#[derive(Debug, Clone)]
61pub struct BayesResult {
62 pub leak_probability: f64,
64
65 pub delta_post: Vector9,
67
68 pub lambda_post: Matrix9,
70
71 pub beta_proj: Vector2,
75
76 pub beta_proj_cov: Matrix2,
78
79 pub beta_draws: Vec<Vector2>,
82
83 pub projection_mismatch_q: f64,
86
87 pub effect_magnitude_ci: (f64, f64),
89
90 pub is_clamped: bool,
93
94 pub sigma_n: Matrix9,
96
97 pub lambda_mean: f64,
101
102 pub lambda_sd: f64,
104
105 pub lambda_cv: f64,
107
108 pub lambda_ess: f64,
110
111 pub lambda_mixing_ok: bool,
113
114 pub kappa_mean: f64,
117
118 pub kappa_sd: f64,
120
121 pub kappa_cv: f64,
123
124 pub kappa_ess: f64,
126
127 pub kappa_mixing_ok: bool,
129}
130
131pub fn compute_bayes_gibbs(
150 delta: &Vector9,
151 sigma_n: &Matrix9,
152 sigma_t: f64,
153 l_r: &Matrix9,
154 theta: f64,
155 seed: Option<u64>,
156) -> BayesResult {
157 use super::gibbs::run_gibbs_inference;
158
159 let regularized = add_jitter(*sigma_n);
160 let actual_seed = seed.unwrap_or(crate::constants::DEFAULT_SEED);
161
162 let gibbs_result = run_gibbs_inference(delta, ®ularized, sigma_t, l_r, theta, actual_seed);
164
165 BayesResult {
166 leak_probability: gibbs_result.leak_probability,
167 delta_post: gibbs_result.delta_post,
168 lambda_post: gibbs_result.lambda_post,
169 beta_proj: gibbs_result.beta_proj,
170 beta_proj_cov: gibbs_result.beta_proj_cov,
171 beta_draws: gibbs_result.beta_draws,
172 projection_mismatch_q: gibbs_result.projection_mismatch_q,
173 effect_magnitude_ci: gibbs_result.effect_magnitude_ci,
174 is_clamped: false,
175 sigma_n: regularized,
176 lambda_mean: gibbs_result.lambda_mean,
178 lambda_sd: gibbs_result.lambda_sd,
179 lambda_cv: gibbs_result.lambda_cv,
180 lambda_ess: gibbs_result.lambda_ess,
181 lambda_mixing_ok: gibbs_result.lambda_mixing_ok,
182 kappa_mean: gibbs_result.kappa_mean,
184 kappa_sd: gibbs_result.kappa_sd,
185 kappa_cv: gibbs_result.kappa_cv,
186 kappa_ess: gibbs_result.kappa_ess,
187 kappa_mixing_ok: gibbs_result.kappa_mixing_ok,
188 }
189}
190
191pub fn compute_2d_projection(
198 delta_post: &Vector9,
199 lambda_post: &Matrix9,
200 sigma_n: &Matrix9,
201) -> (Vector2, Matrix2, f64) {
202 let design = build_design_matrix();
203
204 let sigma_n_chol = match Cholesky::new(*sigma_n) {
206 Some(c) => c,
207 None => {
208 return (Vector2::zeros(), Matrix2::identity() * 1e6, 0.0);
210 }
211 };
212
213 let mut sigma_n_inv_x = Matrix9x2::zeros();
215 for j in 0..2 {
216 let col = design.column(j).into_owned();
217 let solved = sigma_n_chol.solve(&col);
218 for i in 0..9 {
219 sigma_n_inv_x[(i, j)] = solved[i];
220 }
221 }
222
223 let xt_sigma_n_inv_x = design.transpose() * sigma_n_inv_x;
225
226 let xt_chol = match Cholesky::new(xt_sigma_n_inv_x) {
227 Some(c) => c,
228 None => {
229 return (Vector2::zeros(), Matrix2::identity() * 1e6, 0.0);
230 }
231 };
232
233 let sigma_n_inv_delta = sigma_n_chol.solve(delta_post);
235 let xt_sigma_n_inv_delta = design.transpose() * sigma_n_inv_delta;
236
237 let beta_proj = xt_chol.solve(&xt_sigma_n_inv_delta);
239
240 let a_matrix = xt_chol.inverse() * sigma_n_inv_x.transpose();
246 let beta_proj_cov = a_matrix * lambda_post * a_matrix.transpose();
247
248 let delta_proj = design * beta_proj;
250 let r_proj = delta_post - delta_proj;
251 let sigma_n_inv_r = sigma_n_chol.solve(&r_proj);
252 let q_proj = r_proj.dot(&sigma_n_inv_r);
253
254 (beta_proj, beta_proj_cov, q_proj)
255}
256
257fn sample_standard_normal<R: Rng>(rng: &mut R) -> f64 {
259 let u1: f64 = rng.random();
260 let u2: f64 = rng.random();
261 math::sqrt(-2.0 * math::ln(u1)) * math::cos(2.0 * PI * u2)
262}
263
264#[derive(Debug, Clone)]
266pub struct MaxEffectCI {
267 pub mean: f64,
269 pub ci: (f64, f64),
271}
272
273pub fn compute_max_effect_ci(
277 delta_post: &Vector9,
278 lambda_post: &Matrix9,
279 seed: u64,
280) -> MaxEffectCI {
281 let chol = match Cholesky::new(*lambda_post) {
282 Some(c) => c,
283 None => {
284 let jittered = add_jitter(*lambda_post);
285 match Cholesky::new(jittered) {
286 Some(c) => c,
287 None => {
288 return MaxEffectCI {
289 mean: 0.0,
290 ci: (0.0, 0.0),
291 };
292 }
293 }
294 }
295 };
296 let l = chol.l();
297
298 let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
299 let mut max_effects = Vec::with_capacity(N_MONTE_CARLO);
300 let mut sum = 0.0;
301
302 for _ in 0..N_MONTE_CARLO {
303 let mut z = Vector9::zeros();
304 for i in 0..9 {
305 z[i] = sample_standard_normal(&mut rng);
306 }
307
308 let delta_sample = delta_post + l * z;
309 let max_effect = delta_sample.iter().map(|x| x.abs()).fold(0.0_f64, f64::max);
310 max_effects.push(max_effect);
311 sum += max_effect;
312 }
313
314 let mean = sum / N_MONTE_CARLO as f64;
315
316 max_effects.sort_by(|a, b| a.total_cmp(b));
317 let lo_idx = math::round((N_MONTE_CARLO as f64) * 0.025) as usize;
318 let hi_idx = math::round((N_MONTE_CARLO as f64) * 0.975) as usize;
319 let ci = (
320 max_effects[lo_idx.min(N_MONTE_CARLO - 1)],
321 max_effects[hi_idx.min(N_MONTE_CARLO - 1)],
322 );
323
324 MaxEffectCI { mean, ci }
325}
326
327pub fn compute_quantile_exceedances(
331 delta_post: &Vector9,
332 lambda_post: &Matrix9,
333 theta: f64,
334) -> [f64; 9] {
335 let mut exceedances = [0.0; 9];
336 for k in 0..9 {
337 let mu = delta_post[k];
338 let sigma = math::sqrt(lambda_post[(k, k)].max(1e-12));
339 exceedances[k] = compute_single_quantile_exceedance(mu, sigma, theta);
340 }
341 exceedances
342}
343
344fn compute_single_quantile_exceedance(mu: f64, sigma: f64, theta: f64) -> f64 {
346 if sigma < 1e-12 {
347 return if mu.abs() > theta { 1.0 } else { 0.0 };
349 }
350 let phi_upper = math::normal_cdf((theta - mu) / sigma);
351 let phi_lower = math::normal_cdf((-theta - mu) / sigma);
352 1.0 - (phi_upper - phi_lower)
353}
354
355pub fn build_design_matrix() -> Matrix9x2 {
361 let mut x = Matrix9x2::zeros();
362 for i in 0..9 {
363 x[(i, 0)] = ONES[i];
364 x[(i, 1)] = B_TAIL[i];
365 }
366 x
367}
368
369fn add_jitter(mut sigma: Matrix9) -> Matrix9 {
373 let trace: f64 = (0..9).map(|i| sigma[(i, i)]).sum();
374 let mean_var = trace / 9.0;
375
376 let min_var = (0.01 * mean_var).max(1e-10);
377 let jitter = 1e-10 + mean_var * 1e-8;
378
379 for i in 0..9 {
380 sigma[(i, i)] = sigma[(i, i)].max(min_var) + jitter;
381 }
382 sigma
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 #[test]
390 fn test_2d_projection_uniform_shift() {
391 let mut delta_post = Vector9::zeros();
393 for i in 0..9 {
394 delta_post[i] = 50.0;
395 }
396 let lambda_post = Matrix9::identity();
397 let sigma_n = Matrix9::identity();
398
399 let (beta_proj, _, q_proj) = compute_2d_projection(&delta_post, &lambda_post, &sigma_n);
400
401 assert!(
403 (beta_proj[0] - 50.0).abs() < 1.0,
404 "Shift should be ~50, got {}",
405 beta_proj[0]
406 );
407 assert!(
408 beta_proj[1].abs() < 5.0,
409 "Tail should be ~0, got {}",
410 beta_proj[1]
411 );
412 assert!(
413 q_proj < 1.0,
414 "Uniform shift should have low Q, got {}",
415 q_proj
416 );
417 }
418
419 #[test]
420 fn test_quantile_exceedance_computation() {
421 let mu = 100.0;
422 let sigma = 10.0;
423 let theta = 50.0;
424
425 let exceedance = compute_single_quantile_exceedance(mu, sigma, theta);
426
427 assert!(
429 exceedance > 0.99,
430 "With μ=100, θ=50, exceedance should be ~1.0, got {}",
431 exceedance
432 );
433 }
434
435 #[test]
436 fn test_quantile_exceedance_symmetric() {
437 let sigma = 10.0;
438 let theta = 50.0;
439
440 let exc_pos = compute_single_quantile_exceedance(30.0, sigma, theta);
441 let exc_neg = compute_single_quantile_exceedance(-30.0, sigma, theta);
442
443 assert!(
445 (exc_pos - exc_neg).abs() < 0.01,
446 "Exceedance should be symmetric, got {} vs {}",
447 exc_pos,
448 exc_neg
449 );
450 }
451
452 #[test]
453 fn test_gibbs_determinism() {
454 use crate::adaptive::calibrate_t_prior_scale;
455
456 let delta = Vector9::from_row_slice(&[5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0]);
457 let sigma_n = Matrix9::identity() * 100.0;
458 let sigma_rate = sigma_n * 1000.0;
459 let theta = 10.0;
460
461 let (sigma_t, l_r) = calibrate_t_prior_scale(&sigma_rate, theta, 1000, false, 42);
462
463 let result1 = compute_bayes_gibbs(&delta, &sigma_n, sigma_t, &l_r, theta, Some(42));
464 let result2 = compute_bayes_gibbs(&delta, &sigma_n, sigma_t, &l_r, theta, Some(42));
465
466 assert_eq!(
467 result1.leak_probability, result2.leak_probability,
468 "Same seed should give same result"
469 );
470 }
471}