1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum CostFunction {
52 L2,
58 Normal,
65}
66
67impl CostFunction {
68 fn params_per_segment(self) -> usize {
70 match self {
71 CostFunction::L2 => 1,
72 CostFunction::Normal => 2,
73 }
74 }
75}
76
77#[derive(Debug, Clone, Copy)]
79pub enum Penalty {
80 Bic,
84 Custom(f64),
86}
87
88pub struct Pelt {
114 cost: CostFunction,
116 penalty: Penalty,
118 min_segment_len: usize,
120}
121
122#[derive(Debug, Clone)]
124pub struct PeltResult {
125 pub changepoints: Vec<usize>,
131}
132
133#[derive(Debug, Clone)]
135pub struct MultiPeltResult {
136 pub changepoints: Vec<usize>,
138}
139
140impl Pelt {
141 pub fn new(cost: CostFunction, penalty: Penalty) -> Option<Self> {
154 Self::with_min_segment_len(cost, penalty, 2)
155 }
156
157 pub fn with_min_segment_len(
170 cost: CostFunction,
171 penalty: Penalty,
172 min_segment_len: usize,
173 ) -> Option<Self> {
174 if let Penalty::Custom(p) = penalty {
175 if !p.is_finite() || p <= 0.0 {
176 return None;
177 }
178 }
179 if min_segment_len < 2 {
180 return None;
181 }
182 Some(Self {
183 cost,
184 penalty,
185 min_segment_len,
186 })
187 }
188
189 pub fn detect(&self, data: &[f64]) -> PeltResult {
218 let n = data.len();
219
220 if n < 2 * self.min_segment_len {
221 return PeltResult {
222 changepoints: Vec::new(),
223 };
224 }
225
226 let penalty_value = self.resolve_penalty(n);
227
228 let mut cum_sum = vec![0.0_f64; n + 1];
232 let mut cum_sum_sq = vec![0.0_f64; n + 1];
233 for i in 0..n {
234 cum_sum[i + 1] = cum_sum[i] + data[i];
235 cum_sum_sq[i + 1] = cum_sum_sq[i] + data[i] * data[i];
236 }
237
238 let mut f = vec![0.0_f64; n + 1];
241 f[0] = -penalty_value;
242
243 let mut last_change = vec![0_usize; n + 1];
245
246 let mut candidates: Vec<usize> = vec![0];
248
249 for t in self.min_segment_len..=n {
250 let mut best_cost = f64::INFINITY;
252 let mut best_tau = 0;
253
254 for &tau in &candidates {
255 let seg_len = t - tau;
256 if seg_len < self.min_segment_len {
257 continue;
258 }
259
260 let cost = self.segment_cost(&cum_sum, &cum_sum_sq, tau, t);
261 let total = f[tau] + cost + penalty_value;
262
263 if total < best_cost {
264 best_cost = total;
265 best_tau = tau;
266 }
267 }
268
269 f[t] = best_cost;
270 last_change[t] = best_tau;
271
272 candidates.retain(|&tau| {
276 let seg_len = t - tau;
277 if seg_len < self.min_segment_len {
278 return true; }
280 let cost = self.segment_cost(&cum_sum, &cum_sum_sq, tau, t);
281 f[tau] + cost < f[t] + penalty_value
282 });
283
284 candidates.push(t);
285 }
286
287 let mut changepoints = Vec::new();
289 let mut t = n;
290 while t > 0 {
291 let tau = last_change[t];
292 if tau > 0 {
293 changepoints.push(tau);
294 }
295 t = tau;
296 }
297
298 changepoints.sort_unstable();
299
300 PeltResult { changepoints }
301 }
302
303 pub fn detect_multi(&self, signals: &[&[f64]]) -> Option<MultiPeltResult> {
343 if signals.is_empty() {
344 return Some(MultiPeltResult {
345 changepoints: Vec::new(),
346 });
347 }
348
349 let n = signals[0].len();
350 if signals.iter().any(|s| s.len() != n) {
351 return None;
352 }
353
354 if n < 2 * self.min_segment_len {
355 return Some(MultiPeltResult {
356 changepoints: Vec::new(),
357 });
358 }
359
360 let n_channels = signals.len();
361 let penalty_value = self.resolve_penalty(n) * n_channels as f64;
362
363 let mut cum_sums: Vec<Vec<f64>> = Vec::with_capacity(n_channels);
365 let mut cum_sum_sqs: Vec<Vec<f64>> = Vec::with_capacity(n_channels);
366
367 for signal in signals {
368 let mut cs = vec![0.0_f64; n + 1];
369 let mut css = vec![0.0_f64; n + 1];
370 for i in 0..n {
371 cs[i + 1] = cs[i] + signal[i];
372 css[i + 1] = css[i] + signal[i] * signal[i];
373 }
374 cum_sums.push(cs);
375 cum_sum_sqs.push(css);
376 }
377
378 let mut f = vec![0.0_f64; n + 1];
379 f[0] = -penalty_value;
380 let mut last_change = vec![0_usize; n + 1];
381 let mut candidates: Vec<usize> = vec![0];
382
383 for t in self.min_segment_len..=n {
384 let mut best_cost = f64::INFINITY;
385 let mut best_tau = 0;
386
387 for &tau in &candidates {
388 let seg_len = t - tau;
389 if seg_len < self.min_segment_len {
390 continue;
391 }
392
393 let cost: f64 = (0..n_channels)
394 .map(|ch| self.segment_cost(&cum_sums[ch], &cum_sum_sqs[ch], tau, t))
395 .sum();
396 let total = f[tau] + cost + penalty_value;
397
398 if total < best_cost {
399 best_cost = total;
400 best_tau = tau;
401 }
402 }
403
404 f[t] = best_cost;
405 last_change[t] = best_tau;
406
407 candidates.retain(|&tau| {
408 let seg_len = t - tau;
409 if seg_len < self.min_segment_len {
410 return true;
411 }
412 let cost: f64 = (0..n_channels)
413 .map(|ch| self.segment_cost(&cum_sums[ch], &cum_sum_sqs[ch], tau, t))
414 .sum();
415 f[tau] + cost < f[t] + penalty_value
416 });
417
418 candidates.push(t);
419 }
420
421 let mut changepoints = Vec::new();
422 let mut t = n;
423 while t > 0 {
424 let tau = last_change[t];
425 if tau > 0 {
426 changepoints.push(tau);
427 }
428 t = tau;
429 }
430 changepoints.sort_unstable();
431
432 Some(MultiPeltResult { changepoints })
433 }
434
435 fn resolve_penalty(&self, n: usize) -> f64 {
437 match self.penalty {
438 Penalty::Bic => {
439 let p = self.cost.params_per_segment() as f64;
440 p * (n as f64).ln()
441 }
442 Penalty::Custom(val) => val,
443 }
444 }
445
446 fn segment_cost(&self, cum_sum: &[f64], cum_sum_sq: &[f64], start: usize, end: usize) -> f64 {
452 let seg_len = (end - start) as f64;
453 let sum = cum_sum[end] - cum_sum[start];
454 let sum_sq = cum_sum_sq[end] - cum_sum_sq[start];
455 let mean = sum / seg_len;
456
457 match self.cost {
458 CostFunction::L2 => {
459 sum_sq - seg_len * mean * mean
461 }
462 CostFunction::Normal => {
463 let variance = (sum_sq - seg_len * mean * mean) / seg_len;
465 if variance <= 0.0 {
466 seg_len * (f64::MIN_POSITIVE).ln()
469 } else {
470 seg_len * variance.ln()
471 }
472 }
473 }
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480
481 #[test]
484 fn test_pelt_valid_construction() {
485 assert!(Pelt::new(CostFunction::L2, Penalty::Bic).is_some());
486 assert!(Pelt::new(CostFunction::Normal, Penalty::Bic).is_some());
487 assert!(Pelt::new(CostFunction::L2, Penalty::Custom(10.0)).is_some());
488 }
489
490 #[test]
491 fn test_pelt_invalid_custom_penalty() {
492 assert!(Pelt::new(CostFunction::L2, Penalty::Custom(0.0)).is_none());
493 assert!(Pelt::new(CostFunction::L2, Penalty::Custom(-1.0)).is_none());
494 assert!(Pelt::new(CostFunction::L2, Penalty::Custom(f64::NAN)).is_none());
495 assert!(Pelt::new(CostFunction::L2, Penalty::Custom(f64::INFINITY)).is_none());
496 }
497
498 #[test]
499 fn test_pelt_invalid_min_segment_len() {
500 assert!(Pelt::with_min_segment_len(CostFunction::L2, Penalty::Bic, 0).is_none());
501 assert!(Pelt::with_min_segment_len(CostFunction::L2, Penalty::Bic, 1).is_none());
502 assert!(Pelt::with_min_segment_len(CostFunction::L2, Penalty::Bic, 2).is_some());
503 }
504
505 #[test]
508 fn test_pelt_empty_data() {
509 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
510 let result = pelt.detect(&[]);
511 assert!(result.changepoints.is_empty());
512 }
513
514 #[test]
515 fn test_pelt_too_short_data() {
516 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
517 let result = pelt.detect(&[1.0, 2.0, 3.0]);
519 assert!(result.changepoints.is_empty());
520 }
521
522 #[test]
525 fn test_pelt_constant_data_no_changepoint() {
526 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
527 let data = vec![5.0; 100];
528 let result = pelt.detect(&data);
529 assert!(
530 result.changepoints.is_empty(),
531 "constant data should have no changepoints, got {:?}",
532 result.changepoints
533 );
534 }
535
536 #[test]
537 fn test_pelt_normal_cost_constant_data() {
538 let pelt = Pelt::new(CostFunction::Normal, Penalty::Bic).expect("valid");
539 let data = vec![5.0; 100];
540 let result = pelt.detect(&data);
541 assert!(
542 result.changepoints.is_empty(),
543 "constant data should have no changepoints with Normal cost, got {:?}",
544 result.changepoints
545 );
546 }
547
548 #[test]
551 fn test_pelt_single_mean_shift_l2() {
552 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
553
554 let mut data = vec![0.0; 50];
555 data.extend(vec![5.0; 50]);
556
557 let result = pelt.detect(&data);
558 assert_eq!(
559 result.changepoints.len(),
560 1,
561 "expected 1 changepoint, got {:?}",
562 result.changepoints
563 );
564 assert!(
565 (result.changepoints[0] as i64 - 50).unsigned_abs() <= 2,
566 "changepoint should be near index 50, got {}",
567 result.changepoints[0]
568 );
569 }
570
571 #[test]
572 fn test_pelt_single_mean_shift_normal() {
573 let pelt = Pelt::new(CostFunction::Normal, Penalty::Bic).expect("valid");
574
575 let mut data = vec![0.0; 50];
576 data.extend(vec![5.0; 50]);
577
578 let result = pelt.detect(&data);
579 assert_eq!(
580 result.changepoints.len(),
581 1,
582 "expected 1 changepoint with Normal cost, got {:?}",
583 result.changepoints
584 );
585 assert!(
586 (result.changepoints[0] as i64 - 50).unsigned_abs() <= 2,
587 "changepoint should be near index 50, got {}",
588 result.changepoints[0]
589 );
590 }
591
592 #[test]
595 fn test_pelt_two_changepoints() {
596 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
597
598 let mut data = vec![0.0; 40];
599 data.extend(vec![5.0; 40]);
600 data.extend(vec![0.0; 40]);
601
602 let result = pelt.detect(&data);
603 assert_eq!(
604 result.changepoints.len(),
605 2,
606 "expected 2 changepoints, got {:?}",
607 result.changepoints
608 );
609
610 assert!(
612 (result.changepoints[0] as i64 - 40).unsigned_abs() <= 2,
613 "first changepoint near 40, got {}",
614 result.changepoints[0]
615 );
616 assert!(
617 (result.changepoints[1] as i64 - 80).unsigned_abs() <= 2,
618 "second changepoint near 80, got {}",
619 result.changepoints[1]
620 );
621 }
622
623 #[test]
624 fn test_pelt_three_changepoints() {
625 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
626
627 let mut data = vec![0.0; 30];
628 data.extend(vec![4.0; 30]);
629 data.extend(vec![-2.0; 30]);
630 data.extend(vec![3.0; 30]);
631
632 let result = pelt.detect(&data);
633 assert_eq!(
634 result.changepoints.len(),
635 3,
636 "expected 3 changepoints, got {:?}",
637 result.changepoints
638 );
639
640 for i in 1..result.changepoints.len() {
642 assert!(
643 result.changepoints[i] > result.changepoints[i - 1],
644 "changepoints should be strictly increasing"
645 );
646 }
647 }
648
649 #[test]
652 fn test_pelt_variance_change_normal_cost() {
653 let pelt = Pelt::new(CostFunction::Normal, Penalty::Bic).expect("valid");
655
656 let mut data = Vec::with_capacity(200);
658 for i in 0..100 {
660 data.push(if i % 2 == 0 { 0.1 } else { -0.1 });
661 }
662 for i in 0..100 {
664 data.push(if i % 2 == 0 { 5.0 } else { -5.0 });
665 }
666
667 let result = pelt.detect(&data);
668 assert!(
669 !result.changepoints.is_empty(),
670 "Normal cost should detect variance change"
671 );
672 let cp = result.changepoints[0];
674 assert!(
675 (cp as i64 - 100).unsigned_abs() <= 5,
676 "variance changepoint should be near 100, got {}",
677 cp
678 );
679 }
680
681 #[test]
684 fn test_pelt_higher_penalty_fewer_changepoints() {
685 let mut data = vec![0.0; 30];
686 data.extend(vec![2.0; 30]);
687 data.extend(vec![0.0; 30]);
688
689 let pelt_low = Pelt::new(CostFunction::L2, Penalty::Custom(1.0)).expect("valid");
690 let pelt_high = Pelt::new(CostFunction::L2, Penalty::Custom(100.0)).expect("valid");
691
692 let result_low = pelt_low.detect(&data);
693 let result_high = pelt_high.detect(&data);
694
695 assert!(
696 result_low.changepoints.len() >= result_high.changepoints.len(),
697 "higher penalty should produce fewer or equal changepoints: low={}, high={}",
698 result_low.changepoints.len(),
699 result_high.changepoints.len()
700 );
701 }
702
703 #[test]
706 fn test_pelt_custom_min_segment_len() {
707 let mut data = vec![0.0; 50];
708 data.extend(vec![10.0; 50]);
709
710 let pelt = Pelt::with_min_segment_len(CostFunction::L2, Penalty::Bic, 10).expect("valid");
711 let result = pelt.detect(&data);
712 assert_eq!(
713 result.changepoints.len(),
714 1,
715 "should detect changepoint with min_segment_len=10"
716 );
717
718 let mut boundaries = vec![0];
720 boundaries.extend_from_slice(&result.changepoints);
721 boundaries.push(data.len());
722 for i in 1..boundaries.len() {
723 let seg_len = boundaries[i] - boundaries[i - 1];
724 assert!(
725 seg_len >= 10,
726 "segment length {} is less than min_segment_len=10",
727 seg_len
728 );
729 }
730 }
731
732 #[test]
743 fn test_pelt_exact_small_example() {
744 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
745 let data = [0.0, 0.0, 10.0, 10.0];
746 let result = pelt.detect(&data);
747
748 assert_eq!(
749 result.changepoints.len(),
750 1,
751 "expected 1 changepoint in [0,0,10,10], got {:?}",
752 result.changepoints
753 );
754 assert_eq!(
755 result.changepoints[0], 2,
756 "changepoint should be at index 2"
757 );
758 }
759
760 #[test]
763 fn test_pelt_changepoints_sorted() {
764 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
765
766 let mut data = vec![0.0; 25];
767 data.extend(vec![5.0; 25]);
768 data.extend(vec![-3.0; 25]);
769 data.extend(vec![7.0; 25]);
770
771 let result = pelt.detect(&data);
772 for i in 1..result.changepoints.len() {
773 assert!(
774 result.changepoints[i] > result.changepoints[i - 1],
775 "changepoints must be strictly increasing: {:?}",
776 result.changepoints
777 );
778 }
779 }
780
781 #[test]
784 fn test_pelt_bic_penalty_scales() {
785 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
786
787 let mut data = vec![0.0; 50];
790 data.extend(vec![0.5; 50]); let result = pelt.detect(&data);
793 let _ = result;
798 }
799
800 #[test]
803 fn test_pelt_segments_cover_data() {
804 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
805
806 let mut data = vec![0.0; 30];
807 data.extend(vec![5.0; 30]);
808 data.extend(vec![0.0; 30]);
809
810 let result = pelt.detect(&data);
811
812 let mut boundaries = vec![0];
814 boundaries.extend_from_slice(&result.changepoints);
815 boundaries.push(data.len());
816
817 for i in 1..boundaries.len() {
818 assert!(
819 boundaries[i] > boundaries[i - 1],
820 "segments must not have zero length"
821 );
822 }
823 assert_eq!(
824 *boundaries.last().expect("non-empty boundaries"),
825 data.len(),
826 "segments must cover entire data"
827 );
828 }
829
830 #[test]
833 fn test_pelt_downward_shift() {
834 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
835
836 let mut data = vec![10.0; 50];
837 data.extend(vec![2.0; 50]); let result = pelt.detect(&data);
840 assert_eq!(result.changepoints.len(), 1, "should detect downward shift");
841 assert!(
842 (result.changepoints[0] as i64 - 50).unsigned_abs() <= 2,
843 "changepoint should be near index 50, got {}",
844 result.changepoints[0]
845 );
846 }
847
848 #[test]
851 fn test_cost_function_params() {
852 assert_eq!(CostFunction::L2.params_per_segment(), 1);
853 assert_eq!(CostFunction::Normal.params_per_segment(), 2);
854 }
855
856 #[test]
859 fn test_pelt_multi_single_channel_matches_univariate() {
860 let pelt = Pelt::new(CostFunction::L2, Penalty::Custom(5.0)).expect("valid");
861 let mut data = vec![0.0; 50];
862 data.extend(vec![5.0; 50]);
863
864 let uni = pelt.detect(&data);
865 let multi = pelt.detect_multi(&[&data]).expect("valid");
866 assert_eq!(uni.changepoints, multi.changepoints);
867 }
868
869 #[test]
870 fn test_pelt_multi_two_channels() {
871 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
872 let a: Vec<f64> = [vec![0.0; 50], vec![5.0; 50]].concat();
873 let b: Vec<f64> = [vec![0.0; 50], vec![3.0; 50]].concat();
874
875 let result = pelt.detect_multi(&[&a, &b]).expect("valid");
876 assert_eq!(
877 result.changepoints.len(),
878 1,
879 "expected 1 changepoint, got {:?}",
880 result.changepoints
881 );
882 assert!(
883 (result.changepoints[0] as i64 - 50).unsigned_abs() <= 2,
884 "changepoint near 50, got {}",
885 result.changepoints[0]
886 );
887 }
888
889 #[test]
890 fn test_pelt_multi_inconsistent_lengths() {
891 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
892 let a = vec![0.0; 50];
893 let b = vec![0.0; 30];
894 assert!(pelt.detect_multi(&[&a[..], &b[..]]).is_none());
895 }
896
897 #[test]
898 fn test_pelt_multi_empty_signals() {
899 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
900 let result = pelt.detect_multi(&[]).expect("valid");
901 assert!(result.changepoints.is_empty());
902 }
903
904 #[test]
905 fn test_pelt_multi_three_channels_two_changepoints() {
906 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
907 let a: Vec<f64> = [vec![0.0; 40], vec![5.0; 40], vec![0.0; 40]].concat();
908 let b: Vec<f64> = [vec![0.0; 40], vec![3.0; 40], vec![0.0; 40]].concat();
909 let c: Vec<f64> = [vec![0.0; 40], vec![4.0; 40], vec![0.0; 40]].concat();
910
911 let result = pelt.detect_multi(&[&a, &b, &c]).expect("valid");
912 assert_eq!(
913 result.changepoints.len(),
914 2,
915 "expected 2 changepoints, got {:?}",
916 result.changepoints
917 );
918 }
919
920 #[test]
921 fn test_pelt_multi_short_data() {
922 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
923 let a = [1.0, 2.0];
924 let b = [3.0, 4.0];
925 let result = pelt.detect_multi(&[&a, &b]).expect("valid");
926 assert!(result.changepoints.is_empty());
927 }
928}