1#![allow(
2 unused_variables,
3 unused_imports,
4 clippy::too_many_arguments,
5 clippy::needless_range_loop
6)]
7
8use pyo3::prelude::*;
9use rayon::prelude::*;
10
11#[derive(Debug, Clone, Copy, PartialEq)]
12#[pyclass]
13pub enum ChangepointMethod {
14 PELT,
15 BinarySegment,
16 BottomUp,
17}
18
19#[pymethods]
20impl ChangepointMethod {
21 #[new]
22 fn new(name: &str) -> PyResult<Self> {
23 match name.to_lowercase().as_str() {
24 "pelt" => Ok(ChangepointMethod::PELT),
25 "binary" | "binarysegment" | "binary_segment" => Ok(ChangepointMethod::BinarySegment),
26 "bottomup" | "bottom_up" => Ok(ChangepointMethod::BottomUp),
27 _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
28 "Unknown method. Use 'pelt', 'binary_segment', or 'bottom_up'",
29 )),
30 }
31 }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq)]
35#[pyclass]
36pub enum CostFunction {
37 L2,
38 L1,
39 Normal,
40 Poisson,
41}
42
43#[pymethods]
44impl CostFunction {
45 #[new]
46 fn new(name: &str) -> PyResult<Self> {
47 match name.to_lowercase().as_str() {
48 "l2" | "quadratic" | "normal_mean" => Ok(CostFunction::L2),
49 "l1" | "absolute" => Ok(CostFunction::L1),
50 "normal" | "normal_meanvar" => Ok(CostFunction::Normal),
51 "poisson" => Ok(CostFunction::Poisson),
52 _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
53 "Unknown cost function",
54 )),
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
60#[pyclass]
61pub struct ChangepointConfig {
62 #[pyo3(get, set)]
63 pub method: ChangepointMethod,
64 #[pyo3(get, set)]
65 pub cost: CostFunction,
66 #[pyo3(get, set)]
67 pub penalty: f64,
68 #[pyo3(get, set)]
69 pub min_size: usize,
70 #[pyo3(get, set)]
71 pub max_changepoints: Option<usize>,
72}
73
74#[pymethods]
75impl ChangepointConfig {
76 #[new]
77 #[pyo3(signature = (
78 method=ChangepointMethod::PELT,
79 cost=CostFunction::L2,
80 penalty=1.0,
81 min_size=2,
82 max_changepoints=None
83 ))]
84 pub fn new(
85 method: ChangepointMethod,
86 cost: CostFunction,
87 penalty: f64,
88 min_size: usize,
89 max_changepoints: Option<usize>,
90 ) -> PyResult<Self> {
91 if penalty < 0.0 {
92 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
93 "penalty must be non-negative",
94 ));
95 }
96 if min_size == 0 {
97 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
98 "min_size must be positive",
99 ));
100 }
101
102 Ok(ChangepointConfig {
103 method,
104 cost,
105 penalty,
106 min_size,
107 max_changepoints,
108 })
109 }
110}
111
112#[derive(Debug, Clone)]
113#[pyclass]
114pub struct Changepoint {
115 #[pyo3(get)]
116 pub index: usize,
117 #[pyo3(get)]
118 pub time: f64,
119 #[pyo3(get)]
120 pub cost_improvement: f64,
121 #[pyo3(get)]
122 pub mean_before: f64,
123 #[pyo3(get)]
124 pub mean_after: f64,
125}
126
127#[pymethods]
128impl Changepoint {
129 fn __repr__(&self) -> String {
130 format!(
131 "Changepoint(idx={}, time={:.2}, delta={:.4})",
132 self.index,
133 self.time,
134 self.mean_after - self.mean_before
135 )
136 }
137}
138
139#[derive(Debug, Clone)]
140#[pyclass]
141pub struct ChangepointResult {
142 #[pyo3(get)]
143 pub feature_idx: usize,
144 #[pyo3(get)]
145 pub changepoints: Vec<Changepoint>,
146 #[pyo3(get)]
147 pub segments: Vec<(usize, usize)>,
148 #[pyo3(get)]
149 pub segment_means: Vec<f64>,
150 #[pyo3(get)]
151 pub total_cost: f64,
152 #[pyo3(get)]
153 pub n_changepoints: usize,
154}
155
156#[pymethods]
157impl ChangepointResult {
158 fn __repr__(&self) -> String {
159 format!(
160 "ChangepointResult(feature={}, n_changepoints={})",
161 self.feature_idx, self.n_changepoints
162 )
163 }
164
165 fn get_segment_at(&self, time_idx: usize) -> usize {
166 for (seg_idx, &(start, end)) in self.segments.iter().enumerate() {
167 if time_idx >= start && time_idx < end {
168 return seg_idx;
169 }
170 }
171 self.segments.len().saturating_sub(1)
172 }
173}
174
175#[derive(Debug, Clone)]
176#[pyclass]
177pub struct AllChangepointsResult {
178 #[pyo3(get)]
179 pub results: Vec<ChangepointResult>,
180 #[pyo3(get)]
181 pub features_with_changes: Vec<usize>,
182 #[pyo3(get)]
183 pub most_unstable_features: Vec<(usize, usize)>,
184}
185
186#[pymethods]
187impl AllChangepointsResult {
188 fn __repr__(&self) -> String {
189 format!(
190 "AllChangepointsResult(n_features={}, with_changes={})",
191 self.results.len(),
192 self.features_with_changes.len()
193 )
194 }
195}
196
197fn compute_segment_cost(data: &[f64], start: usize, end: usize, cost: CostFunction) -> f64 {
198 if end <= start {
199 return 0.0;
200 }
201
202 let segment = &data[start..end];
203 let n = segment.len() as f64;
204
205 match cost {
206 CostFunction::L2 => {
207 let mean = segment.iter().sum::<f64>() / n;
208 segment.iter().map(|&x| (x - mean).powi(2)).sum()
209 }
210 CostFunction::L1 => {
211 let mut sorted: Vec<f64> = segment.to_vec();
212 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
213 let median = if sorted.len().is_multiple_of(2) {
214 (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
215 } else {
216 sorted[sorted.len() / 2]
217 };
218 segment.iter().map(|&x| (x - median).abs()).sum()
219 }
220 CostFunction::Normal => {
221 let mean = segment.iter().sum::<f64>() / n;
222 let var = segment.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
223 if var > 1e-12 { n * (1.0 + var.ln()) } else { n }
224 }
225 CostFunction::Poisson => {
226 let mean = segment.iter().sum::<f64>() / n;
227 if mean > 1e-12 {
228 2.0 * segment
229 .iter()
230 .map(|&x| {
231 let x = x.max(1e-12);
232 x * (x / mean).ln() - x + mean
233 })
234 .sum::<f64>()
235 } else {
236 0.0
237 }
238 }
239 }
240}
241
242fn pelt(data: &[f64], penalty: f64, min_size: usize, cost: CostFunction) -> Vec<usize> {
243 let n = data.len();
244 if n < 2 * min_size {
245 return vec![];
246 }
247
248 let mut f = vec![f64::INFINITY; n + 1];
249 let mut cp = vec![0usize; n + 1];
250 let mut r: Vec<usize> = vec![0];
251
252 f[0] = -penalty;
253
254 for t in min_size..=n {
255 let mut new_r = Vec::new();
256
257 for &s in &r {
258 if t - s >= min_size {
259 let cost_val = compute_segment_cost(data, s, t, cost);
260 let candidate = f[s] + cost_val + penalty;
261
262 if candidate < f[t] {
263 f[t] = candidate;
264 cp[t] = s;
265 }
266
267 if f[s] + cost_val + penalty <= f[t] + penalty {
268 new_r.push(s);
269 }
270 }
271 }
272
273 new_r.push(t);
274 r = new_r;
275 }
276
277 let mut changepoints = Vec::new();
278 let mut idx = n;
279 while cp[idx] > 0 {
280 changepoints.push(cp[idx]);
281 idx = cp[idx];
282 }
283
284 changepoints.reverse();
285 changepoints
286}
287
288fn binary_segmentation(
289 data: &[f64],
290 penalty: f64,
291 min_size: usize,
292 cost: CostFunction,
293 max_cp: Option<usize>,
294) -> Vec<usize> {
295 let n = data.len();
296 let max_changepoints = max_cp.unwrap_or(n / (2 * min_size));
297
298 let mut changepoints = Vec::new();
299 let mut segments: Vec<(usize, usize)> = vec![(0, n)];
300
301 while changepoints.len() < max_changepoints && !segments.is_empty() {
302 let mut best_gain = 0.0;
303 let mut best_cp = None;
304 let mut best_seg_idx = 0;
305
306 for (seg_idx, &(start, end)) in segments.iter().enumerate() {
307 if end - start < 2 * min_size {
308 continue;
309 }
310
311 let full_cost = compute_segment_cost(data, start, end, cost);
312
313 for cp in (start + min_size)..(end - min_size + 1) {
314 let left_cost = compute_segment_cost(data, start, cp, cost);
315 let right_cost = compute_segment_cost(data, cp, end, cost);
316 let gain = full_cost - left_cost - right_cost - penalty;
317
318 if gain > best_gain {
319 best_gain = gain;
320 best_cp = Some(cp);
321 best_seg_idx = seg_idx;
322 }
323 }
324 }
325
326 if let Some(cp) = best_cp {
327 let (start, end) = segments.remove(best_seg_idx);
328 segments.push((start, cp));
329 segments.push((cp, end));
330 changepoints.push(cp);
331 } else {
332 break;
333 }
334 }
335
336 changepoints.sort();
337 changepoints
338}
339
340fn bottom_up(
341 data: &[f64],
342 penalty: f64,
343 min_size: usize,
344 cost: CostFunction,
345 max_cp: Option<usize>,
346) -> Vec<usize> {
347 let n = data.len();
348 let max_changepoints = max_cp.unwrap_or(n / min_size);
349
350 let mut changepoints: Vec<usize> = (min_size..n).step_by(min_size).collect();
351
352 if changepoints.is_empty() {
353 return vec![];
354 }
355
356 while changepoints.len() > max_changepoints {
357 let mut min_cost_increase = f64::INFINITY;
358 let mut merge_idx = 0;
359
360 for i in 0..changepoints.len() {
361 let start = if i == 0 { 0 } else { changepoints[i - 1] };
362 let mid = changepoints[i];
363 let end = if i + 1 < changepoints.len() {
364 changepoints[i + 1]
365 } else {
366 n
367 };
368
369 let left_cost = compute_segment_cost(data, start, mid, cost);
370 let right_cost = compute_segment_cost(data, mid, end, cost);
371 let merged_cost = compute_segment_cost(data, start, end, cost);
372
373 let cost_increase = merged_cost - left_cost - right_cost + penalty;
374
375 if cost_increase < min_cost_increase {
376 min_cost_increase = cost_increase;
377 merge_idx = i;
378 }
379 }
380
381 if min_cost_increase > penalty {
382 break;
383 }
384
385 changepoints.remove(merge_idx);
386 }
387
388 changepoints
389}
390
391fn detect_changepoints_single(
392 shap_values: &[f64],
393 time_points: &[f64],
394 feature_idx: usize,
395 config: &ChangepointConfig,
396) -> ChangepointResult {
397 let n = shap_values.len();
398
399 let cp_indices = match config.method {
400 ChangepointMethod::PELT => pelt(shap_values, config.penalty, config.min_size, config.cost),
401 ChangepointMethod::BinarySegment => binary_segmentation(
402 shap_values,
403 config.penalty,
404 config.min_size,
405 config.cost,
406 config.max_changepoints,
407 ),
408 ChangepointMethod::BottomUp => bottom_up(
409 shap_values,
410 config.penalty,
411 config.min_size,
412 config.cost,
413 config.max_changepoints,
414 ),
415 };
416
417 let mut segments: Vec<(usize, usize)> = Vec::new();
418 let mut prev = 0;
419 for &cp in &cp_indices {
420 segments.push((prev, cp));
421 prev = cp;
422 }
423 segments.push((prev, n));
424
425 let segment_means: Vec<f64> = segments
426 .iter()
427 .map(|&(start, end)| {
428 if end > start {
429 shap_values[start..end].iter().sum::<f64>() / (end - start) as f64
430 } else {
431 0.0
432 }
433 })
434 .collect();
435
436 let total_cost: f64 = segments
437 .iter()
438 .map(|&(start, end)| compute_segment_cost(shap_values, start, end, config.cost))
439 .sum();
440
441 let changepoints: Vec<Changepoint> = cp_indices
442 .iter()
443 .enumerate()
444 .map(|(i, &idx)| {
445 let mean_before = segment_means[i];
446 let mean_after = segment_means[i + 1];
447
448 let start = if i == 0 { 0 } else { cp_indices[i - 1] };
449 let end = if i + 1 < cp_indices.len() {
450 cp_indices[i + 1]
451 } else {
452 n
453 };
454
455 let cost_without =
456 compute_segment_cost(shap_values, start, end, config.cost) + config.penalty;
457 let cost_with = compute_segment_cost(shap_values, start, idx, config.cost)
458 + compute_segment_cost(shap_values, idx, end, config.cost);
459
460 Changepoint {
461 index: idx,
462 time: time_points.get(idx).copied().unwrap_or(idx as f64),
463 cost_improvement: cost_without - cost_with,
464 mean_before,
465 mean_after,
466 }
467 })
468 .collect();
469
470 ChangepointResult {
471 feature_idx,
472 changepoints,
473 segments,
474 segment_means,
475 total_cost,
476 n_changepoints: cp_indices.len(),
477 }
478}
479
480#[pyfunction]
481#[pyo3(signature = (shap_values, time_points, n_samples, n_features, config))]
482pub fn detect_changepoints(
483 shap_values: Vec<Vec<Vec<f64>>>,
484 time_points: Vec<f64>,
485 n_samples: usize,
486 n_features: usize,
487 config: &ChangepointConfig,
488) -> PyResult<AllChangepointsResult> {
489 let n_times = time_points.len();
490
491 if shap_values.len() != n_samples {
492 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
493 "shap_values first dimension must match n_samples",
494 ));
495 }
496
497 let results: Vec<ChangepointResult> = (0..n_features)
498 .into_par_iter()
499 .map(|f| {
500 let aggregated: Vec<f64> = (0..n_times)
501 .map(|t| {
502 shap_values
503 .iter()
504 .map(|sample| sample[f][t].abs())
505 .sum::<f64>()
506 / n_samples as f64
507 })
508 .collect();
509
510 detect_changepoints_single(&aggregated, &time_points, f, config)
511 })
512 .collect();
513
514 let features_with_changes: Vec<usize> = results
515 .iter()
516 .filter(|r| r.n_changepoints > 0)
517 .map(|r| r.feature_idx)
518 .collect();
519
520 let mut most_unstable_features: Vec<(usize, usize)> = results
521 .iter()
522 .map(|r| (r.feature_idx, r.n_changepoints))
523 .collect();
524 most_unstable_features.sort_by(|a, b| b.1.cmp(&a.1));
525
526 Ok(AllChangepointsResult {
527 results,
528 features_with_changes,
529 most_unstable_features,
530 })
531}
532
533#[pyfunction]
534#[pyo3(signature = (data, time_points, config))]
535pub fn detect_changepoints_single_series(
536 data: Vec<f64>,
537 time_points: Vec<f64>,
538 config: &ChangepointConfig,
539) -> PyResult<ChangepointResult> {
540 if data.len() != time_points.len() {
541 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
542 "data and time_points must have equal length",
543 ));
544 }
545
546 Ok(detect_changepoints_single(&data, &time_points, 0, config))
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552
553 #[test]
554 fn test_config() {
555 let config =
556 ChangepointConfig::new(ChangepointMethod::PELT, CostFunction::L2, 1.0, 2, None)
557 .unwrap();
558 assert_eq!(config.min_size, 2);
559 }
560
561 #[test]
562 fn test_segment_cost_l2() {
563 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
564 let cost = compute_segment_cost(&data, 0, 5, CostFunction::L2);
565 assert!(cost > 0.0);
566 }
567
568 #[test]
569 fn test_pelt_clear_changepoint() {
570 let mut data: Vec<f64> = vec![1.0; 20];
571 data.extend(vec![5.0; 20]);
572
573 let cp = binary_segmentation(&data, 5.0, 5, CostFunction::L2, Some(3));
574 assert!(!cp.is_empty());
575 assert!((cp[0] as i32 - 20).abs() <= 3);
576 }
577
578 #[test]
579 fn test_binary_segmentation() {
580 let mut data: Vec<f64> = vec![1.0; 15];
581 data.extend(vec![5.0; 15]);
582
583 let cp = binary_segmentation(&data, 5.0, 5, CostFunction::L2, Some(3));
584 assert!(!cp.is_empty());
585 }
586
587 #[test]
588 fn test_bottom_up() {
589 let mut data: Vec<f64> = vec![1.0; 20];
590 data.extend(vec![5.0; 20]);
591
592 let cp = bottom_up(&data, 10.0, 5, CostFunction::L2, Some(5));
593 assert!(!cp.is_empty());
594 }
595
596 #[test]
597 fn test_detect_single_series() {
598 let data: Vec<f64> = (0..30).map(|i| if i < 15 { 1.0 } else { 5.0 }).collect();
599 let time: Vec<f64> = (0..30).map(|i| i as f64).collect();
600
601 let config = ChangepointConfig::new(
602 ChangepointMethod::BinarySegment,
603 CostFunction::L2,
604 5.0,
605 5,
606 None,
607 )
608 .unwrap();
609
610 let result = detect_changepoints_single_series(data, time, &config).unwrap();
611 assert!(result.n_changepoints >= 1);
612 }
613}