1extern crate alloc;
36
37use alloc::vec::Vec;
38use core::f64::consts::PI;
39
40use nalgebra::Cholesky;
41use rand::prelude::*;
42use rand::SeedableRng;
43use rand_xoshiro::Xoshiro256PlusPlus;
44
45use crate::math;
46use crate::types::{Matrix9, Vector9};
47
48const N_MONTE_CARLO: usize = 1000;
50
51#[derive(Debug, Clone)]
56pub struct BayesResult {
57 pub leak_probability: f64,
59
60 pub delta_post: Vector9,
62
63 pub lambda_post: Matrix9,
65
66 pub delta_draws: Vec<Vector9>,
69
70 pub effect_magnitude_ci: (f64, f64),
72
73 pub is_clamped: bool,
76
77 pub sigma_n: Matrix9,
79
80 pub lambda_mean: f64,
84
85 pub lambda_sd: f64,
87
88 pub lambda_cv: f64,
90
91 pub lambda_ess: f64,
93
94 pub lambda_mixing_ok: bool,
96
97 pub kappa_mean: f64,
100
101 pub kappa_sd: f64,
103
104 pub kappa_cv: f64,
106
107 pub kappa_ess: f64,
109
110 pub kappa_mixing_ok: bool,
112}
113
114pub fn compute_bayes_gibbs(
133 delta: &Vector9,
134 sigma_n: &Matrix9,
135 sigma_t: f64,
136 l_r: &Matrix9,
137 theta: f64,
138 seed: Option<u64>,
139) -> BayesResult {
140 use super::gibbs::run_gibbs_inference;
141
142 let regularized = add_jitter(*sigma_n);
143 let actual_seed = seed.unwrap_or(crate::constants::DEFAULT_SEED);
144
145 let gibbs_result = run_gibbs_inference(delta, ®ularized, sigma_t, l_r, theta, actual_seed);
147
148 BayesResult {
149 leak_probability: gibbs_result.leak_probability,
150 delta_post: gibbs_result.delta_post,
151 lambda_post: gibbs_result.lambda_post,
152 delta_draws: gibbs_result.delta_draws,
153 effect_magnitude_ci: gibbs_result.effect_magnitude_ci,
154 is_clamped: false,
155 sigma_n: regularized,
156 lambda_mean: gibbs_result.lambda_mean,
158 lambda_sd: gibbs_result.lambda_sd,
159 lambda_cv: gibbs_result.lambda_cv,
160 lambda_ess: gibbs_result.lambda_ess,
161 lambda_mixing_ok: gibbs_result.lambda_mixing_ok,
162 kappa_mean: gibbs_result.kappa_mean,
164 kappa_sd: gibbs_result.kappa_sd,
165 kappa_cv: gibbs_result.kappa_cv,
166 kappa_ess: gibbs_result.kappa_ess,
167 kappa_mixing_ok: gibbs_result.kappa_mixing_ok,
168 }
169}
170
171fn sample_standard_normal<R: Rng>(rng: &mut R) -> f64 {
173 let u1: f64 = rng.random();
174 let u2: f64 = rng.random();
175 math::sqrt(-2.0 * math::ln(u1)) * math::cos(2.0 * PI * u2)
176}
177
178#[derive(Debug, Clone)]
180pub struct MaxEffectCI {
181 pub mean: f64,
183 pub ci: (f64, f64),
185}
186
187pub fn compute_max_effect_ci(
191 delta_post: &Vector9,
192 lambda_post: &Matrix9,
193 seed: u64,
194) -> MaxEffectCI {
195 let chol = match Cholesky::new(*lambda_post) {
196 Some(c) => c,
197 None => {
198 let jittered = add_jitter(*lambda_post);
199 match Cholesky::new(jittered) {
200 Some(c) => c,
201 None => {
202 return MaxEffectCI {
203 mean: 0.0,
204 ci: (0.0, 0.0),
205 };
206 }
207 }
208 }
209 };
210 let l = chol.l();
211
212 let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
213 let mut max_effects = Vec::with_capacity(N_MONTE_CARLO);
214 let mut sum = 0.0;
215
216 for _ in 0..N_MONTE_CARLO {
217 let mut z = Vector9::zeros();
218 for i in 0..9 {
219 z[i] = sample_standard_normal(&mut rng);
220 }
221
222 let delta_sample = delta_post + l * z;
223 let max_effect = delta_sample.iter().map(|x| x.abs()).fold(0.0_f64, f64::max);
224 max_effects.push(max_effect);
225 sum += max_effect;
226 }
227
228 let mean = sum / N_MONTE_CARLO as f64;
229
230 max_effects.sort_by(|a, b| a.total_cmp(b));
231 let lo_idx = math::round((N_MONTE_CARLO as f64) * 0.025) as usize;
232 let hi_idx = math::round((N_MONTE_CARLO as f64) * 0.975) as usize;
233 let ci = (
234 max_effects[lo_idx.min(N_MONTE_CARLO - 1)],
235 max_effects[hi_idx.min(N_MONTE_CARLO - 1)],
236 );
237
238 MaxEffectCI { mean, ci }
239}
240
241pub fn add_jitter(mut sigma: Matrix9) -> Matrix9 {
245 let trace: f64 = (0..9).map(|i| sigma[(i, i)]).sum();
246 let mean_var = trace / 9.0;
247
248 let min_var = (0.01 * mean_var).max(1e-10);
249 let jitter = 1e-10 + mean_var * 1e-8;
250
251 for i in 0..9 {
252 sigma[(i, i)] = sigma[(i, i)].max(min_var) + jitter;
253 }
254 sigma
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 #[test]
262 fn test_max_effect_ci_basic() {
263 let delta_post = Vector9::from_row_slice(&[10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
264 let lambda_post = Matrix9::identity() * 1.0;
265
266 let ci = compute_max_effect_ci(&delta_post, &lambda_post, 42);
267
268 assert!(ci.mean > 8.0, "mean should be around 10, got {}", ci.mean);
270 assert!(ci.ci.0 < ci.mean, "CI lower should be below mean");
271 assert!(ci.ci.1 > ci.mean, "CI upper should be above mean");
272 }
273
274 #[test]
275 fn test_gibbs_determinism() {
276 use crate::adaptive::calibrate_t_prior_scale;
277
278 let delta = Vector9::from_row_slice(&[5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0]);
279 let sigma_n = Matrix9::identity() * 100.0;
280 let sigma_rate = sigma_n * 1000.0;
281 let theta = 10.0;
282
283 let (sigma_t, l_r) = calibrate_t_prior_scale(&sigma_rate, theta, 1000, false, 42);
284
285 let result1 = compute_bayes_gibbs(&delta, &sigma_n, sigma_t, &l_r, theta, Some(42));
286 let result2 = compute_bayes_gibbs(&delta, &sigma_n, sigma_t, &l_r, theta, Some(42));
287
288 assert_eq!(
289 result1.leak_probability, result2.leak_probability,
290 "Same seed should give same result"
291 );
292 }
293}