1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum CostFunction {
52 L2,
58 Normal,
65}
66
67impl CostFunction {
68 fn params_per_segment(self) -> usize {
70 match self {
71 CostFunction::L2 => 1,
72 CostFunction::Normal => 2,
73 }
74 }
75}
76
77#[derive(Debug, Clone, Copy)]
79pub enum Penalty {
80 Bic,
84 Custom(f64),
86}
87
88pub struct Pelt {
114 cost: CostFunction,
116 penalty: Penalty,
118 min_segment_len: usize,
120}
121
122#[derive(Debug, Clone)]
124pub struct PeltResult {
125 pub changepoints: Vec<usize>,
131}
132
133#[derive(Debug, Clone)]
135pub struct MultiPeltResult {
136 pub changepoints: Vec<usize>,
138}
139
140impl Pelt {
141 pub fn new(cost: CostFunction, penalty: Penalty) -> Option<Self> {
154 Self::with_min_segment_len(cost, penalty, 2)
155 }
156
157 pub fn with_min_segment_len(
170 cost: CostFunction,
171 penalty: Penalty,
172 min_segment_len: usize,
173 ) -> Option<Self> {
174 if let Penalty::Custom(p) = penalty {
175 if !p.is_finite() || p <= 0.0 {
176 return None;
177 }
178 }
179 if min_segment_len < 2 {
180 return None;
181 }
182 Some(Self {
183 cost,
184 penalty,
185 min_segment_len,
186 })
187 }
188
189 pub fn detect(&self, data: &[f64]) -> PeltResult {
218 let n = data.len();
219
220 if n < 2 * self.min_segment_len {
221 return PeltResult {
222 changepoints: Vec::new(),
223 };
224 }
225
226 let penalty_value = self.resolve_penalty(n);
227
228 let mut cum_sum = vec![0.0_f64; n + 1];
232 let mut cum_sum_sq = vec![0.0_f64; n + 1];
233 for i in 0..n {
234 cum_sum[i + 1] = cum_sum[i] + data[i];
235 cum_sum_sq[i + 1] = cum_sum_sq[i] + data[i] * data[i];
236 }
237
238 let mut f = vec![0.0_f64; n + 1];
241 f[0] = -penalty_value;
242
243 let mut last_change = vec![0_usize; n + 1];
245
246 let mut candidates: Vec<usize> = vec![0];
248
249 for t in self.min_segment_len..=n {
250 let mut best_cost = f64::INFINITY;
252 let mut best_tau = 0;
253
254 for &tau in &candidates {
255 let seg_len = t - tau;
256 if seg_len < self.min_segment_len {
257 continue;
258 }
259
260 let cost = self.segment_cost(&cum_sum, &cum_sum_sq, tau, t);
261 let total = f[tau] + cost + penalty_value;
262
263 if total < best_cost {
264 best_cost = total;
265 best_tau = tau;
266 }
267 }
268
269 f[t] = best_cost;
270 last_change[t] = best_tau;
271
272 candidates.retain(|&tau| {
276 let seg_len = t - tau;
277 if seg_len < self.min_segment_len {
278 return true; }
280 let cost = self.segment_cost(&cum_sum, &cum_sum_sq, tau, t);
281 f[tau] + cost < f[t] + penalty_value
282 });
283
284 candidates.push(t);
285 }
286
287 let mut changepoints = Vec::new();
289 let mut t = n;
290 while t > 0 {
291 let tau = last_change[t];
292 if tau > 0 {
293 changepoints.push(tau);
294 }
295 t = tau;
296 }
297
298 changepoints.sort_unstable();
299
300 PeltResult { changepoints }
301 }
302
303 pub fn detect_multi(&self, signals: &[&[f64]]) -> Option<MultiPeltResult> {
343 if signals.is_empty() {
344 return Some(MultiPeltResult {
345 changepoints: Vec::new(),
346 });
347 }
348
349 let n = signals[0].len();
350 if signals.iter().any(|s| s.len() != n) {
351 return None;
352 }
353
354 if n < 2 * self.min_segment_len {
355 return Some(MultiPeltResult {
356 changepoints: Vec::new(),
357 });
358 }
359
360 let n_channels = signals.len();
361 let penalty_value = self.resolve_penalty(n) * n_channels as f64;
362
363 let mut cum_sums: Vec<Vec<f64>> = Vec::with_capacity(n_channels);
365 let mut cum_sum_sqs: Vec<Vec<f64>> = Vec::with_capacity(n_channels);
366
367 for signal in signals {
368 let mut cs = vec![0.0_f64; n + 1];
369 let mut css = vec![0.0_f64; n + 1];
370 for i in 0..n {
371 cs[i + 1] = cs[i] + signal[i];
372 css[i + 1] = css[i] + signal[i] * signal[i];
373 }
374 cum_sums.push(cs);
375 cum_sum_sqs.push(css);
376 }
377
378 let mut f = vec![0.0_f64; n + 1];
379 f[0] = -penalty_value;
380 let mut last_change = vec![0_usize; n + 1];
381 let mut candidates: Vec<usize> = vec![0];
382
383 for t in self.min_segment_len..=n {
384 let mut best_cost = f64::INFINITY;
385 let mut best_tau = 0;
386
387 for &tau in &candidates {
388 let seg_len = t - tau;
389 if seg_len < self.min_segment_len {
390 continue;
391 }
392
393 let cost: f64 = (0..n_channels)
394 .map(|ch| self.segment_cost(&cum_sums[ch], &cum_sum_sqs[ch], tau, t))
395 .sum();
396 let total = f[tau] + cost + penalty_value;
397
398 if total < best_cost {
399 best_cost = total;
400 best_tau = tau;
401 }
402 }
403
404 f[t] = best_cost;
405 last_change[t] = best_tau;
406
407 candidates.retain(|&tau| {
408 let seg_len = t - tau;
409 if seg_len < self.min_segment_len {
410 return true;
411 }
412 let cost: f64 = (0..n_channels)
413 .map(|ch| self.segment_cost(&cum_sums[ch], &cum_sum_sqs[ch], tau, t))
414 .sum();
415 f[tau] + cost < f[t] + penalty_value
416 });
417
418 candidates.push(t);
419 }
420
421 let mut changepoints = Vec::new();
422 let mut t = n;
423 while t > 0 {
424 let tau = last_change[t];
425 if tau > 0 {
426 changepoints.push(tau);
427 }
428 t = tau;
429 }
430 changepoints.sort_unstable();
431
432 Some(MultiPeltResult { changepoints })
433 }
434
435 fn resolve_penalty(&self, n: usize) -> f64 {
437 match self.penalty {
438 Penalty::Bic => {
439 let p = self.cost.params_per_segment() as f64;
440 p * (n as f64).ln()
441 }
442 Penalty::Custom(val) => val,
443 }
444 }
445
446 fn segment_cost(
452 &self,
453 cum_sum: &[f64],
454 cum_sum_sq: &[f64],
455 start: usize,
456 end: usize,
457 ) -> f64 {
458 let seg_len = (end - start) as f64;
459 let sum = cum_sum[end] - cum_sum[start];
460 let sum_sq = cum_sum_sq[end] - cum_sum_sq[start];
461 let mean = sum / seg_len;
462
463 match self.cost {
464 CostFunction::L2 => {
465 sum_sq - seg_len * mean * mean
467 }
468 CostFunction::Normal => {
469 let variance = (sum_sq - seg_len * mean * mean) / seg_len;
471 if variance <= 0.0 {
472 seg_len * (f64::MIN_POSITIVE).ln()
475 } else {
476 seg_len * variance.ln()
477 }
478 }
479 }
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
490 fn test_pelt_valid_construction() {
491 assert!(Pelt::new(CostFunction::L2, Penalty::Bic).is_some());
492 assert!(Pelt::new(CostFunction::Normal, Penalty::Bic).is_some());
493 assert!(Pelt::new(CostFunction::L2, Penalty::Custom(10.0)).is_some());
494 }
495
496 #[test]
497 fn test_pelt_invalid_custom_penalty() {
498 assert!(Pelt::new(CostFunction::L2, Penalty::Custom(0.0)).is_none());
499 assert!(Pelt::new(CostFunction::L2, Penalty::Custom(-1.0)).is_none());
500 assert!(Pelt::new(CostFunction::L2, Penalty::Custom(f64::NAN)).is_none());
501 assert!(Pelt::new(CostFunction::L2, Penalty::Custom(f64::INFINITY)).is_none());
502 }
503
504 #[test]
505 fn test_pelt_invalid_min_segment_len() {
506 assert!(Pelt::with_min_segment_len(CostFunction::L2, Penalty::Bic, 0).is_none());
507 assert!(Pelt::with_min_segment_len(CostFunction::L2, Penalty::Bic, 1).is_none());
508 assert!(Pelt::with_min_segment_len(CostFunction::L2, Penalty::Bic, 2).is_some());
509 }
510
511 #[test]
514 fn test_pelt_empty_data() {
515 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
516 let result = pelt.detect(&[]);
517 assert!(result.changepoints.is_empty());
518 }
519
520 #[test]
521 fn test_pelt_too_short_data() {
522 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
523 let result = pelt.detect(&[1.0, 2.0, 3.0]);
525 assert!(result.changepoints.is_empty());
526 }
527
528 #[test]
531 fn test_pelt_constant_data_no_changepoint() {
532 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
533 let data = vec![5.0; 100];
534 let result = pelt.detect(&data);
535 assert!(
536 result.changepoints.is_empty(),
537 "constant data should have no changepoints, got {:?}",
538 result.changepoints
539 );
540 }
541
542 #[test]
543 fn test_pelt_normal_cost_constant_data() {
544 let pelt = Pelt::new(CostFunction::Normal, Penalty::Bic).expect("valid");
545 let data = vec![5.0; 100];
546 let result = pelt.detect(&data);
547 assert!(
548 result.changepoints.is_empty(),
549 "constant data should have no changepoints with Normal cost, got {:?}",
550 result.changepoints
551 );
552 }
553
554 #[test]
557 fn test_pelt_single_mean_shift_l2() {
558 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
559
560 let mut data = vec![0.0; 50];
561 data.extend(vec![5.0; 50]);
562
563 let result = pelt.detect(&data);
564 assert_eq!(
565 result.changepoints.len(),
566 1,
567 "expected 1 changepoint, got {:?}",
568 result.changepoints
569 );
570 assert!(
571 (result.changepoints[0] as i64 - 50).unsigned_abs() <= 2,
572 "changepoint should be near index 50, got {}",
573 result.changepoints[0]
574 );
575 }
576
577 #[test]
578 fn test_pelt_single_mean_shift_normal() {
579 let pelt = Pelt::new(CostFunction::Normal, Penalty::Bic).expect("valid");
580
581 let mut data = vec![0.0; 50];
582 data.extend(vec![5.0; 50]);
583
584 let result = pelt.detect(&data);
585 assert_eq!(
586 result.changepoints.len(),
587 1,
588 "expected 1 changepoint with Normal cost, got {:?}",
589 result.changepoints
590 );
591 assert!(
592 (result.changepoints[0] as i64 - 50).unsigned_abs() <= 2,
593 "changepoint should be near index 50, got {}",
594 result.changepoints[0]
595 );
596 }
597
598 #[test]
601 fn test_pelt_two_changepoints() {
602 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
603
604 let mut data = vec![0.0; 40];
605 data.extend(vec![5.0; 40]);
606 data.extend(vec![0.0; 40]);
607
608 let result = pelt.detect(&data);
609 assert_eq!(
610 result.changepoints.len(),
611 2,
612 "expected 2 changepoints, got {:?}",
613 result.changepoints
614 );
615
616 assert!(
618 (result.changepoints[0] as i64 - 40).unsigned_abs() <= 2,
619 "first changepoint near 40, got {}",
620 result.changepoints[0]
621 );
622 assert!(
623 (result.changepoints[1] as i64 - 80).unsigned_abs() <= 2,
624 "second changepoint near 80, got {}",
625 result.changepoints[1]
626 );
627 }
628
629 #[test]
630 fn test_pelt_three_changepoints() {
631 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
632
633 let mut data = vec![0.0; 30];
634 data.extend(vec![4.0; 30]);
635 data.extend(vec![-2.0; 30]);
636 data.extend(vec![3.0; 30]);
637
638 let result = pelt.detect(&data);
639 assert_eq!(
640 result.changepoints.len(),
641 3,
642 "expected 3 changepoints, got {:?}",
643 result.changepoints
644 );
645
646 for i in 1..result.changepoints.len() {
648 assert!(
649 result.changepoints[i] > result.changepoints[i - 1],
650 "changepoints should be strictly increasing"
651 );
652 }
653 }
654
655 #[test]
658 fn test_pelt_variance_change_normal_cost() {
659 let pelt = Pelt::new(CostFunction::Normal, Penalty::Bic).expect("valid");
661
662 let mut data = Vec::with_capacity(200);
664 for i in 0..100 {
666 data.push(if i % 2 == 0 { 0.1 } else { -0.1 });
667 }
668 for i in 0..100 {
670 data.push(if i % 2 == 0 { 5.0 } else { -5.0 });
671 }
672
673 let result = pelt.detect(&data);
674 assert!(
675 !result.changepoints.is_empty(),
676 "Normal cost should detect variance change"
677 );
678 let cp = result.changepoints[0];
680 assert!(
681 (cp as i64 - 100).unsigned_abs() <= 5,
682 "variance changepoint should be near 100, got {}",
683 cp
684 );
685 }
686
687 #[test]
690 fn test_pelt_higher_penalty_fewer_changepoints() {
691 let mut data = vec![0.0; 30];
692 data.extend(vec![2.0; 30]);
693 data.extend(vec![0.0; 30]);
694
695 let pelt_low = Pelt::new(CostFunction::L2, Penalty::Custom(1.0)).expect("valid");
696 let pelt_high = Pelt::new(CostFunction::L2, Penalty::Custom(100.0)).expect("valid");
697
698 let result_low = pelt_low.detect(&data);
699 let result_high = pelt_high.detect(&data);
700
701 assert!(
702 result_low.changepoints.len() >= result_high.changepoints.len(),
703 "higher penalty should produce fewer or equal changepoints: low={}, high={}",
704 result_low.changepoints.len(),
705 result_high.changepoints.len()
706 );
707 }
708
709 #[test]
712 fn test_pelt_custom_min_segment_len() {
713 let mut data = vec![0.0; 50];
714 data.extend(vec![10.0; 50]);
715
716 let pelt = Pelt::with_min_segment_len(CostFunction::L2, Penalty::Bic, 10).expect("valid");
717 let result = pelt.detect(&data);
718 assert_eq!(
719 result.changepoints.len(),
720 1,
721 "should detect changepoint with min_segment_len=10"
722 );
723
724 let mut boundaries = vec![0];
726 boundaries.extend_from_slice(&result.changepoints);
727 boundaries.push(data.len());
728 for i in 1..boundaries.len() {
729 let seg_len = boundaries[i] - boundaries[i - 1];
730 assert!(
731 seg_len >= 10,
732 "segment length {} is less than min_segment_len=10",
733 seg_len
734 );
735 }
736 }
737
738 #[test]
749 fn test_pelt_exact_small_example() {
750 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
751 let data = [0.0, 0.0, 10.0, 10.0];
752 let result = pelt.detect(&data);
753
754 assert_eq!(
755 result.changepoints.len(),
756 1,
757 "expected 1 changepoint in [0,0,10,10], got {:?}",
758 result.changepoints
759 );
760 assert_eq!(
761 result.changepoints[0], 2,
762 "changepoint should be at index 2"
763 );
764 }
765
766 #[test]
769 fn test_pelt_changepoints_sorted() {
770 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
771
772 let mut data = vec![0.0; 25];
773 data.extend(vec![5.0; 25]);
774 data.extend(vec![-3.0; 25]);
775 data.extend(vec![7.0; 25]);
776
777 let result = pelt.detect(&data);
778 for i in 1..result.changepoints.len() {
779 assert!(
780 result.changepoints[i] > result.changepoints[i - 1],
781 "changepoints must be strictly increasing: {:?}",
782 result.changepoints
783 );
784 }
785 }
786
787 #[test]
790 fn test_pelt_bic_penalty_scales() {
791 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
792
793 let mut data = vec![0.0; 50];
796 data.extend(vec![0.5; 50]); let result = pelt.detect(&data);
799 let _ = result;
804 }
805
806 #[test]
809 fn test_pelt_segments_cover_data() {
810 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
811
812 let mut data = vec![0.0; 30];
813 data.extend(vec![5.0; 30]);
814 data.extend(vec![0.0; 30]);
815
816 let result = pelt.detect(&data);
817
818 let mut boundaries = vec![0];
820 boundaries.extend_from_slice(&result.changepoints);
821 boundaries.push(data.len());
822
823 for i in 1..boundaries.len() {
824 assert!(
825 boundaries[i] > boundaries[i - 1],
826 "segments must not have zero length"
827 );
828 }
829 assert_eq!(
830 *boundaries.last().expect("non-empty boundaries"),
831 data.len(),
832 "segments must cover entire data"
833 );
834 }
835
836 #[test]
839 fn test_pelt_downward_shift() {
840 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
841
842 let mut data = vec![10.0; 50];
843 data.extend(vec![2.0; 50]); let result = pelt.detect(&data);
846 assert_eq!(
847 result.changepoints.len(),
848 1,
849 "should detect downward shift"
850 );
851 assert!(
852 (result.changepoints[0] as i64 - 50).unsigned_abs() <= 2,
853 "changepoint should be near index 50, got {}",
854 result.changepoints[0]
855 );
856 }
857
858 #[test]
861 fn test_cost_function_params() {
862 assert_eq!(CostFunction::L2.params_per_segment(), 1);
863 assert_eq!(CostFunction::Normal.params_per_segment(), 2);
864 }
865
866 #[test]
869 fn test_pelt_multi_single_channel_matches_univariate() {
870 let pelt = Pelt::new(CostFunction::L2, Penalty::Custom(5.0)).expect("valid");
871 let mut data = vec![0.0; 50];
872 data.extend(vec![5.0; 50]);
873
874 let uni = pelt.detect(&data);
875 let multi = pelt.detect_multi(&[&data]).expect("valid");
876 assert_eq!(uni.changepoints, multi.changepoints);
877 }
878
879 #[test]
880 fn test_pelt_multi_two_channels() {
881 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
882 let a: Vec<f64> = [vec![0.0; 50], vec![5.0; 50]].concat();
883 let b: Vec<f64> = [vec![0.0; 50], vec![3.0; 50]].concat();
884
885 let result = pelt.detect_multi(&[&a, &b]).expect("valid");
886 assert_eq!(
887 result.changepoints.len(),
888 1,
889 "expected 1 changepoint, got {:?}",
890 result.changepoints
891 );
892 assert!(
893 (result.changepoints[0] as i64 - 50).unsigned_abs() <= 2,
894 "changepoint near 50, got {}",
895 result.changepoints[0]
896 );
897 }
898
899 #[test]
900 fn test_pelt_multi_inconsistent_lengths() {
901 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
902 let a = vec![0.0; 50];
903 let b = vec![0.0; 30];
904 assert!(pelt.detect_multi(&[&a[..], &b[..]]).is_none());
905 }
906
907 #[test]
908 fn test_pelt_multi_empty_signals() {
909 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
910 let result = pelt.detect_multi(&[]).expect("valid");
911 assert!(result.changepoints.is_empty());
912 }
913
914 #[test]
915 fn test_pelt_multi_three_channels_two_changepoints() {
916 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
917 let a: Vec<f64> = [vec![0.0; 40], vec![5.0; 40], vec![0.0; 40]].concat();
918 let b: Vec<f64> = [vec![0.0; 40], vec![3.0; 40], vec![0.0; 40]].concat();
919 let c: Vec<f64> = [vec![0.0; 40], vec![4.0; 40], vec![0.0; 40]].concat();
920
921 let result = pelt.detect_multi(&[&a, &b, &c]).expect("valid");
922 assert_eq!(
923 result.changepoints.len(),
924 2,
925 "expected 2 changepoints, got {:?}",
926 result.changepoints
927 );
928 }
929
930 #[test]
931 fn test_pelt_multi_short_data() {
932 let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
933 let a = [1.0, 2.0];
934 let b = [3.0, 4.0];
935 let result = pelt.detect_multi(&[&a, &b]).expect("valid");
936 assert!(result.changepoints.is_empty());
937 }
938}