Skip to main content

shape_runtime/
multiple_testing.rs

1//! Multiple testing corrections and warnings
2//!
3//! This module provides tools to track the number of parameter combinations
4//! tested during optimization and warn about overfitting risks.
5
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8
9/// Multiple testing correction methods
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
11pub enum CorrectionMethod {
12    /// Bonferroni correction (most conservative)
13    #[default]
14    Bonferroni,
15    /// Holm-Bonferroni step-down procedure
16    HolmBonferroni,
17    /// Benjamini-Hochberg False Discovery Rate
18    BenjaminiHochberg,
19    /// No correction applied
20    None,
21}
22
23/// Warning severity level
24#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
25pub enum WarningLevel {
26    /// No warning needed
27    None = 0,
28    /// Informational (< 50 combinations)
29    Info = 1,
30    /// Caution advised (50-199 combinations)
31    Caution = 2,
32    /// Warning (200-999 combinations)
33    Warning = 3,
34    /// Critical overfitting risk (1000+ combinations)
35    Critical = 4,
36}
37
38impl WarningLevel {
39    /// Get a human-readable description
40    pub fn description(&self) -> &'static str {
41        match self {
42            WarningLevel::None => "No multiple testing concerns",
43            WarningLevel::Info => "Low risk - consider walk-forward validation",
44            WarningLevel::Caution => "Moderate risk - walk-forward analysis recommended",
45            WarningLevel::Warning => "High risk - walk-forward analysis strongly recommended",
46            WarningLevel::Critical => {
47                "Critical overfitting risk - results may be meaningless without validation"
48            }
49        }
50    }
51}
52
53/// Statistics about multiple testing
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct MultipleTestingStats {
56    /// Number of parameter combinations tested
57    pub n_tests: usize,
58
59    /// Original significance level (alpha)
60    pub alpha: f64,
61
62    /// Adjusted significance level after correction
63    pub adjusted_alpha: f64,
64
65    /// Correction method applied
66    pub method: CorrectionMethod,
67
68    /// Warning level based on number of tests
69    pub warning_level: WarningLevel,
70
71    /// Human-readable warning message (if any)
72    pub warning_message: Option<String>,
73
74    /// Whether the user explicitly accepted overfitting risk
75    pub risk_accepted: bool,
76}
77
78impl MultipleTestingStats {
79    /// Convert to a ValueWord TypedObject for Shape
80    pub fn to_value(&self) -> shape_value::ValueWord {
81        use shape_value::ValueWord;
82
83        let warning_msg = self
84            .warning_message
85            .clone()
86            .map(|s| ValueWord::from_string(Arc::new(s)))
87            .unwrap_or(ValueWord::none());
88
89        crate::type_schema::typed_object_from_nb_pairs(&[
90            ("n_tests", ValueWord::from_f64(self.n_tests as f64)),
91            ("alpha", ValueWord::from_f64(self.alpha)),
92            ("adjusted_alpha", ValueWord::from_f64(self.adjusted_alpha)),
93            (
94                "method",
95                ValueWord::from_string(Arc::new(format!("{:?}", self.method))),
96            ),
97            (
98                "warning_level",
99                ValueWord::from_string(Arc::new(format!("{:?}", self.warning_level))),
100            ),
101            ("warning_message", warning_msg),
102            ("risk_accepted", ValueWord::from_bool(self.risk_accepted)),
103        ])
104    }
105}
106
107/// Guard that tracks and warns about multiple testing
108#[derive(Debug, Clone)]
109pub struct MultipleTestingGuard {
110    /// Number of combinations tested so far
111    combinations_tested: usize,
112
113    /// Base significance level
114    alpha: f64,
115
116    /// Correction method to use
117    method: CorrectionMethod,
118
119    /// Whether user has explicitly accepted overfitting risk
120    accept_overfitting_risk: bool,
121
122    /// Threshold for caution warning
123    _caution_threshold: usize,
124
125    /// Threshold for warning
126    _warning_threshold: usize,
127
128    /// Threshold for critical warning
129    _critical_threshold: usize,
130}
131
132impl Default for MultipleTestingGuard {
133    fn default() -> Self {
134        Self::new(0.05)
135    }
136}
137
138impl MultipleTestingGuard {
139    /// Create a new guard with given significance level
140    pub fn new(alpha: f64) -> Self {
141        Self {
142            combinations_tested: 0,
143            alpha,
144            method: CorrectionMethod::Bonferroni,
145            accept_overfitting_risk: false,
146            _caution_threshold: 50,
147            _warning_threshold: 200,
148            _critical_threshold: 1000,
149        }
150    }
151
152    /// Set the correction method
153    pub fn with_method(mut self, method: CorrectionMethod) -> Self {
154        self.method = method;
155        self
156    }
157
158    /// Record that N combinations were tested
159    pub fn record_tests(&mut self, n: usize) {
160        self.combinations_tested += n;
161    }
162
163    /// Get the number of combinations tested
164    pub fn combinations_tested(&self) -> usize {
165        self.combinations_tested
166    }
167
168    /// Suppress warnings (user explicitly accepts risk)
169    pub fn accept_risk(&mut self) {
170        self.accept_overfitting_risk = true;
171    }
172
173    /// Check if risk has been accepted
174    pub fn is_risk_accepted(&self) -> bool {
175        self.accept_overfitting_risk
176    }
177
178    /// Calculate the adjusted alpha based on correction method
179    pub fn adjusted_alpha(&self) -> f64 {
180        if self.combinations_tested == 0 {
181            return self.alpha;
182        }
183
184        match self.method {
185            CorrectionMethod::Bonferroni => self.alpha / self.combinations_tested as f64,
186            CorrectionMethod::HolmBonferroni => {
187                // For Holm-Bonferroni, the adjusted alpha for the first test
188                // is alpha/n, for the second alpha/(n-1), etc.
189                // We return the most stringent (first test) here
190                self.alpha / self.combinations_tested as f64
191            }
192            CorrectionMethod::BenjaminiHochberg => {
193                // FDR control - less conservative than Bonferroni
194                // Roughly alpha * (k/n) for the k-th smallest p-value
195                // We return a representative value
196                self.alpha * 0.5 / self.combinations_tested as f64
197            }
198            CorrectionMethod::None => self.alpha,
199        }
200    }
201
202    /// Determine warning level based on combinations tested
203    pub fn warning_level(&self) -> WarningLevel {
204        match self.combinations_tested {
205            0..=49 => WarningLevel::None,
206            50..=199 => WarningLevel::Info,
207            200..=999 => WarningLevel::Caution,
208            _ => WarningLevel::Critical,
209        }
210    }
211
212    /// Get current statistics and warnings
213    pub fn get_stats(&self) -> MultipleTestingStats {
214        let warning_level = self.warning_level();
215        let adjusted_alpha = self.adjusted_alpha();
216
217        let warning_message = if self.accept_overfitting_risk {
218            None
219        } else {
220            self.generate_warning_message(warning_level, adjusted_alpha)
221        };
222
223        MultipleTestingStats {
224            n_tests: self.combinations_tested,
225            alpha: self.alpha,
226            adjusted_alpha,
227            method: self.method,
228            warning_level,
229            warning_message,
230            risk_accepted: self.accept_overfitting_risk,
231        }
232    }
233
234    /// Generate a warning message based on severity
235    fn generate_warning_message(&self, level: WarningLevel, adjusted_alpha: f64) -> Option<String> {
236        match level {
237            WarningLevel::None => None,
238            WarningLevel::Info => Some(format!(
239                "INFO: {} parameter combinations tested. Consider walk-forward validation.",
240                self.combinations_tested
241            )),
242            WarningLevel::Caution => Some(format!(
243                "CAUTION: {} parameter combinations tested. \
244                 Bonferroni-adjusted alpha: {:.6}. \
245                 Walk-forward analysis recommended.",
246                self.combinations_tested, adjusted_alpha
247            )),
248            WarningLevel::Warning | WarningLevel::Critical => Some(format!(
249                "WARNING: {} parameter combinations tested without walk-forward analysis.\n\
250                 Bonferroni-adjusted alpha: {:.2e}\n\
251                 This many tests dramatically increases false discovery risk.\n\n\
252                 To address this:\n\
253                 1. Use walk-forward analysis: `walk_forward: {{ ... }}`\n\
254                 2. Or explicitly accept risk: `accept_overfitting_risk: true`",
255                self.combinations_tested, adjusted_alpha
256            )),
257        }
258    }
259
260    /// Print warning to stderr if needed
261    pub fn emit_warning_if_needed(&self) {
262        if self.accept_overfitting_risk {
263            return;
264        }
265
266        let stats = self.get_stats();
267        if let Some(msg) = &stats.warning_message {
268            if stats.warning_level >= WarningLevel::Caution {
269                eprintln!("\n{}\n", msg);
270            }
271        }
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn test_warning_levels() {
281        let mut guard = MultipleTestingGuard::new(0.05);
282
283        assert_eq!(guard.warning_level(), WarningLevel::None);
284
285        guard.record_tests(50);
286        assert_eq!(guard.warning_level(), WarningLevel::Info);
287
288        guard.record_tests(150);
289        assert_eq!(guard.warning_level(), WarningLevel::Caution);
290
291        guard.record_tests(800);
292        assert_eq!(guard.warning_level(), WarningLevel::Critical);
293    }
294
295    #[test]
296    fn test_bonferroni_correction() {
297        let mut guard = MultipleTestingGuard::new(0.05);
298        guard.record_tests(100);
299
300        let adjusted = guard.adjusted_alpha();
301        assert!((adjusted - 0.0005).abs() < 1e-10);
302    }
303
304    #[test]
305    fn test_accept_risk_suppresses_warning() {
306        let mut guard = MultipleTestingGuard::new(0.05);
307        guard.record_tests(500);
308        guard.accept_risk();
309
310        let stats = guard.get_stats();
311        assert!(stats.warning_message.is_none());
312        assert!(stats.risk_accepted);
313    }
314}