Skip to main content

tacet_core/analysis/
effect.rs

1//! Effect estimation from posterior samples (spec §5.2).
2//!
3//! This module computes effect estimates from the 9D posterior over quantile
4//! differences. The primary metric is max_k |δ_k| - the maximum absolute
5//! effect across all deciles.
6//!
7//! ## Effect Reporting (spec §5.2)
8//!
9//! - `max_effect_ns`: Posterior mean of max_k |δ_k|
10//! - `credible_interval_ns`: 95% CI for max|δ|
11//! - `top_quantiles`: Top 2-3 quantiles by exceedance probability
12
13extern crate alloc;
14
15use alloc::vec::Vec;
16
17use crate::constants::DECILES;
18use crate::math;
19use crate::result::{EffectEstimate, TopQuantile};
20use crate::types::{Matrix9, Vector9};
21
22/// Per-quantile statistics: (index, quantile_p, mean_ns, ci95_ns, exceed_prob).
23type QuantileStats = (usize, f64, f64, (f64, f64), f64);
24
25/// Compute effect estimate from delta draws (spec §5.2).
26///
27/// Takes posterior samples of the 9D effect vector δ and computes:
28/// - max_effect_ns: posterior mean of max_k |δ_k|
29/// - credible_interval_ns: 95% CI for max|δ|
30/// - top_quantiles: top 2-3 quantiles by exceedance probability
31///
32/// # Arguments
33///
34/// * `delta_draws` - Posterior samples of δ ∈ ℝ⁹
35/// * `theta` - Threshold for exceedance probability computation
36///
37/// # Returns
38///
39/// An `EffectEstimate` with max effect and top quantiles.
40pub fn compute_effect_estimate(delta_draws: &[Vector9], theta: f64) -> EffectEstimate {
41    if delta_draws.is_empty() {
42        return EffectEstimate::default();
43    }
44
45    let n = delta_draws.len();
46
47    // Compute max|δ| for each draw
48    let mut max_effects: Vec<f64> = Vec::with_capacity(n);
49    for delta in delta_draws {
50        let max_abs = delta.iter().map(|x| x.abs()).fold(0.0_f64, f64::max);
51        max_effects.push(max_abs);
52    }
53
54    // Posterior mean of max|δ|
55    let max_effect_ns = max_effects.iter().sum::<f64>() / n as f64;
56
57    // 95% credible interval (2.5th and 97.5th percentiles)
58    max_effects.sort_by(|a, b| a.total_cmp(b));
59    let lo_idx = ((n as f64 * 0.025).round() as usize).min(n - 1);
60    let hi_idx = ((n as f64 * 0.975).round() as usize).min(n - 1);
61    let credible_interval_ns = (max_effects[lo_idx], max_effects[hi_idx]);
62
63    // Compute top quantiles by exceedance probability
64    let top_quantiles = compute_top_quantiles(delta_draws, theta);
65
66    EffectEstimate {
67        max_effect_ns,
68        credible_interval_ns,
69        top_quantiles,
70    }
71}
72
73/// Compute top 2-3 quantiles by exceedance probability.
74///
75/// For each quantile k, computes:
76/// - mean_ns: posterior mean δ_k
77/// - ci95_ns: 95% marginal CI for δ_k
78/// - exceed_prob: P(|δ_k| > θ | data)
79///
80/// Returns the top quantiles (up to 3) with exceed_prob > 0.5.
81pub fn compute_top_quantiles(delta_draws: &[Vector9], theta: f64) -> Vec<TopQuantile> {
82    if delta_draws.is_empty() {
83        return Vec::new();
84    }
85
86    let n = delta_draws.len();
87
88    // Compute per-quantile statistics
89    let mut quantile_stats: Vec<QuantileStats> = Vec::with_capacity(9);
90
91    for k in 0..9 {
92        // Extract draws for quantile k
93        let mut values: Vec<f64> = delta_draws.iter().map(|d| d[k]).collect();
94
95        // Mean
96        let mean = values.iter().sum::<f64>() / n as f64;
97
98        // 95% CI
99        values.sort_by(|a, b| a.total_cmp(b));
100        let lo_idx = ((n as f64 * 0.025).round() as usize).min(n - 1);
101        let hi_idx = ((n as f64 * 0.975).round() as usize).min(n - 1);
102        let ci = (values[lo_idx], values[hi_idx]);
103
104        // Exceedance probability: P(|δ_k| > θ | data)
105        let exceed_count = delta_draws.iter().filter(|d| d[k].abs() > theta).count();
106        let exceed_prob = exceed_count as f64 / n as f64;
107
108        quantile_stats.push((k, DECILES[k], mean, ci, exceed_prob));
109    }
110
111    // Sort by exceedance probability (descending)
112    quantile_stats.sort_by(|a, b| b.4.total_cmp(&a.4));
113
114    // Take top 2-3 with exceed_prob > 0.5
115    quantile_stats
116        .into_iter()
117        .filter(|(_, _, _, _, exceed_prob)| *exceed_prob > 0.5)
118        .take(3)
119        .map(
120            |(_, quantile_p, mean_ns, ci95_ns, exceed_prob)| TopQuantile {
121                quantile_p,
122                mean_ns,
123                ci95_ns,
124                exceed_prob,
125            },
126        )
127        .collect()
128}
129
130/// Compute effect estimate from posterior mean and covariance (analytical).
131///
132/// This is a faster alternative to `compute_effect_estimate` when only the
133/// posterior mean and covariance are available (no draws).
134///
135/// # Arguments
136///
137/// * `delta_post` - Posterior mean δ_post
138/// * `lambda_post` - Posterior covariance Λ_post
139/// * `theta` - Threshold for exceedance probability
140///
141/// # Returns
142///
143/// An `EffectEstimate` with approximate max effect (uses mean of |δ_post|).
144pub fn compute_effect_estimate_analytical(
145    delta_post: &Vector9,
146    lambda_post: &Matrix9,
147    theta: f64,
148) -> EffectEstimate {
149    // Max absolute effect from posterior mean
150    let max_effect_ns = delta_post.iter().map(|x| x.abs()).fold(0.0_f64, f64::max);
151
152    // Approximate CI using marginal variances
153    // This is a rough approximation - the true CI requires sampling
154    let max_k = delta_post
155        .iter()
156        .enumerate()
157        .max_by(|(_, a), (_, b)| a.abs().total_cmp(&b.abs()))
158        .map(|(k, _)| k)
159        .unwrap_or(0);
160
161    let se = math::sqrt(lambda_post[(max_k, max_k)].max(1e-12));
162    let ci_low = (max_effect_ns - 1.96 * se).max(0.0);
163    let ci_high = max_effect_ns + 1.96 * se;
164
165    // Compute top quantiles analytically
166    let top_quantiles = compute_top_quantiles_analytical(delta_post, lambda_post, theta);
167
168    EffectEstimate {
169        max_effect_ns,
170        credible_interval_ns: (ci_low, ci_high),
171        top_quantiles,
172    }
173}
174
175/// Compute top quantiles analytically from posterior mean and covariance.
176fn compute_top_quantiles_analytical(
177    delta_post: &Vector9,
178    lambda_post: &Matrix9,
179    theta: f64,
180) -> Vec<TopQuantile> {
181    let mut quantile_stats: Vec<QuantileStats> = Vec::with_capacity(9);
182
183    for k in 0..9 {
184        let mean = delta_post[k];
185        let se = math::sqrt(lambda_post[(k, k)].max(1e-12));
186
187        // 95% CI
188        let ci = (mean - 1.96 * se, mean + 1.96 * se);
189
190        // Exceedance probability: P(|δ_k| > θ)
191        // For Gaussian N(μ, σ²): P(|X| > θ) = 1 - Φ((θ-μ)/σ) + Φ((-θ-μ)/σ)
192        let exceed_prob = compute_exceedance_prob(mean, se, theta);
193
194        quantile_stats.push((k, DECILES[k], mean, ci, exceed_prob));
195    }
196
197    // Sort by exceedance probability (descending)
198    quantile_stats.sort_by(|a, b| b.4.total_cmp(&a.4));
199
200    // Take top 2-3 with exceed_prob > 0.5
201    quantile_stats
202        .into_iter()
203        .filter(|(_, _, _, _, exceed_prob)| *exceed_prob > 0.5)
204        .take(3)
205        .map(
206            |(_, quantile_p, mean_ns, ci95_ns, exceed_prob)| TopQuantile {
207                quantile_p,
208                mean_ns,
209                ci95_ns,
210                exceed_prob,
211            },
212        )
213        .collect()
214}
215
216/// Compute P(|X| > θ) for X ~ N(μ, σ²).
217fn compute_exceedance_prob(mu: f64, sigma: f64, theta: f64) -> f64 {
218    if sigma < 1e-12 {
219        // Degenerate case
220        return if mu.abs() > theta { 1.0 } else { 0.0 };
221    }
222    let phi_upper = math::normal_cdf((theta - mu) / sigma);
223    let phi_lower = math::normal_cdf((-theta - mu) / sigma);
224    1.0 - (phi_upper - phi_lower)
225}
226
227/// Apply variance floor regularization for numerical stability (spec §3.3.2).
228///
229/// When some quantiles have zero or near-zero variance (common in discrete mode
230/// with ties), the covariance matrix becomes ill-conditioned.
231///
232/// Formula (spec §3.3.2):
233///   σ²ᵢ ← max(σ²ᵢ, 0.01 × σ̄²) + ε
234/// where σ̄² = tr(Σ)/9 and ε = 10⁻¹⁰ + σ̄² × 10⁻⁸
235pub fn regularize_covariance(sigma: &Matrix9) -> Matrix9 {
236    let trace: f64 = (0..9).map(|i| sigma[(i, i)]).sum();
237    let mean_var = trace / 9.0;
238
239    // Use 1% of mean variance as floor, with absolute minimum of 1e-10
240    let min_var = (0.01 * mean_var).max(1e-10);
241
242    // Also add small jitter proportional to scale for numerical stability
243    let jitter = 1e-10 + mean_var * 1e-8;
244
245    let mut regularized = *sigma;
246    for i in 0..9 {
247        regularized[(i, i)] = regularized[(i, i)].max(min_var) + jitter;
248    }
249    regularized
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_effect_estimate_basic() {
258        // Create some sample draws
259        let draws: Vec<Vector9> = (0..100)
260            .map(|i| {
261                let val = (i as f64) * 0.1;
262                Vector9::from_row_slice(&[val, val, val, val, val, val, val, val, val])
263            })
264            .collect();
265
266        let estimate = compute_effect_estimate(&draws, 5.0);
267
268        // Max effect should be around 9.9 (max draw is 99 * 0.1 = 9.9)
269        assert!(
270            estimate.max_effect_ns > 4.0,
271            "max effect should be significant"
272        );
273        assert!(
274            estimate.credible_interval_ns.0 < estimate.max_effect_ns,
275            "CI lower should be below mean"
276        );
277        assert!(
278            estimate.credible_interval_ns.1 > estimate.max_effect_ns,
279            "CI upper should be above mean"
280        );
281    }
282
283    #[test]
284    fn test_effect_estimate_empty() {
285        let estimate = compute_effect_estimate(&[], 5.0);
286        assert_eq!(estimate.max_effect_ns, 0.0);
287        assert!(estimate.top_quantiles.is_empty());
288    }
289
290    #[test]
291    fn test_top_quantiles_threshold() {
292        // Create draws where only the 90th percentile exceeds threshold
293        let draws: Vec<Vector9> = (0..100)
294            .map(|_| Vector9::from_row_slice(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 10.0]))
295            .collect();
296
297        let top = compute_top_quantiles(&draws, 5.0);
298
299        // Should have the 90th percentile (index 8, quantile_p = 0.9)
300        assert!(!top.is_empty());
301        assert!((top[0].quantile_p - 0.9).abs() < 0.01);
302        assert!(top[0].exceed_prob > 0.99);
303    }
304
305    #[test]
306    fn test_regularize_covariance() {
307        let mut sigma = Matrix9::zeros();
308        for i in 0..9 {
309            sigma[(i, i)] = if i == 0 { 0.0 } else { 1.0 }; // First diagonal is zero
310        }
311
312        let regularized = regularize_covariance(&sigma);
313
314        // All diagonal elements should be positive
315        for i in 0..9 {
316            assert!(
317                regularized[(i, i)] > 0.0,
318                "diagonal {} should be positive",
319                i
320            );
321        }
322    }
323
324    #[test]
325    fn test_exceedance_prob() {
326        // Large mean should have high exceedance
327        let prob_high = compute_exceedance_prob(100.0, 10.0, 50.0);
328        assert!(prob_high > 0.99, "large mean should exceed threshold");
329
330        // Small mean should have low exceedance for high threshold
331        let prob_low = compute_exceedance_prob(1.0, 1.0, 50.0);
332        assert!(prob_low < 0.01, "small mean should not exceed threshold");
333
334        // Zero mean, threshold equals 2σ -> ~5% exceedance
335        let prob_2sigma = compute_exceedance_prob(0.0, 1.0, 2.0);
336        assert!(
337            (prob_2sigma - 0.0455).abs() < 0.01,
338            "2σ threshold should have ~4.5% exceedance"
339        );
340    }
341}