1#![allow(
2 unused_variables,
3 unused_imports,
4 clippy::too_many_arguments,
5 clippy::needless_range_loop
6)]
7
8use pyo3::prelude::*;
9use rayon::prelude::*;
10
11#[derive(Debug, Clone, Copy, PartialEq)]
12#[pyclass]
13pub enum TimeVaryingTestType {
14 SlopeTest,
15 VarianceTest,
16 BreakpointTest,
17}
18
19#[pymethods]
20impl TimeVaryingTestType {
21 #[new]
22 fn new(name: &str) -> PyResult<Self> {
23 match name.to_lowercase().as_str() {
24 "slope" | "slopetest" => Ok(TimeVaryingTestType::SlopeTest),
25 "variance" | "variancetest" => Ok(TimeVaryingTestType::VarianceTest),
26 "breakpoint" | "breakpointtest" => Ok(TimeVaryingTestType::BreakpointTest),
27 _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
28 "Unknown test type. Use 'slope', 'variance', or 'breakpoint'",
29 )),
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
35#[pyclass]
36pub struct TimeVaryingTestConfig {
37 #[pyo3(get, set)]
38 pub test_type: TimeVaryingTestType,
39 #[pyo3(get, set)]
40 pub n_windows: usize,
41 #[pyo3(get, set)]
42 pub min_window_size: usize,
43 #[pyo3(get, set)]
44 pub significance_level: f64,
45 #[pyo3(get, set)]
46 pub n_permutations: usize,
47}
48
49#[pymethods]
50impl TimeVaryingTestConfig {
51 #[new]
52 #[pyo3(signature = (
53 test_type=TimeVaryingTestType::SlopeTest,
54 n_windows=5,
55 min_window_size=10,
56 significance_level=0.05,
57 n_permutations=1000
58 ))]
59 pub fn new(
60 test_type: TimeVaryingTestType,
61 n_windows: usize,
62 min_window_size: usize,
63 significance_level: f64,
64 n_permutations: usize,
65 ) -> PyResult<Self> {
66 if n_windows == 0 {
67 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
68 "n_windows must be positive",
69 ));
70 }
71 if !(0.0..1.0).contains(&significance_level) {
72 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
73 "significance_level must be between 0 and 1",
74 ));
75 }
76
77 Ok(TimeVaryingTestConfig {
78 test_type,
79 n_windows,
80 min_window_size,
81 significance_level,
82 n_permutations,
83 })
84 }
85}
86
87#[derive(Debug, Clone)]
88#[pyclass]
89pub struct TimeVaryingTestResult {
90 #[pyo3(get)]
91 pub feature_idx: usize,
92 #[pyo3(get)]
93 pub is_time_varying: bool,
94 #[pyo3(get)]
95 pub test_statistic: f64,
96 #[pyo3(get)]
97 pub p_value: f64,
98 #[pyo3(get)]
99 pub slope: Option<f64>,
100 #[pyo3(get)]
101 pub slope_se: Option<f64>,
102 #[pyo3(get)]
103 pub window_means: Option<Vec<f64>>,
104 #[pyo3(get)]
105 pub window_variances: Option<Vec<f64>>,
106 #[pyo3(get)]
107 pub breakpoint_time: Option<f64>,
108 #[pyo3(get)]
109 pub effect_size: f64,
110}
111
112#[pymethods]
113impl TimeVaryingTestResult {
114 fn __repr__(&self) -> String {
115 format!(
116 "TimeVaryingTestResult(feature={}, time_varying={}, p={:.4})",
117 self.feature_idx, self.is_time_varying, self.p_value
118 )
119 }
120}
121
122#[derive(Debug, Clone)]
123#[pyclass]
124pub struct TimeVaryingAnalysis {
125 #[pyo3(get)]
126 pub results: Vec<TimeVaryingTestResult>,
127 #[pyo3(get)]
128 pub time_varying_features: Vec<usize>,
129 #[pyo3(get)]
130 pub stable_features: Vec<usize>,
131 #[pyo3(get)]
132 pub feature_rankings: Vec<(usize, f64)>,
133}
134
135#[pymethods]
136impl TimeVaryingAnalysis {
137 fn __repr__(&self) -> String {
138 format!(
139 "TimeVaryingAnalysis(n_time_varying={}, n_stable={})",
140 self.time_varying_features.len(),
141 self.stable_features.len()
142 )
143 }
144
145 fn get_feature_result(&self, feature_idx: usize) -> Option<TimeVaryingTestResult> {
146 self.results
147 .iter()
148 .find(|r| r.feature_idx == feature_idx)
149 .cloned()
150 }
151}
152
153fn compute_slope_test(
154 shap_values: &[f64],
155 time_points: &[f64],
156 n_times: usize,
157) -> (f64, f64, f64, f64) {
158 if n_times < 2 {
159 return (0.0, 0.0, f64::NAN, 1.0);
160 }
161
162 let mean_t: f64 = time_points.iter().sum::<f64>() / n_times as f64;
163 let mean_y: f64 = shap_values.iter().sum::<f64>() / n_times as f64;
164
165 let mut ss_tt = 0.0;
166 let mut ss_ty = 0.0;
167
168 for i in 0..n_times {
169 let t_diff = time_points[i] - mean_t;
170 let y_diff = shap_values[i] - mean_y;
171 ss_tt += t_diff * t_diff;
172 ss_ty += t_diff * y_diff;
173 }
174
175 if ss_tt.abs() < 1e-12 {
176 return (0.0, 0.0, f64::NAN, 1.0);
177 }
178
179 let slope = ss_ty / ss_tt;
180
181 let mut ss_res = 0.0;
182 for i in 0..n_times {
183 let predicted = mean_y + slope * (time_points[i] - mean_t);
184 let residual = shap_values[i] - predicted;
185 ss_res += residual * residual;
186 }
187
188 let mse = ss_res / (n_times - 2).max(1) as f64;
189 let slope_se = (mse / ss_tt).sqrt();
190
191 let t_stat = if slope_se > 1e-12 {
192 slope / slope_se
193 } else {
194 0.0
195 };
196
197 let df = (n_times - 2) as f64;
198 let p_value = 2.0 * (1.0 - t_distribution_cdf(t_stat.abs(), df));
199
200 (slope, slope_se, t_stat, p_value)
201}
202
203fn t_distribution_cdf(t: f64, df: f64) -> f64 {
204 if df <= 0.0 {
205 return 0.5;
206 }
207
208 let x = df / (df + t * t);
209 let a = df / 2.0;
210 let b = 0.5;
211
212 let beta_cdf = incomplete_beta(a, b, x);
213
214 if t >= 0.0 {
215 1.0 - 0.5 * beta_cdf
216 } else {
217 0.5 * beta_cdf
218 }
219}
220
221fn incomplete_beta(a: f64, b: f64, x: f64) -> f64 {
222 if x <= 0.0 {
223 return 0.0;
224 }
225 if x >= 1.0 {
226 return 1.0;
227 }
228
229 let bt = if x == 0.0 || x == 1.0 {
230 0.0
231 } else {
232 (ln_gamma(a + b) - ln_gamma(a) - ln_gamma(b) + a * x.ln() + b * (1.0 - x).ln()).exp()
233 };
234
235 let symmetry_transform = x < (a + 1.0) / (a + b + 2.0);
236
237 if symmetry_transform {
238 bt * beta_cf(a, b, x) / a
239 } else {
240 1.0 - bt * beta_cf(b, a, 1.0 - x) / b
241 }
242}
243
244fn beta_cf(a: f64, b: f64, x: f64) -> f64 {
245 let qab = a + b;
246 let qap = a + 1.0;
247 let qam = a - 1.0;
248
249 let mut c = 1.0;
250 let mut d = 1.0 - qab * x / qap;
251 if d.abs() < 1e-30 {
252 d = 1e-30;
253 }
254 d = 1.0 / d;
255 let mut h = d;
256
257 for m in 1..=100 {
258 let m = m as f64;
259 let m2 = 2.0 * m;
260
261 let aa = m * (b - m) * x / ((qam + m2) * (a + m2));
262 d = 1.0 + aa * d;
263 if d.abs() < 1e-30 {
264 d = 1e-30;
265 }
266 c = 1.0 + aa / c;
267 if c.abs() < 1e-30 {
268 c = 1e-30;
269 }
270 d = 1.0 / d;
271 h *= d * c;
272
273 let aa = -(a + m) * (qab + m) * x / ((a + m2) * (qap + m2));
274 d = 1.0 + aa * d;
275 if d.abs() < 1e-30 {
276 d = 1e-30;
277 }
278 c = 1.0 + aa / c;
279 if c.abs() < 1e-30 {
280 c = 1e-30;
281 }
282 d = 1.0 / d;
283 let del = d * c;
284 h *= del;
285
286 if (del - 1.0).abs() < 1e-10 {
287 break;
288 }
289 }
290
291 h
292}
293
294fn ln_gamma(x: f64) -> f64 {
295 let cof = [
296 76.18009172947146,
297 -86.50532032941677,
298 24.01409824083091,
299 -1.231739572450155,
300 0.1208650973866179e-2,
301 -0.5395239384953e-5,
302 ];
303
304 let y = x;
305 let mut tmp = x + 5.5;
306 tmp -= (x + 0.5) * tmp.ln();
307 let mut ser = 1.000000000190015;
308
309 for (j, &c) in cof.iter().enumerate() {
310 ser += c / (y + 1.0 + j as f64);
311 }
312
313 -tmp + (2.5066282746310005 * ser / x).ln()
314}
315
316fn compute_variance_test(
317 shap_values: &[f64],
318 time_points: &[f64],
319 n_times: usize,
320 n_windows: usize,
321) -> (Vec<f64>, Vec<f64>, f64, f64) {
322 let window_size = (n_times / n_windows).max(1);
323 let mut window_means = Vec::with_capacity(n_windows);
324 let mut window_variances = Vec::with_capacity(n_windows);
325
326 for w in 0..n_windows {
327 let start = w * window_size;
328 let end = if w == n_windows - 1 {
329 n_times
330 } else {
331 (start + window_size).min(n_times)
332 };
333
334 if start >= n_times {
335 break;
336 }
337
338 let window_vals: Vec<f64> = shap_values[start..end].to_vec();
339 let n = window_vals.len();
340
341 if n == 0 {
342 continue;
343 }
344
345 let mean = window_vals.iter().sum::<f64>() / n as f64;
346 let var = if n > 1 {
347 window_vals.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / (n - 1) as f64
348 } else {
349 0.0
350 };
351
352 window_means.push(mean);
353 window_variances.push(var);
354 }
355
356 let k = window_variances.len();
357 if k < 2 {
358 return (window_means, window_variances, 0.0, 1.0);
359 }
360
361 let n_total = n_times as f64;
362 let pooled_var = window_variances.iter().sum::<f64>() / k as f64;
363
364 let mut bartlett_num = 0.0;
365 let mut bartlett_denom = 0.0;
366
367 for (w, &var) in window_variances.iter().enumerate() {
368 let n_w = window_size as f64;
369 if var > 1e-12 && pooled_var > 1e-12 {
370 bartlett_num += (n_w - 1.0) * (var / pooled_var).ln();
371 }
372 bartlett_denom += 1.0 / (n_w - 1.0);
373 }
374
375 let c = 1.0 + (1.0 / (3.0 * (k as f64 - 1.0))) * (bartlett_denom - 1.0 / (n_total - k as f64));
376
377 let chi2_stat = if c > 1e-12 { -bartlett_num / c } else { 0.0 };
378
379 let df = (k - 1) as f64;
380 let p_value = 1.0 - chi_squared_cdf(chi2_stat.abs(), df);
381
382 (window_means, window_variances, chi2_stat, p_value)
383}
384
385fn chi_squared_cdf(x: f64, df: f64) -> f64 {
386 if x <= 0.0 || df <= 0.0 {
387 return 0.0;
388 }
389 incomplete_gamma(df / 2.0, x / 2.0)
390}
391
392fn incomplete_gamma(a: f64, x: f64) -> f64 {
393 if x <= 0.0 {
394 return 0.0;
395 }
396 if x < a + 1.0 {
397 gamma_series(a, x)
398 } else {
399 1.0 - gamma_cf(a, x)
400 }
401}
402
403fn gamma_series(a: f64, x: f64) -> f64 {
404 let gln = ln_gamma(a);
405 let mut ap = a;
406 let mut sum = 1.0 / a;
407 let mut del = sum;
408
409 for _ in 0..100 {
410 ap += 1.0;
411 del *= x / ap;
412 sum += del;
413 if del.abs() < sum.abs() * 1e-10 {
414 break;
415 }
416 }
417
418 sum * (-x + a * x.ln() - gln).exp()
419}
420
421fn gamma_cf(a: f64, x: f64) -> f64 {
422 let gln = ln_gamma(a);
423 let mut b = x + 1.0 - a;
424 let mut c = 1.0 / 1e-30;
425 let mut d = 1.0 / b;
426 let mut h = d;
427
428 for i in 1..=100 {
429 let i = i as f64;
430 let an = -i * (i - a);
431 b += 2.0;
432 d = an * d + b;
433 if d.abs() < 1e-30 {
434 d = 1e-30;
435 }
436 c = b + an / c;
437 if c.abs() < 1e-30 {
438 c = 1e-30;
439 }
440 d = 1.0 / d;
441 let del = d * c;
442 h *= del;
443 if (del - 1.0).abs() < 1e-10 {
444 break;
445 }
446 }
447
448 (-x + a * x.ln() - gln).exp() * h
449}
450
451fn compute_breakpoint_test(
452 shap_values: &[f64],
453 time_points: &[f64],
454 n_times: usize,
455 min_segment: usize,
456) -> (Option<f64>, f64, f64) {
457 if n_times < 2 * min_segment {
458 return (None, 0.0, 1.0);
459 }
460
461 let total_mean: f64 = shap_values.iter().sum::<f64>() / n_times as f64;
462 let total_ss: f64 = shap_values.iter().map(|&v| (v - total_mean).powi(2)).sum();
463
464 let mut min_ss = total_ss;
465 let mut best_breakpoint = None;
466
467 for k in min_segment..(n_times - min_segment) {
468 let left = &shap_values[..k];
469 let right = &shap_values[k..];
470
471 let left_mean = left.iter().sum::<f64>() / k as f64;
472 let right_mean = right.iter().sum::<f64>() / (n_times - k) as f64;
473
474 let left_ss: f64 = left.iter().map(|&v| (v - left_mean).powi(2)).sum();
475 let right_ss: f64 = right.iter().map(|&v| (v - right_mean).powi(2)).sum();
476
477 let combined_ss = left_ss + right_ss;
478
479 if combined_ss < min_ss {
480 min_ss = combined_ss;
481 best_breakpoint = Some(time_points[k]);
482 }
483 }
484
485 let f_stat = if min_ss > 1e-12 && n_times > 3 {
486 ((total_ss - min_ss) / 1.0) / (min_ss / (n_times - 3) as f64)
487 } else {
488 0.0
489 };
490
491 let p_value = 1.0 - f_distribution_cdf(f_stat, 1.0, (n_times - 3) as f64);
492
493 (best_breakpoint, f_stat, p_value)
494}
495
496fn f_distribution_cdf(f: f64, df1: f64, df2: f64) -> f64 {
497 if f <= 0.0 {
498 return 0.0;
499 }
500 let x = df2 / (df2 + df1 * f);
501 incomplete_beta(df2 / 2.0, df1 / 2.0, x)
502}
503
504#[pyfunction]
505#[pyo3(signature = (
506 shap_values,
507 time_points,
508 n_samples,
509 n_features,
510 config
511))]
512pub fn detect_time_varying_features(
513 shap_values: Vec<Vec<Vec<f64>>>,
514 time_points: Vec<f64>,
515 n_samples: usize,
516 n_features: usize,
517 config: &TimeVaryingTestConfig,
518) -> PyResult<TimeVaryingAnalysis> {
519 let n_times = time_points.len();
520
521 if shap_values.len() != n_samples {
522 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
523 "shap_values first dimension must match n_samples",
524 ));
525 }
526
527 let results: Vec<TimeVaryingTestResult> = (0..n_features)
528 .into_par_iter()
529 .map(|f| {
530 let aggregated_shap: Vec<f64> = (0..n_times)
531 .map(|t| {
532 shap_values
533 .iter()
534 .map(|sample| sample[f][t].abs())
535 .sum::<f64>()
536 / n_samples as f64
537 })
538 .collect();
539
540 match config.test_type {
541 TimeVaryingTestType::SlopeTest => {
542 let (slope, slope_se, t_stat, p_value) =
543 compute_slope_test(&aggregated_shap, &time_points, n_times);
544
545 let effect_size = slope.abs()
546 * (time_points.last().unwrap_or(&1.0)
547 - time_points.first().unwrap_or(&0.0));
548
549 TimeVaryingTestResult {
550 feature_idx: f,
551 is_time_varying: p_value < config.significance_level,
552 test_statistic: t_stat,
553 p_value,
554 slope: Some(slope),
555 slope_se: Some(slope_se),
556 window_means: None,
557 window_variances: None,
558 breakpoint_time: None,
559 effect_size,
560 }
561 }
562 TimeVaryingTestType::VarianceTest => {
563 let (window_means, window_variances, chi2_stat, p_value) =
564 compute_variance_test(
565 &aggregated_shap,
566 &time_points,
567 n_times,
568 config.n_windows,
569 );
570
571 let max_var = window_variances.iter().fold(0.0f64, |a, &b| a.max(b));
572 let min_var = window_variances
573 .iter()
574 .fold(f64::INFINITY, |a, &b| a.min(b));
575 let effect_size = if min_var > 1e-12 {
576 (max_var / min_var).ln()
577 } else {
578 0.0
579 };
580
581 TimeVaryingTestResult {
582 feature_idx: f,
583 is_time_varying: p_value < config.significance_level,
584 test_statistic: chi2_stat,
585 p_value,
586 slope: None,
587 slope_se: None,
588 window_means: Some(window_means),
589 window_variances: Some(window_variances),
590 breakpoint_time: None,
591 effect_size,
592 }
593 }
594 TimeVaryingTestType::BreakpointTest => {
595 let (breakpoint, f_stat, p_value) = compute_breakpoint_test(
596 &aggregated_shap,
597 &time_points,
598 n_times,
599 config.min_window_size,
600 );
601
602 let effect_size = f_stat.sqrt();
603
604 TimeVaryingTestResult {
605 feature_idx: f,
606 is_time_varying: p_value < config.significance_level,
607 test_statistic: f_stat,
608 p_value,
609 slope: None,
610 slope_se: None,
611 window_means: None,
612 window_variances: None,
613 breakpoint_time: breakpoint,
614 effect_size,
615 }
616 }
617 }
618 })
619 .collect();
620
621 let time_varying_features: Vec<usize> = results
622 .iter()
623 .filter(|r| r.is_time_varying)
624 .map(|r| r.feature_idx)
625 .collect();
626
627 let stable_features: Vec<usize> = results
628 .iter()
629 .filter(|r| !r.is_time_varying)
630 .map(|r| r.feature_idx)
631 .collect();
632
633 let mut feature_rankings: Vec<(usize, f64)> = results
634 .iter()
635 .map(|r| (r.feature_idx, r.effect_size))
636 .collect();
637 feature_rankings.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
638
639 Ok(TimeVaryingAnalysis {
640 results,
641 time_varying_features,
642 stable_features,
643 feature_rankings,
644 })
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650
651 #[test]
652 fn test_config() {
653 let config =
654 TimeVaryingTestConfig::new(TimeVaryingTestType::SlopeTest, 5, 10, 0.05, 1000).unwrap();
655 assert_eq!(config.n_windows, 5);
656 }
657
658 #[test]
659 fn test_slope_test() {
660 let shap = vec![0.12, 0.18, 0.32, 0.38, 0.52];
661 let time = vec![1.0, 2.0, 3.0, 4.0, 5.0];
662 let (slope, se, t_stat, _p_value) = compute_slope_test(&shap, &time, 5);
663
664 assert!((slope - 0.1).abs() < 0.05);
665 assert!(se > 0.0);
666 assert!(t_stat.abs() > 0.0);
667 }
668
669 #[test]
670 fn test_variance_test() {
671 let shap = vec![0.1, 0.15, 0.12, 0.5, 0.6, 0.55, 0.1, 0.12, 0.11, 0.58];
672 let time: Vec<f64> = (0..10).map(|i| i as f64).collect();
673 let (means, vars, stat, p) = compute_variance_test(&shap, &time, 10, 2);
674
675 assert_eq!(means.len(), 2);
676 assert_eq!(vars.len(), 2);
677 }
678
679 #[test]
680 fn test_breakpoint_test() {
681 let shap: Vec<f64> = (0..20).map(|i| if i < 10 { 0.1 } else { 0.5 }).collect();
682 let time: Vec<f64> = (0..20).map(|i| i as f64).collect();
683 let (bp, stat, p) = compute_breakpoint_test(&shap, &time, 20, 3);
684
685 assert!(bp.is_some());
686 }
687
688 #[test]
689 fn test_detect_time_varying() {
690 let n_samples = 5;
691 let n_features = 3;
692 let n_times = 10;
693
694 let shap_values: Vec<Vec<Vec<f64>>> = (0..n_samples)
695 .map(|_| {
696 (0..n_features)
697 .map(|f| {
698 (0..n_times)
699 .map(|t| if f == 0 { t as f64 * 0.1 } else { 0.5 })
700 .collect()
701 })
702 .collect()
703 })
704 .collect();
705
706 let time_points: Vec<f64> = (0..n_times).map(|t| t as f64).collect();
707
708 let config =
709 TimeVaryingTestConfig::new(TimeVaryingTestType::SlopeTest, 5, 2, 0.05, 100).unwrap();
710
711 let result =
712 detect_time_varying_features(shap_values, time_points, n_samples, n_features, &config)
713 .unwrap();
714
715 assert_eq!(result.results.len(), n_features);
716 }
717}