Skip to main content

survival/interpretability/
changepoints.rs

1#![allow(
2    unused_variables,
3    unused_imports,
4    clippy::too_many_arguments,
5    clippy::needless_range_loop
6)]
7
8use pyo3::prelude::*;
9use rayon::prelude::*;
10
11#[derive(Debug, Clone, Copy, PartialEq)]
12#[pyclass]
13pub enum ChangepointMethod {
14    PELT,
15    BinarySegment,
16    BottomUp,
17}
18
19#[pymethods]
20impl ChangepointMethod {
21    #[new]
22    fn new(name: &str) -> PyResult<Self> {
23        match name.to_lowercase().as_str() {
24            "pelt" => Ok(ChangepointMethod::PELT),
25            "binary" | "binarysegment" | "binary_segment" => Ok(ChangepointMethod::BinarySegment),
26            "bottomup" | "bottom_up" => Ok(ChangepointMethod::BottomUp),
27            _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
28                "Unknown method. Use 'pelt', 'binary_segment', or 'bottom_up'",
29            )),
30        }
31    }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq)]
35#[pyclass]
36pub enum CostFunction {
37    L2,
38    L1,
39    Normal,
40    Poisson,
41}
42
43#[pymethods]
44impl CostFunction {
45    #[new]
46    fn new(name: &str) -> PyResult<Self> {
47        match name.to_lowercase().as_str() {
48            "l2" | "quadratic" | "normal_mean" => Ok(CostFunction::L2),
49            "l1" | "absolute" => Ok(CostFunction::L1),
50            "normal" | "normal_meanvar" => Ok(CostFunction::Normal),
51            "poisson" => Ok(CostFunction::Poisson),
52            _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
53                "Unknown cost function",
54            )),
55        }
56    }
57}
58
59#[derive(Debug, Clone)]
60#[pyclass]
61pub struct ChangepointConfig {
62    #[pyo3(get, set)]
63    pub method: ChangepointMethod,
64    #[pyo3(get, set)]
65    pub cost: CostFunction,
66    #[pyo3(get, set)]
67    pub penalty: f64,
68    #[pyo3(get, set)]
69    pub min_size: usize,
70    #[pyo3(get, set)]
71    pub max_changepoints: Option<usize>,
72}
73
74#[pymethods]
75impl ChangepointConfig {
76    #[new]
77    #[pyo3(signature = (
78        method=ChangepointMethod::PELT,
79        cost=CostFunction::L2,
80        penalty=1.0,
81        min_size=2,
82        max_changepoints=None
83    ))]
84    pub fn new(
85        method: ChangepointMethod,
86        cost: CostFunction,
87        penalty: f64,
88        min_size: usize,
89        max_changepoints: Option<usize>,
90    ) -> PyResult<Self> {
91        if penalty < 0.0 {
92            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
93                "penalty must be non-negative",
94            ));
95        }
96        if min_size == 0 {
97            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
98                "min_size must be positive",
99            ));
100        }
101
102        Ok(ChangepointConfig {
103            method,
104            cost,
105            penalty,
106            min_size,
107            max_changepoints,
108        })
109    }
110}
111
112#[derive(Debug, Clone)]
113#[pyclass]
114pub struct Changepoint {
115    #[pyo3(get)]
116    pub index: usize,
117    #[pyo3(get)]
118    pub time: f64,
119    #[pyo3(get)]
120    pub cost_improvement: f64,
121    #[pyo3(get)]
122    pub mean_before: f64,
123    #[pyo3(get)]
124    pub mean_after: f64,
125}
126
127#[pymethods]
128impl Changepoint {
129    fn __repr__(&self) -> String {
130        format!(
131            "Changepoint(idx={}, time={:.2}, delta={:.4})",
132            self.index,
133            self.time,
134            self.mean_after - self.mean_before
135        )
136    }
137}
138
139#[derive(Debug, Clone)]
140#[pyclass]
141pub struct ChangepointResult {
142    #[pyo3(get)]
143    pub feature_idx: usize,
144    #[pyo3(get)]
145    pub changepoints: Vec<Changepoint>,
146    #[pyo3(get)]
147    pub segments: Vec<(usize, usize)>,
148    #[pyo3(get)]
149    pub segment_means: Vec<f64>,
150    #[pyo3(get)]
151    pub total_cost: f64,
152    #[pyo3(get)]
153    pub n_changepoints: usize,
154}
155
156#[pymethods]
157impl ChangepointResult {
158    fn __repr__(&self) -> String {
159        format!(
160            "ChangepointResult(feature={}, n_changepoints={})",
161            self.feature_idx, self.n_changepoints
162        )
163    }
164
165    fn get_segment_at(&self, time_idx: usize) -> usize {
166        for (seg_idx, &(start, end)) in self.segments.iter().enumerate() {
167            if time_idx >= start && time_idx < end {
168                return seg_idx;
169            }
170        }
171        self.segments.len().saturating_sub(1)
172    }
173}
174
175#[derive(Debug, Clone)]
176#[pyclass]
177pub struct AllChangepointsResult {
178    #[pyo3(get)]
179    pub results: Vec<ChangepointResult>,
180    #[pyo3(get)]
181    pub features_with_changes: Vec<usize>,
182    #[pyo3(get)]
183    pub most_unstable_features: Vec<(usize, usize)>,
184}
185
186#[pymethods]
187impl AllChangepointsResult {
188    fn __repr__(&self) -> String {
189        format!(
190            "AllChangepointsResult(n_features={}, with_changes={})",
191            self.results.len(),
192            self.features_with_changes.len()
193        )
194    }
195}
196
197fn compute_segment_cost(data: &[f64], start: usize, end: usize, cost: CostFunction) -> f64 {
198    if end <= start {
199        return 0.0;
200    }
201
202    let segment = &data[start..end];
203    let n = segment.len() as f64;
204
205    match cost {
206        CostFunction::L2 => {
207            let mean = segment.iter().sum::<f64>() / n;
208            segment.iter().map(|&x| (x - mean).powi(2)).sum()
209        }
210        CostFunction::L1 => {
211            let mut sorted: Vec<f64> = segment.to_vec();
212            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
213            let median = if sorted.len().is_multiple_of(2) {
214                (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
215            } else {
216                sorted[sorted.len() / 2]
217            };
218            segment.iter().map(|&x| (x - median).abs()).sum()
219        }
220        CostFunction::Normal => {
221            let mean = segment.iter().sum::<f64>() / n;
222            let var = segment.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
223            if var > 1e-12 { n * (1.0 + var.ln()) } else { n }
224        }
225        CostFunction::Poisson => {
226            let mean = segment.iter().sum::<f64>() / n;
227            if mean > 1e-12 {
228                2.0 * segment
229                    .iter()
230                    .map(|&x| {
231                        let x = x.max(1e-12);
232                        x * (x / mean).ln() - x + mean
233                    })
234                    .sum::<f64>()
235            } else {
236                0.0
237            }
238        }
239    }
240}
241
242fn pelt(data: &[f64], penalty: f64, min_size: usize, cost: CostFunction) -> Vec<usize> {
243    let n = data.len();
244    if n < 2 * min_size {
245        return vec![];
246    }
247
248    let mut f = vec![f64::INFINITY; n + 1];
249    let mut cp = vec![0usize; n + 1];
250    let mut r: Vec<usize> = vec![0];
251
252    f[0] = -penalty;
253
254    for t in min_size..=n {
255        let mut new_r = Vec::new();
256
257        for &s in &r {
258            if t - s >= min_size {
259                let cost_val = compute_segment_cost(data, s, t, cost);
260                let candidate = f[s] + cost_val + penalty;
261
262                if candidate < f[t] {
263                    f[t] = candidate;
264                    cp[t] = s;
265                }
266
267                if f[s] + cost_val + penalty <= f[t] + penalty {
268                    new_r.push(s);
269                }
270            }
271        }
272
273        new_r.push(t);
274        r = new_r;
275    }
276
277    let mut changepoints = Vec::new();
278    let mut idx = n;
279    while cp[idx] > 0 {
280        changepoints.push(cp[idx]);
281        idx = cp[idx];
282    }
283
284    changepoints.reverse();
285    changepoints
286}
287
288fn binary_segmentation(
289    data: &[f64],
290    penalty: f64,
291    min_size: usize,
292    cost: CostFunction,
293    max_cp: Option<usize>,
294) -> Vec<usize> {
295    let n = data.len();
296    let max_changepoints = max_cp.unwrap_or(n / (2 * min_size));
297
298    let mut changepoints = Vec::new();
299    let mut segments: Vec<(usize, usize)> = vec![(0, n)];
300
301    while changepoints.len() < max_changepoints && !segments.is_empty() {
302        let mut best_gain = 0.0;
303        let mut best_cp = None;
304        let mut best_seg_idx = 0;
305
306        for (seg_idx, &(start, end)) in segments.iter().enumerate() {
307            if end - start < 2 * min_size {
308                continue;
309            }
310
311            let full_cost = compute_segment_cost(data, start, end, cost);
312
313            for cp in (start + min_size)..(end - min_size + 1) {
314                let left_cost = compute_segment_cost(data, start, cp, cost);
315                let right_cost = compute_segment_cost(data, cp, end, cost);
316                let gain = full_cost - left_cost - right_cost - penalty;
317
318                if gain > best_gain {
319                    best_gain = gain;
320                    best_cp = Some(cp);
321                    best_seg_idx = seg_idx;
322                }
323            }
324        }
325
326        if let Some(cp) = best_cp {
327            let (start, end) = segments.remove(best_seg_idx);
328            segments.push((start, cp));
329            segments.push((cp, end));
330            changepoints.push(cp);
331        } else {
332            break;
333        }
334    }
335
336    changepoints.sort();
337    changepoints
338}
339
340fn bottom_up(
341    data: &[f64],
342    penalty: f64,
343    min_size: usize,
344    cost: CostFunction,
345    max_cp: Option<usize>,
346) -> Vec<usize> {
347    let n = data.len();
348    let max_changepoints = max_cp.unwrap_or(n / min_size);
349
350    let mut changepoints: Vec<usize> = (min_size..n).step_by(min_size).collect();
351
352    if changepoints.is_empty() {
353        return vec![];
354    }
355
356    while changepoints.len() > max_changepoints {
357        let mut min_cost_increase = f64::INFINITY;
358        let mut merge_idx = 0;
359
360        for i in 0..changepoints.len() {
361            let start = if i == 0 { 0 } else { changepoints[i - 1] };
362            let mid = changepoints[i];
363            let end = if i + 1 < changepoints.len() {
364                changepoints[i + 1]
365            } else {
366                n
367            };
368
369            let left_cost = compute_segment_cost(data, start, mid, cost);
370            let right_cost = compute_segment_cost(data, mid, end, cost);
371            let merged_cost = compute_segment_cost(data, start, end, cost);
372
373            let cost_increase = merged_cost - left_cost - right_cost + penalty;
374
375            if cost_increase < min_cost_increase {
376                min_cost_increase = cost_increase;
377                merge_idx = i;
378            }
379        }
380
381        if min_cost_increase > penalty {
382            break;
383        }
384
385        changepoints.remove(merge_idx);
386    }
387
388    changepoints
389}
390
391fn detect_changepoints_single(
392    shap_values: &[f64],
393    time_points: &[f64],
394    feature_idx: usize,
395    config: &ChangepointConfig,
396) -> ChangepointResult {
397    let n = shap_values.len();
398
399    let cp_indices = match config.method {
400        ChangepointMethod::PELT => pelt(shap_values, config.penalty, config.min_size, config.cost),
401        ChangepointMethod::BinarySegment => binary_segmentation(
402            shap_values,
403            config.penalty,
404            config.min_size,
405            config.cost,
406            config.max_changepoints,
407        ),
408        ChangepointMethod::BottomUp => bottom_up(
409            shap_values,
410            config.penalty,
411            config.min_size,
412            config.cost,
413            config.max_changepoints,
414        ),
415    };
416
417    let mut segments: Vec<(usize, usize)> = Vec::new();
418    let mut prev = 0;
419    for &cp in &cp_indices {
420        segments.push((prev, cp));
421        prev = cp;
422    }
423    segments.push((prev, n));
424
425    let segment_means: Vec<f64> = segments
426        .iter()
427        .map(|&(start, end)| {
428            if end > start {
429                shap_values[start..end].iter().sum::<f64>() / (end - start) as f64
430            } else {
431                0.0
432            }
433        })
434        .collect();
435
436    let total_cost: f64 = segments
437        .iter()
438        .map(|&(start, end)| compute_segment_cost(shap_values, start, end, config.cost))
439        .sum();
440
441    let changepoints: Vec<Changepoint> = cp_indices
442        .iter()
443        .enumerate()
444        .map(|(i, &idx)| {
445            let mean_before = segment_means[i];
446            let mean_after = segment_means[i + 1];
447
448            let start = if i == 0 { 0 } else { cp_indices[i - 1] };
449            let end = if i + 1 < cp_indices.len() {
450                cp_indices[i + 1]
451            } else {
452                n
453            };
454
455            let cost_without =
456                compute_segment_cost(shap_values, start, end, config.cost) + config.penalty;
457            let cost_with = compute_segment_cost(shap_values, start, idx, config.cost)
458                + compute_segment_cost(shap_values, idx, end, config.cost);
459
460            Changepoint {
461                index: idx,
462                time: time_points.get(idx).copied().unwrap_or(idx as f64),
463                cost_improvement: cost_without - cost_with,
464                mean_before,
465                mean_after,
466            }
467        })
468        .collect();
469
470    ChangepointResult {
471        feature_idx,
472        changepoints,
473        segments,
474        segment_means,
475        total_cost,
476        n_changepoints: cp_indices.len(),
477    }
478}
479
480#[pyfunction]
481#[pyo3(signature = (shap_values, time_points, n_samples, n_features, config))]
482pub fn detect_changepoints(
483    shap_values: Vec<Vec<Vec<f64>>>,
484    time_points: Vec<f64>,
485    n_samples: usize,
486    n_features: usize,
487    config: &ChangepointConfig,
488) -> PyResult<AllChangepointsResult> {
489    let n_times = time_points.len();
490
491    if shap_values.len() != n_samples {
492        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
493            "shap_values first dimension must match n_samples",
494        ));
495    }
496
497    let results: Vec<ChangepointResult> = (0..n_features)
498        .into_par_iter()
499        .map(|f| {
500            let aggregated: Vec<f64> = (0..n_times)
501                .map(|t| {
502                    shap_values
503                        .iter()
504                        .map(|sample| sample[f][t].abs())
505                        .sum::<f64>()
506                        / n_samples as f64
507                })
508                .collect();
509
510            detect_changepoints_single(&aggregated, &time_points, f, config)
511        })
512        .collect();
513
514    let features_with_changes: Vec<usize> = results
515        .iter()
516        .filter(|r| r.n_changepoints > 0)
517        .map(|r| r.feature_idx)
518        .collect();
519
520    let mut most_unstable_features: Vec<(usize, usize)> = results
521        .iter()
522        .map(|r| (r.feature_idx, r.n_changepoints))
523        .collect();
524    most_unstable_features.sort_by(|a, b| b.1.cmp(&a.1));
525
526    Ok(AllChangepointsResult {
527        results,
528        features_with_changes,
529        most_unstable_features,
530    })
531}
532
533#[pyfunction]
534#[pyo3(signature = (data, time_points, config))]
535pub fn detect_changepoints_single_series(
536    data: Vec<f64>,
537    time_points: Vec<f64>,
538    config: &ChangepointConfig,
539) -> PyResult<ChangepointResult> {
540    if data.len() != time_points.len() {
541        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
542            "data and time_points must have equal length",
543        ));
544    }
545
546    Ok(detect_changepoints_single(&data, &time_points, 0, config))
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552
553    #[test]
554    fn test_config() {
555        let config =
556            ChangepointConfig::new(ChangepointMethod::PELT, CostFunction::L2, 1.0, 2, None)
557                .unwrap();
558        assert_eq!(config.min_size, 2);
559    }
560
561    #[test]
562    fn test_segment_cost_l2() {
563        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
564        let cost = compute_segment_cost(&data, 0, 5, CostFunction::L2);
565        assert!(cost > 0.0);
566    }
567
568    #[test]
569    fn test_pelt_clear_changepoint() {
570        let mut data: Vec<f64> = vec![1.0; 20];
571        data.extend(vec![5.0; 20]);
572
573        let cp = binary_segmentation(&data, 5.0, 5, CostFunction::L2, Some(3));
574        assert!(!cp.is_empty());
575        assert!((cp[0] as i32 - 20).abs() <= 3);
576    }
577
578    #[test]
579    fn test_binary_segmentation() {
580        let mut data: Vec<f64> = vec![1.0; 15];
581        data.extend(vec![5.0; 15]);
582
583        let cp = binary_segmentation(&data, 5.0, 5, CostFunction::L2, Some(3));
584        assert!(!cp.is_empty());
585    }
586
587    #[test]
588    fn test_bottom_up() {
589        let mut data: Vec<f64> = vec![1.0; 20];
590        data.extend(vec![5.0; 20]);
591
592        let cp = bottom_up(&data, 10.0, 5, CostFunction::L2, Some(5));
593        assert!(!cp.is_empty());
594    }
595
596    #[test]
597    fn test_detect_single_series() {
598        let data: Vec<f64> = (0..30).map(|i| if i < 15 { 1.0 } else { 5.0 }).collect();
599        let time: Vec<f64> = (0..30).map(|i| i as f64).collect();
600
601        let config = ChangepointConfig::new(
602            ChangepointMethod::BinarySegment,
603            CostFunction::L2,
604            5.0,
605            5,
606            None,
607        )
608        .unwrap();
609
610        let result = detect_changepoints_single_series(data, time, &config).unwrap();
611        assert!(result.n_changepoints >= 1);
612    }
613}