Skip to main content

survival/interpretability/
time_varying.rs

1#![allow(
2    unused_variables,
3    unused_imports,
4    clippy::too_many_arguments,
5    clippy::needless_range_loop
6)]
7
8use pyo3::prelude::*;
9use rayon::prelude::*;
10
11#[derive(Debug, Clone, Copy, PartialEq)]
12#[pyclass]
13pub enum TimeVaryingTestType {
14    SlopeTest,
15    VarianceTest,
16    BreakpointTest,
17}
18
19#[pymethods]
20impl TimeVaryingTestType {
21    #[new]
22    fn new(name: &str) -> PyResult<Self> {
23        match name.to_lowercase().as_str() {
24            "slope" | "slopetest" => Ok(TimeVaryingTestType::SlopeTest),
25            "variance" | "variancetest" => Ok(TimeVaryingTestType::VarianceTest),
26            "breakpoint" | "breakpointtest" => Ok(TimeVaryingTestType::BreakpointTest),
27            _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
28                "Unknown test type. Use 'slope', 'variance', or 'breakpoint'",
29            )),
30        }
31    }
32}
33
34#[derive(Debug, Clone)]
35#[pyclass]
36pub struct TimeVaryingTestConfig {
37    #[pyo3(get, set)]
38    pub test_type: TimeVaryingTestType,
39    #[pyo3(get, set)]
40    pub n_windows: usize,
41    #[pyo3(get, set)]
42    pub min_window_size: usize,
43    #[pyo3(get, set)]
44    pub significance_level: f64,
45    #[pyo3(get, set)]
46    pub n_permutations: usize,
47}
48
49#[pymethods]
50impl TimeVaryingTestConfig {
51    #[new]
52    #[pyo3(signature = (
53        test_type=TimeVaryingTestType::SlopeTest,
54        n_windows=5,
55        min_window_size=10,
56        significance_level=0.05,
57        n_permutations=1000
58    ))]
59    pub fn new(
60        test_type: TimeVaryingTestType,
61        n_windows: usize,
62        min_window_size: usize,
63        significance_level: f64,
64        n_permutations: usize,
65    ) -> PyResult<Self> {
66        if n_windows == 0 {
67            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
68                "n_windows must be positive",
69            ));
70        }
71        if !(0.0..1.0).contains(&significance_level) {
72            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
73                "significance_level must be between 0 and 1",
74            ));
75        }
76
77        Ok(TimeVaryingTestConfig {
78            test_type,
79            n_windows,
80            min_window_size,
81            significance_level,
82            n_permutations,
83        })
84    }
85}
86
87#[derive(Debug, Clone)]
88#[pyclass]
89pub struct TimeVaryingTestResult {
90    #[pyo3(get)]
91    pub feature_idx: usize,
92    #[pyo3(get)]
93    pub is_time_varying: bool,
94    #[pyo3(get)]
95    pub test_statistic: f64,
96    #[pyo3(get)]
97    pub p_value: f64,
98    #[pyo3(get)]
99    pub slope: Option<f64>,
100    #[pyo3(get)]
101    pub slope_se: Option<f64>,
102    #[pyo3(get)]
103    pub window_means: Option<Vec<f64>>,
104    #[pyo3(get)]
105    pub window_variances: Option<Vec<f64>>,
106    #[pyo3(get)]
107    pub breakpoint_time: Option<f64>,
108    #[pyo3(get)]
109    pub effect_size: f64,
110}
111
112#[pymethods]
113impl TimeVaryingTestResult {
114    fn __repr__(&self) -> String {
115        format!(
116            "TimeVaryingTestResult(feature={}, time_varying={}, p={:.4})",
117            self.feature_idx, self.is_time_varying, self.p_value
118        )
119    }
120}
121
122#[derive(Debug, Clone)]
123#[pyclass]
124pub struct TimeVaryingAnalysis {
125    #[pyo3(get)]
126    pub results: Vec<TimeVaryingTestResult>,
127    #[pyo3(get)]
128    pub time_varying_features: Vec<usize>,
129    #[pyo3(get)]
130    pub stable_features: Vec<usize>,
131    #[pyo3(get)]
132    pub feature_rankings: Vec<(usize, f64)>,
133}
134
135#[pymethods]
136impl TimeVaryingAnalysis {
137    fn __repr__(&self) -> String {
138        format!(
139            "TimeVaryingAnalysis(n_time_varying={}, n_stable={})",
140            self.time_varying_features.len(),
141            self.stable_features.len()
142        )
143    }
144
145    fn get_feature_result(&self, feature_idx: usize) -> Option<TimeVaryingTestResult> {
146        self.results
147            .iter()
148            .find(|r| r.feature_idx == feature_idx)
149            .cloned()
150    }
151}
152
153fn compute_slope_test(
154    shap_values: &[f64],
155    time_points: &[f64],
156    n_times: usize,
157) -> (f64, f64, f64, f64) {
158    if n_times < 2 {
159        return (0.0, 0.0, f64::NAN, 1.0);
160    }
161
162    let mean_t: f64 = time_points.iter().sum::<f64>() / n_times as f64;
163    let mean_y: f64 = shap_values.iter().sum::<f64>() / n_times as f64;
164
165    let mut ss_tt = 0.0;
166    let mut ss_ty = 0.0;
167
168    for i in 0..n_times {
169        let t_diff = time_points[i] - mean_t;
170        let y_diff = shap_values[i] - mean_y;
171        ss_tt += t_diff * t_diff;
172        ss_ty += t_diff * y_diff;
173    }
174
175    if ss_tt.abs() < 1e-12 {
176        return (0.0, 0.0, f64::NAN, 1.0);
177    }
178
179    let slope = ss_ty / ss_tt;
180
181    let mut ss_res = 0.0;
182    for i in 0..n_times {
183        let predicted = mean_y + slope * (time_points[i] - mean_t);
184        let residual = shap_values[i] - predicted;
185        ss_res += residual * residual;
186    }
187
188    let mse = ss_res / (n_times - 2).max(1) as f64;
189    let slope_se = (mse / ss_tt).sqrt();
190
191    let t_stat = if slope_se > 1e-12 {
192        slope / slope_se
193    } else {
194        0.0
195    };
196
197    let df = (n_times - 2) as f64;
198    let p_value = 2.0 * (1.0 - t_distribution_cdf(t_stat.abs(), df));
199
200    (slope, slope_se, t_stat, p_value)
201}
202
203fn t_distribution_cdf(t: f64, df: f64) -> f64 {
204    if df <= 0.0 {
205        return 0.5;
206    }
207
208    let x = df / (df + t * t);
209    let a = df / 2.0;
210    let b = 0.5;
211
212    let beta_cdf = incomplete_beta(a, b, x);
213
214    if t >= 0.0 {
215        1.0 - 0.5 * beta_cdf
216    } else {
217        0.5 * beta_cdf
218    }
219}
220
221fn incomplete_beta(a: f64, b: f64, x: f64) -> f64 {
222    if x <= 0.0 {
223        return 0.0;
224    }
225    if x >= 1.0 {
226        return 1.0;
227    }
228
229    let bt = if x == 0.0 || x == 1.0 {
230        0.0
231    } else {
232        (ln_gamma(a + b) - ln_gamma(a) - ln_gamma(b) + a * x.ln() + b * (1.0 - x).ln()).exp()
233    };
234
235    let symmetry_transform = x < (a + 1.0) / (a + b + 2.0);
236
237    if symmetry_transform {
238        bt * beta_cf(a, b, x) / a
239    } else {
240        1.0 - bt * beta_cf(b, a, 1.0 - x) / b
241    }
242}
243
244fn beta_cf(a: f64, b: f64, x: f64) -> f64 {
245    let qab = a + b;
246    let qap = a + 1.0;
247    let qam = a - 1.0;
248
249    let mut c = 1.0;
250    let mut d = 1.0 - qab * x / qap;
251    if d.abs() < 1e-30 {
252        d = 1e-30;
253    }
254    d = 1.0 / d;
255    let mut h = d;
256
257    for m in 1..=100 {
258        let m = m as f64;
259        let m2 = 2.0 * m;
260
261        let aa = m * (b - m) * x / ((qam + m2) * (a + m2));
262        d = 1.0 + aa * d;
263        if d.abs() < 1e-30 {
264            d = 1e-30;
265        }
266        c = 1.0 + aa / c;
267        if c.abs() < 1e-30 {
268            c = 1e-30;
269        }
270        d = 1.0 / d;
271        h *= d * c;
272
273        let aa = -(a + m) * (qab + m) * x / ((a + m2) * (qap + m2));
274        d = 1.0 + aa * d;
275        if d.abs() < 1e-30 {
276            d = 1e-30;
277        }
278        c = 1.0 + aa / c;
279        if c.abs() < 1e-30 {
280            c = 1e-30;
281        }
282        d = 1.0 / d;
283        let del = d * c;
284        h *= del;
285
286        if (del - 1.0).abs() < 1e-10 {
287            break;
288        }
289    }
290
291    h
292}
293
294fn ln_gamma(x: f64) -> f64 {
295    let cof = [
296        76.18009172947146,
297        -86.50532032941677,
298        24.01409824083091,
299        -1.231739572450155,
300        0.1208650973866179e-2,
301        -0.5395239384953e-5,
302    ];
303
304    let y = x;
305    let mut tmp = x + 5.5;
306    tmp -= (x + 0.5) * tmp.ln();
307    let mut ser = 1.000000000190015;
308
309    for (j, &c) in cof.iter().enumerate() {
310        ser += c / (y + 1.0 + j as f64);
311    }
312
313    -tmp + (2.5066282746310005 * ser / x).ln()
314}
315
316fn compute_variance_test(
317    shap_values: &[f64],
318    time_points: &[f64],
319    n_times: usize,
320    n_windows: usize,
321) -> (Vec<f64>, Vec<f64>, f64, f64) {
322    let window_size = (n_times / n_windows).max(1);
323    let mut window_means = Vec::with_capacity(n_windows);
324    let mut window_variances = Vec::with_capacity(n_windows);
325
326    for w in 0..n_windows {
327        let start = w * window_size;
328        let end = if w == n_windows - 1 {
329            n_times
330        } else {
331            (start + window_size).min(n_times)
332        };
333
334        if start >= n_times {
335            break;
336        }
337
338        let window_vals: Vec<f64> = shap_values[start..end].to_vec();
339        let n = window_vals.len();
340
341        if n == 0 {
342            continue;
343        }
344
345        let mean = window_vals.iter().sum::<f64>() / n as f64;
346        let var = if n > 1 {
347            window_vals.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / (n - 1) as f64
348        } else {
349            0.0
350        };
351
352        window_means.push(mean);
353        window_variances.push(var);
354    }
355
356    let k = window_variances.len();
357    if k < 2 {
358        return (window_means, window_variances, 0.0, 1.0);
359    }
360
361    let n_total = n_times as f64;
362    let pooled_var = window_variances.iter().sum::<f64>() / k as f64;
363
364    let mut bartlett_num = 0.0;
365    let mut bartlett_denom = 0.0;
366
367    for (w, &var) in window_variances.iter().enumerate() {
368        let n_w = window_size as f64;
369        if var > 1e-12 && pooled_var > 1e-12 {
370            bartlett_num += (n_w - 1.0) * (var / pooled_var).ln();
371        }
372        bartlett_denom += 1.0 / (n_w - 1.0);
373    }
374
375    let c = 1.0 + (1.0 / (3.0 * (k as f64 - 1.0))) * (bartlett_denom - 1.0 / (n_total - k as f64));
376
377    let chi2_stat = if c > 1e-12 { -bartlett_num / c } else { 0.0 };
378
379    let df = (k - 1) as f64;
380    let p_value = 1.0 - chi_squared_cdf(chi2_stat.abs(), df);
381
382    (window_means, window_variances, chi2_stat, p_value)
383}
384
385fn chi_squared_cdf(x: f64, df: f64) -> f64 {
386    if x <= 0.0 || df <= 0.0 {
387        return 0.0;
388    }
389    incomplete_gamma(df / 2.0, x / 2.0)
390}
391
392fn incomplete_gamma(a: f64, x: f64) -> f64 {
393    if x <= 0.0 {
394        return 0.0;
395    }
396    if x < a + 1.0 {
397        gamma_series(a, x)
398    } else {
399        1.0 - gamma_cf(a, x)
400    }
401}
402
403fn gamma_series(a: f64, x: f64) -> f64 {
404    let gln = ln_gamma(a);
405    let mut ap = a;
406    let mut sum = 1.0 / a;
407    let mut del = sum;
408
409    for _ in 0..100 {
410        ap += 1.0;
411        del *= x / ap;
412        sum += del;
413        if del.abs() < sum.abs() * 1e-10 {
414            break;
415        }
416    }
417
418    sum * (-x + a * x.ln() - gln).exp()
419}
420
421fn gamma_cf(a: f64, x: f64) -> f64 {
422    let gln = ln_gamma(a);
423    let mut b = x + 1.0 - a;
424    let mut c = 1.0 / 1e-30;
425    let mut d = 1.0 / b;
426    let mut h = d;
427
428    for i in 1..=100 {
429        let i = i as f64;
430        let an = -i * (i - a);
431        b += 2.0;
432        d = an * d + b;
433        if d.abs() < 1e-30 {
434            d = 1e-30;
435        }
436        c = b + an / c;
437        if c.abs() < 1e-30 {
438            c = 1e-30;
439        }
440        d = 1.0 / d;
441        let del = d * c;
442        h *= del;
443        if (del - 1.0).abs() < 1e-10 {
444            break;
445        }
446    }
447
448    (-x + a * x.ln() - gln).exp() * h
449}
450
451fn compute_breakpoint_test(
452    shap_values: &[f64],
453    time_points: &[f64],
454    n_times: usize,
455    min_segment: usize,
456) -> (Option<f64>, f64, f64) {
457    if n_times < 2 * min_segment {
458        return (None, 0.0, 1.0);
459    }
460
461    let total_mean: f64 = shap_values.iter().sum::<f64>() / n_times as f64;
462    let total_ss: f64 = shap_values.iter().map(|&v| (v - total_mean).powi(2)).sum();
463
464    let mut min_ss = total_ss;
465    let mut best_breakpoint = None;
466
467    for k in min_segment..(n_times - min_segment) {
468        let left = &shap_values[..k];
469        let right = &shap_values[k..];
470
471        let left_mean = left.iter().sum::<f64>() / k as f64;
472        let right_mean = right.iter().sum::<f64>() / (n_times - k) as f64;
473
474        let left_ss: f64 = left.iter().map(|&v| (v - left_mean).powi(2)).sum();
475        let right_ss: f64 = right.iter().map(|&v| (v - right_mean).powi(2)).sum();
476
477        let combined_ss = left_ss + right_ss;
478
479        if combined_ss < min_ss {
480            min_ss = combined_ss;
481            best_breakpoint = Some(time_points[k]);
482        }
483    }
484
485    let f_stat = if min_ss > 1e-12 && n_times > 3 {
486        ((total_ss - min_ss) / 1.0) / (min_ss / (n_times - 3) as f64)
487    } else {
488        0.0
489    };
490
491    let p_value = 1.0 - f_distribution_cdf(f_stat, 1.0, (n_times - 3) as f64);
492
493    (best_breakpoint, f_stat, p_value)
494}
495
496fn f_distribution_cdf(f: f64, df1: f64, df2: f64) -> f64 {
497    if f <= 0.0 {
498        return 0.0;
499    }
500    let x = df2 / (df2 + df1 * f);
501    incomplete_beta(df2 / 2.0, df1 / 2.0, x)
502}
503
504#[pyfunction]
505#[pyo3(signature = (
506    shap_values,
507    time_points,
508    n_samples,
509    n_features,
510    config
511))]
512pub fn detect_time_varying_features(
513    shap_values: Vec<Vec<Vec<f64>>>,
514    time_points: Vec<f64>,
515    n_samples: usize,
516    n_features: usize,
517    config: &TimeVaryingTestConfig,
518) -> PyResult<TimeVaryingAnalysis> {
519    let n_times = time_points.len();
520
521    if shap_values.len() != n_samples {
522        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
523            "shap_values first dimension must match n_samples",
524        ));
525    }
526
527    let results: Vec<TimeVaryingTestResult> = (0..n_features)
528        .into_par_iter()
529        .map(|f| {
530            let aggregated_shap: Vec<f64> = (0..n_times)
531                .map(|t| {
532                    shap_values
533                        .iter()
534                        .map(|sample| sample[f][t].abs())
535                        .sum::<f64>()
536                        / n_samples as f64
537                })
538                .collect();
539
540            match config.test_type {
541                TimeVaryingTestType::SlopeTest => {
542                    let (slope, slope_se, t_stat, p_value) =
543                        compute_slope_test(&aggregated_shap, &time_points, n_times);
544
545                    let effect_size = slope.abs()
546                        * (time_points.last().unwrap_or(&1.0)
547                            - time_points.first().unwrap_or(&0.0));
548
549                    TimeVaryingTestResult {
550                        feature_idx: f,
551                        is_time_varying: p_value < config.significance_level,
552                        test_statistic: t_stat,
553                        p_value,
554                        slope: Some(slope),
555                        slope_se: Some(slope_se),
556                        window_means: None,
557                        window_variances: None,
558                        breakpoint_time: None,
559                        effect_size,
560                    }
561                }
562                TimeVaryingTestType::VarianceTest => {
563                    let (window_means, window_variances, chi2_stat, p_value) =
564                        compute_variance_test(
565                            &aggregated_shap,
566                            &time_points,
567                            n_times,
568                            config.n_windows,
569                        );
570
571                    let max_var = window_variances.iter().fold(0.0f64, |a, &b| a.max(b));
572                    let min_var = window_variances
573                        .iter()
574                        .fold(f64::INFINITY, |a, &b| a.min(b));
575                    let effect_size = if min_var > 1e-12 {
576                        (max_var / min_var).ln()
577                    } else {
578                        0.0
579                    };
580
581                    TimeVaryingTestResult {
582                        feature_idx: f,
583                        is_time_varying: p_value < config.significance_level,
584                        test_statistic: chi2_stat,
585                        p_value,
586                        slope: None,
587                        slope_se: None,
588                        window_means: Some(window_means),
589                        window_variances: Some(window_variances),
590                        breakpoint_time: None,
591                        effect_size,
592                    }
593                }
594                TimeVaryingTestType::BreakpointTest => {
595                    let (breakpoint, f_stat, p_value) = compute_breakpoint_test(
596                        &aggregated_shap,
597                        &time_points,
598                        n_times,
599                        config.min_window_size,
600                    );
601
602                    let effect_size = f_stat.sqrt();
603
604                    TimeVaryingTestResult {
605                        feature_idx: f,
606                        is_time_varying: p_value < config.significance_level,
607                        test_statistic: f_stat,
608                        p_value,
609                        slope: None,
610                        slope_se: None,
611                        window_means: None,
612                        window_variances: None,
613                        breakpoint_time: breakpoint,
614                        effect_size,
615                    }
616                }
617            }
618        })
619        .collect();
620
621    let time_varying_features: Vec<usize> = results
622        .iter()
623        .filter(|r| r.is_time_varying)
624        .map(|r| r.feature_idx)
625        .collect();
626
627    let stable_features: Vec<usize> = results
628        .iter()
629        .filter(|r| !r.is_time_varying)
630        .map(|r| r.feature_idx)
631        .collect();
632
633    let mut feature_rankings: Vec<(usize, f64)> = results
634        .iter()
635        .map(|r| (r.feature_idx, r.effect_size))
636        .collect();
637    feature_rankings.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
638
639    Ok(TimeVaryingAnalysis {
640        results,
641        time_varying_features,
642        stable_features,
643        feature_rankings,
644    })
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650
651    #[test]
652    fn test_config() {
653        let config =
654            TimeVaryingTestConfig::new(TimeVaryingTestType::SlopeTest, 5, 10, 0.05, 1000).unwrap();
655        assert_eq!(config.n_windows, 5);
656    }
657
658    #[test]
659    fn test_slope_test() {
660        let shap = vec![0.12, 0.18, 0.32, 0.38, 0.52];
661        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0];
662        let (slope, se, t_stat, _p_value) = compute_slope_test(&shap, &time, 5);
663
664        assert!((slope - 0.1).abs() < 0.05);
665        assert!(se > 0.0);
666        assert!(t_stat.abs() > 0.0);
667    }
668
669    #[test]
670    fn test_variance_test() {
671        let shap = vec![0.1, 0.15, 0.12, 0.5, 0.6, 0.55, 0.1, 0.12, 0.11, 0.58];
672        let time: Vec<f64> = (0..10).map(|i| i as f64).collect();
673        let (means, vars, stat, p) = compute_variance_test(&shap, &time, 10, 2);
674
675        assert_eq!(means.len(), 2);
676        assert_eq!(vars.len(), 2);
677    }
678
679    #[test]
680    fn test_breakpoint_test() {
681        let shap: Vec<f64> = (0..20).map(|i| if i < 10 { 0.1 } else { 0.5 }).collect();
682        let time: Vec<f64> = (0..20).map(|i| i as f64).collect();
683        let (bp, stat, p) = compute_breakpoint_test(&shap, &time, 20, 3);
684
685        assert!(bp.is_some());
686    }
687
688    #[test]
689    fn test_detect_time_varying() {
690        let n_samples = 5;
691        let n_features = 3;
692        let n_times = 10;
693
694        let shap_values: Vec<Vec<Vec<f64>>> = (0..n_samples)
695            .map(|_| {
696                (0..n_features)
697                    .map(|f| {
698                        (0..n_times)
699                            .map(|t| if f == 0 { t as f64 * 0.1 } else { 0.5 })
700                            .collect()
701                    })
702                    .collect()
703            })
704            .collect();
705
706        let time_points: Vec<f64> = (0..n_times).map(|t| t as f64).collect();
707
708        let config =
709            TimeVaryingTestConfig::new(TimeVaryingTestType::SlopeTest, 5, 2, 0.05, 100).unwrap();
710
711        let result =
712            detect_time_varying_features(shap_values, time_points, n_samples, n_features, &config)
713                .unwrap();
714
715        assert_eq!(result.results.len(), n_features);
716    }
717}