1use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
11pub enum CorrectionMethod {
12 #[default]
14 Bonferroni,
15 HolmBonferroni,
17 BenjaminiHochberg,
19 None,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
25pub enum WarningLevel {
26 None = 0,
28 Info = 1,
30 Caution = 2,
32 Warning = 3,
34 Critical = 4,
36}
37
38impl WarningLevel {
39 pub fn description(&self) -> &'static str {
41 match self {
42 WarningLevel::None => "No multiple testing concerns",
43 WarningLevel::Info => "Low risk - consider walk-forward validation",
44 WarningLevel::Caution => "Moderate risk - walk-forward analysis recommended",
45 WarningLevel::Warning => "High risk - walk-forward analysis strongly recommended",
46 WarningLevel::Critical => {
47 "Critical overfitting risk - results may be meaningless without validation"
48 }
49 }
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct MultipleTestingStats {
56 pub n_tests: usize,
58
59 pub alpha: f64,
61
62 pub adjusted_alpha: f64,
64
65 pub method: CorrectionMethod,
67
68 pub warning_level: WarningLevel,
70
71 pub warning_message: Option<String>,
73
74 pub risk_accepted: bool,
76}
77
78impl MultipleTestingStats {
79 pub fn to_value(&self) -> shape_value::ValueWord {
81 use shape_value::ValueWord;
82
83 let warning_msg = self
84 .warning_message
85 .clone()
86 .map(|s| ValueWord::from_string(Arc::new(s)))
87 .unwrap_or(ValueWord::none());
88
89 crate::type_schema::typed_object_from_nb_pairs(&[
90 ("n_tests", ValueWord::from_f64(self.n_tests as f64)),
91 ("alpha", ValueWord::from_f64(self.alpha)),
92 ("adjusted_alpha", ValueWord::from_f64(self.adjusted_alpha)),
93 (
94 "method",
95 ValueWord::from_string(Arc::new(format!("{:?}", self.method))),
96 ),
97 (
98 "warning_level",
99 ValueWord::from_string(Arc::new(format!("{:?}", self.warning_level))),
100 ),
101 ("warning_message", warning_msg),
102 ("risk_accepted", ValueWord::from_bool(self.risk_accepted)),
103 ])
104 }
105}
106
107#[derive(Debug, Clone)]
109pub struct MultipleTestingGuard {
110 combinations_tested: usize,
112
113 alpha: f64,
115
116 method: CorrectionMethod,
118
119 accept_overfitting_risk: bool,
121
122 _caution_threshold: usize,
124
125 _warning_threshold: usize,
127
128 _critical_threshold: usize,
130}
131
132impl Default for MultipleTestingGuard {
133 fn default() -> Self {
134 Self::new(0.05)
135 }
136}
137
138impl MultipleTestingGuard {
139 pub fn new(alpha: f64) -> Self {
141 Self {
142 combinations_tested: 0,
143 alpha,
144 method: CorrectionMethod::Bonferroni,
145 accept_overfitting_risk: false,
146 _caution_threshold: 50,
147 _warning_threshold: 200,
148 _critical_threshold: 1000,
149 }
150 }
151
152 pub fn with_method(mut self, method: CorrectionMethod) -> Self {
154 self.method = method;
155 self
156 }
157
158 pub fn record_tests(&mut self, n: usize) {
160 self.combinations_tested += n;
161 }
162
163 pub fn combinations_tested(&self) -> usize {
165 self.combinations_tested
166 }
167
168 pub fn accept_risk(&mut self) {
170 self.accept_overfitting_risk = true;
171 }
172
173 pub fn is_risk_accepted(&self) -> bool {
175 self.accept_overfitting_risk
176 }
177
178 pub fn adjusted_alpha(&self) -> f64 {
180 if self.combinations_tested == 0 {
181 return self.alpha;
182 }
183
184 match self.method {
185 CorrectionMethod::Bonferroni => self.alpha / self.combinations_tested as f64,
186 CorrectionMethod::HolmBonferroni => {
187 self.alpha / self.combinations_tested as f64
191 }
192 CorrectionMethod::BenjaminiHochberg => {
193 self.alpha * 0.5 / self.combinations_tested as f64
197 }
198 CorrectionMethod::None => self.alpha,
199 }
200 }
201
202 pub fn warning_level(&self) -> WarningLevel {
204 match self.combinations_tested {
205 0..=49 => WarningLevel::None,
206 50..=199 => WarningLevel::Info,
207 200..=999 => WarningLevel::Caution,
208 _ => WarningLevel::Critical,
209 }
210 }
211
212 pub fn get_stats(&self) -> MultipleTestingStats {
214 let warning_level = self.warning_level();
215 let adjusted_alpha = self.adjusted_alpha();
216
217 let warning_message = if self.accept_overfitting_risk {
218 None
219 } else {
220 self.generate_warning_message(warning_level, adjusted_alpha)
221 };
222
223 MultipleTestingStats {
224 n_tests: self.combinations_tested,
225 alpha: self.alpha,
226 adjusted_alpha,
227 method: self.method,
228 warning_level,
229 warning_message,
230 risk_accepted: self.accept_overfitting_risk,
231 }
232 }
233
234 fn generate_warning_message(&self, level: WarningLevel, adjusted_alpha: f64) -> Option<String> {
236 match level {
237 WarningLevel::None => None,
238 WarningLevel::Info => Some(format!(
239 "INFO: {} parameter combinations tested. Consider walk-forward validation.",
240 self.combinations_tested
241 )),
242 WarningLevel::Caution => Some(format!(
243 "CAUTION: {} parameter combinations tested. \
244 Bonferroni-adjusted alpha: {:.6}. \
245 Walk-forward analysis recommended.",
246 self.combinations_tested, adjusted_alpha
247 )),
248 WarningLevel::Warning | WarningLevel::Critical => Some(format!(
249 "WARNING: {} parameter combinations tested without walk-forward analysis.\n\
250 Bonferroni-adjusted alpha: {:.2e}\n\
251 This many tests dramatically increases false discovery risk.\n\n\
252 To address this:\n\
253 1. Use walk-forward analysis: `walk_forward: {{ ... }}`\n\
254 2. Or explicitly accept risk: `accept_overfitting_risk: true`",
255 self.combinations_tested, adjusted_alpha
256 )),
257 }
258 }
259
260 pub fn emit_warning_if_needed(&self) {
262 if self.accept_overfitting_risk {
263 return;
264 }
265
266 let stats = self.get_stats();
267 if let Some(msg) = &stats.warning_message {
268 if stats.warning_level >= WarningLevel::Caution {
269 eprintln!("\n{}\n", msg);
270 }
271 }
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_warning_levels() {
281 let mut guard = MultipleTestingGuard::new(0.05);
282
283 assert_eq!(guard.warning_level(), WarningLevel::None);
284
285 guard.record_tests(50);
286 assert_eq!(guard.warning_level(), WarningLevel::Info);
287
288 guard.record_tests(150);
289 assert_eq!(guard.warning_level(), WarningLevel::Caution);
290
291 guard.record_tests(800);
292 assert_eq!(guard.warning_level(), WarningLevel::Critical);
293 }
294
295 #[test]
296 fn test_bonferroni_correction() {
297 let mut guard = MultipleTestingGuard::new(0.05);
298 guard.record_tests(100);
299
300 let adjusted = guard.adjusted_alpha();
301 assert!((adjusted - 0.0005).abs() < 1e-10);
302 }
303
304 #[test]
305 fn test_accept_risk_suppresses_warning() {
306 let mut guard = MultipleTestingGuard::new(0.05);
307 guard.record_tests(500);
308 guard.accept_risk();
309
310 let stats = guard.get_stats();
311 assert!(stats.warning_message.is_none());
312 assert!(stats.risk_accepted);
313 }
314}