Skip to main content

u_analytics/detection/
pelt.rs

1//! Pruned Exact Linear Time (PELT) algorithm for offline changepoint detection.
2//!
3//! # Algorithm
4//!
5//! PELT solves the penalized cost minimization problem:
6//!
7//! ```text
8//! minimize  sum_{j=1}^{m+1} C(y_{tau_{j-1}+1 : tau_j}) + m * beta
9//! ```
10//!
11//! where `C` is a segment cost function, `m` is the number of changepoints,
12//! and `beta` is a penalty per changepoint.
13//!
14//! The dynamic programming recurrence is:
15//!
16//! ```text
17//! F(t) = min_{tau in R_t} [ F(tau) + C(y_{tau+1:t}) + beta ]
18//! ```
19//!
20//! with pruning rule: remove `tau` from candidates if `F(tau) + C(y_{tau+1:t}) >= F(t)`,
21//! since such `tau` can never be optimal for any future `s > t`.
22//!
23//! # Complexity
24//!
25//! Expected O(n) under mild conditions on the cost function.
26//! Worst case O(n^2) (no pruning effective).
27//!
28//! # Cost Functions
29//!
30//! - [`CostFunction::L2`] — Detects changes in mean. Cost = sum of squared residuals
31//!   from segment mean. One parameter per segment.
32//! - [`CostFunction::Normal`] — Detects changes in mean and/or variance. Cost =
33//!   n * ln(MLE variance). Two parameters per segment.
34//!
35//! # Penalty Selection
36//!
37//! - **BIC** (default): `beta = p * ln(n)` where `p` is the number of parameters
38//!   per segment (1 for L2, 2 for Normal). Balances model complexity and fit.
39//! - **Custom**: User-specified penalty value.
40//!
41//! # References
42//!
43//! - Killick, R., Fearnhead, P., & Eckley, I.A. (2012). "Optimal Detection of
44//!   Changepoints with a Linear Computational Cost", *Journal of the American
45//!   Statistical Association* 107(500), pp. 1590-1598.
46//! - Schwarz, G. (1978). "Estimating the Dimension of a Model",
47//!   *Annals of Statistics* 6(2), pp. 461-464.
48
49/// Cost function for evaluating segment homogeneity.
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum CostFunction {
52    /// Detects changes in mean only.
53    ///
54    /// Segment cost: `sum((y_i - y_bar)^2)` where `y_bar` is the segment mean.
55    /// Equivalent to minus log-likelihood of a normal model with known variance.
56    /// One parameter per segment (mean).
57    L2,
58    /// Detects changes in mean and/or variance.
59    ///
60    /// Segment cost: `n * ln(sigma^2_MLE)` where `sigma^2_MLE = sum((y_i - y_bar)^2) / n`.
61    /// Derived from the negative log-likelihood of a normal distribution
62    /// with both mean and variance estimated by MLE.
63    /// Two parameters per segment (mean, variance).
64    Normal,
65}
66
67impl CostFunction {
68    /// Number of estimated parameters per segment.
69    fn params_per_segment(self) -> usize {
70        match self {
71            CostFunction::L2 => 1,
72            CostFunction::Normal => 2,
73        }
74    }
75}
76
77/// Penalty selection for the PELT algorithm.
78#[derive(Debug, Clone, Copy)]
79pub enum Penalty {
80    /// BIC penalty: `p * ln(n)` where `p` = parameters per segment.
81    ///
82    /// Reference: Schwarz (1978). Automatically scales with data length.
83    Bic,
84    /// User-specified penalty value (must be positive and finite).
85    Custom(f64),
86}
87
88/// PELT changepoint detector.
89///
90/// Implements the Pruned Exact Linear Time algorithm for detecting
91/// multiple changepoints in a univariate time series.
92///
93/// # Examples
94///
95/// ```
96/// use u_analytics::detection::{Pelt, CostFunction, Penalty};
97///
98/// // Data with a mean shift at index 50
99/// let mut data: Vec<f64> = vec![0.0; 50];
100/// data.extend(vec![5.0; 50]);
101///
102/// let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).unwrap();
103/// let result = pelt.detect(&data);
104/// assert!(!result.changepoints.is_empty());
105/// // The detected changepoint should be near index 50
106/// assert!((result.changepoints[0] as i64 - 50).unsigned_abs() <= 2);
107/// ```
108///
109/// # References
110///
111/// Killick, R., Fearnhead, P., & Eckley, I.A. (2012). "Optimal Detection of
112/// Changepoints with a Linear Computational Cost", *JASA* 107(500), pp. 1590-1598.
113pub struct Pelt {
114    /// Cost function for segment evaluation.
115    cost: CostFunction,
116    /// Penalty per changepoint.
117    penalty: Penalty,
118    /// Minimum segment length (must be >= 2 for variance estimation).
119    min_segment_len: usize,
120}
121
122/// Result of PELT changepoint detection.
123#[derive(Debug, Clone)]
124pub struct PeltResult {
125    /// Detected changepoint indices (0-based). Each index marks the first
126    /// observation of a new segment.
127    ///
128    /// For example, if `changepoints = [50, 100]`, the segments are
129    /// `[0..50)`, `[50..100)`, `[100..n)`.
130    pub changepoints: Vec<usize>,
131}
132
133/// Result of multivariate PELT changepoint detection.
134#[derive(Debug, Clone)]
135pub struct MultiPeltResult {
136    /// Detected changepoint indices (0-based), shared across all channels.
137    pub changepoints: Vec<usize>,
138}
139
140impl Pelt {
141    /// Creates a new PELT detector with the given cost function and penalty.
142    ///
143    /// Uses a default minimum segment length of 2.
144    ///
145    /// # Returns
146    ///
147    /// `None` if a custom penalty is not positive or not finite.
148    ///
149    /// # Reference
150    ///
151    /// Killick et al. (2012), §2.2: penalty must be positive to avoid
152    /// trivial solutions (changepoint at every observation).
153    pub fn new(cost: CostFunction, penalty: Penalty) -> Option<Self> {
154        Self::with_min_segment_len(cost, penalty, 2)
155    }
156
157    /// Creates a PELT detector with a custom minimum segment length.
158    ///
159    /// # Parameters
160    ///
161    /// - `cost`: Cost function for segment evaluation
162    /// - `penalty`: Penalty per changepoint
163    /// - `min_segment_len`: Minimum number of observations per segment.
164    ///   Must be >= 2 (needed for variance estimation in Normal cost).
165    ///
166    /// # Returns
167    ///
168    /// `None` if parameters are invalid.
169    pub fn with_min_segment_len(
170        cost: CostFunction,
171        penalty: Penalty,
172        min_segment_len: usize,
173    ) -> Option<Self> {
174        if let Penalty::Custom(p) = penalty {
175            if !p.is_finite() || p <= 0.0 {
176                return None;
177            }
178        }
179        if min_segment_len < 2 {
180            return None;
181        }
182        Some(Self {
183            cost,
184            penalty,
185            min_segment_len,
186        })
187    }
188
189    /// Detects changepoints in the given data.
190    ///
191    /// Returns a [`PeltResult`] containing the detected changepoint indices.
192    /// If the data is too short (fewer than `2 * min_segment_len` observations),
193    /// no changepoints can be detected and an empty result is returned.
194    ///
195    /// Non-finite values (NaN, Infinity) are **not** supported in the input.
196    /// The data should be pre-cleaned; non-finite values will lead to
197    /// incorrect cost computations.
198    ///
199    /// # Examples
200    ///
201    /// ```
202    /// use u_analytics::detection::{Pelt, CostFunction, Penalty};
203    ///
204    /// // Two changepoints: shift up at 30, shift down at 70
205    /// let mut data = vec![0.0; 30];
206    /// data.extend(vec![3.0; 40]);
207    /// data.extend(vec![0.0; 30]);
208    ///
209    /// let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).unwrap();
210    /// let result = pelt.detect(&data);
211    /// assert_eq!(result.changepoints.len(), 2);
212    /// ```
213    ///
214    /// # Complexity
215    ///
216    /// Expected O(n), worst case O(n^2).
217    pub fn detect(&self, data: &[f64]) -> PeltResult {
218        let n = data.len();
219
220        if n < 2 * self.min_segment_len {
221            return PeltResult {
222                changepoints: Vec::new(),
223            };
224        }
225
226        let penalty_value = self.resolve_penalty(n);
227
228        // Precompute cumulative sums for O(1) segment cost evaluation.
229        // cum_sum[i] = sum(data[0..i])
230        // cum_sum_sq[i] = sum(data[0..i]^2)
231        let mut cum_sum = vec![0.0_f64; n + 1];
232        let mut cum_sum_sq = vec![0.0_f64; n + 1];
233        for i in 0..n {
234            cum_sum[i + 1] = cum_sum[i] + data[i];
235            cum_sum_sq[i + 1] = cum_sum_sq[i] + data[i] * data[i];
236        }
237
238        // F[t] = optimal cost for data[0..t]
239        // F[0] = -penalty (so that the first segment cost + penalty = cost alone)
240        let mut f = vec![0.0_f64; n + 1];
241        f[0] = -penalty_value;
242
243        // last_change[t] = last changepoint index for optimal segmentation of data[0..t]
244        let mut last_change = vec![0_usize; n + 1];
245
246        // Candidate set R: indices tau such that tau could be the last changepoint before t
247        let mut candidates: Vec<usize> = vec![0];
248
249        for t in self.min_segment_len..=n {
250            // Find the optimal last changepoint for position t
251            let mut best_cost = f64::INFINITY;
252            let mut best_tau = 0;
253
254            for &tau in &candidates {
255                let seg_len = t - tau;
256                if seg_len < self.min_segment_len {
257                    continue;
258                }
259
260                let cost = self.segment_cost(&cum_sum, &cum_sum_sq, tau, t);
261                let total = f[tau] + cost + penalty_value;
262
263                if total < best_cost {
264                    best_cost = total;
265                    best_tau = tau;
266                }
267            }
268
269            f[t] = best_cost;
270            last_change[t] = best_tau;
271
272            // PELT pruning: remove candidates that can never be optimal
273            // Killick et al. (2012), Theorem 3.1:
274            // If F(tau) + C(y_{tau+1:t}) >= F(t), then tau can be pruned.
275            candidates.retain(|&tau| {
276                let seg_len = t - tau;
277                if seg_len < self.min_segment_len {
278                    return true; // Keep — not yet evaluable
279                }
280                let cost = self.segment_cost(&cum_sum, &cum_sum_sq, tau, t);
281                f[tau] + cost < f[t] + penalty_value
282            });
283
284            candidates.push(t);
285        }
286
287        // Backtrack to extract changepoints
288        let mut changepoints = Vec::new();
289        let mut t = n;
290        while t > 0 {
291            let tau = last_change[t];
292            if tau > 0 {
293                changepoints.push(tau);
294            }
295            t = tau;
296        }
297
298        changepoints.sort_unstable();
299
300        PeltResult { changepoints }
301    }
302
303    /// Detects changepoints in multivariate (multi-signal) data.
304    ///
305    /// Each inner slice represents one signal channel. All channels must
306    /// have the same length. The cost function is applied independently
307    /// to each channel and summed — a single set of changepoints is
308    /// returned that applies to all channels simultaneously.
309    ///
310    /// The penalty scales with the number of channels: `penalty * n_channels`.
311    ///
312    /// # Parameters
313    ///
314    /// - `signals`: slice of signal channels, each of length `n`
315    ///
316    /// # Returns
317    ///
318    /// `None` if channels have inconsistent lengths.
319    /// Otherwise returns `Some(MultiPeltResult)`.
320    ///
321    /// # Examples
322    ///
323    /// ```
324    /// use u_analytics::detection::{Pelt, CostFunction, Penalty};
325    ///
326    /// let signal_a: Vec<f64> = [vec![0.0; 50], vec![5.0; 50]].concat();
327    /// let signal_b: Vec<f64> = [vec![0.0; 50], vec![3.0; 50]].concat();
328    ///
329    /// let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).unwrap();
330    /// let result = pelt.detect_multi(&[&signal_a, &signal_b]).unwrap();
331    /// assert!(!result.changepoints.is_empty());
332    /// ```
333    ///
334    /// # Reference
335    ///
336    /// Killick, R. & Eckley, I.A. (2014). "changepoint: An R Package for
337    /// Changepoint Analysis", *Journal of Statistical Software* 58(3).
338    ///
339    /// # Complexity
340    ///
341    /// Expected O(n * k), worst case O(n^2 * k) where k = number of channels.
342    pub fn detect_multi(&self, signals: &[&[f64]]) -> Option<MultiPeltResult> {
343        if signals.is_empty() {
344            return Some(MultiPeltResult {
345                changepoints: Vec::new(),
346            });
347        }
348
349        let n = signals[0].len();
350        if signals.iter().any(|s| s.len() != n) {
351            return None;
352        }
353
354        if n < 2 * self.min_segment_len {
355            return Some(MultiPeltResult {
356                changepoints: Vec::new(),
357            });
358        }
359
360        let n_channels = signals.len();
361        let penalty_value = self.resolve_penalty(n) * n_channels as f64;
362
363        // Precompute cumulative sums per channel for O(1) segment cost.
364        let mut cum_sums: Vec<Vec<f64>> = Vec::with_capacity(n_channels);
365        let mut cum_sum_sqs: Vec<Vec<f64>> = Vec::with_capacity(n_channels);
366
367        for signal in signals {
368            let mut cs = vec![0.0_f64; n + 1];
369            let mut css = vec![0.0_f64; n + 1];
370            for i in 0..n {
371                cs[i + 1] = cs[i] + signal[i];
372                css[i + 1] = css[i] + signal[i] * signal[i];
373            }
374            cum_sums.push(cs);
375            cum_sum_sqs.push(css);
376        }
377
378        let mut f = vec![0.0_f64; n + 1];
379        f[0] = -penalty_value;
380        let mut last_change = vec![0_usize; n + 1];
381        let mut candidates: Vec<usize> = vec![0];
382
383        for t in self.min_segment_len..=n {
384            let mut best_cost = f64::INFINITY;
385            let mut best_tau = 0;
386
387            for &tau in &candidates {
388                let seg_len = t - tau;
389                if seg_len < self.min_segment_len {
390                    continue;
391                }
392
393                let cost: f64 = (0..n_channels)
394                    .map(|ch| self.segment_cost(&cum_sums[ch], &cum_sum_sqs[ch], tau, t))
395                    .sum();
396                let total = f[tau] + cost + penalty_value;
397
398                if total < best_cost {
399                    best_cost = total;
400                    best_tau = tau;
401                }
402            }
403
404            f[t] = best_cost;
405            last_change[t] = best_tau;
406
407            candidates.retain(|&tau| {
408                let seg_len = t - tau;
409                if seg_len < self.min_segment_len {
410                    return true;
411                }
412                let cost: f64 = (0..n_channels)
413                    .map(|ch| self.segment_cost(&cum_sums[ch], &cum_sum_sqs[ch], tau, t))
414                    .sum();
415                f[tau] + cost < f[t] + penalty_value
416            });
417
418            candidates.push(t);
419        }
420
421        let mut changepoints = Vec::new();
422        let mut t = n;
423        while t > 0 {
424            let tau = last_change[t];
425            if tau > 0 {
426                changepoints.push(tau);
427            }
428            t = tau;
429        }
430        changepoints.sort_unstable();
431
432        Some(MultiPeltResult { changepoints })
433    }
434
435    /// Resolves the penalty value for a dataset of length `n`.
436    fn resolve_penalty(&self, n: usize) -> f64 {
437        match self.penalty {
438            Penalty::Bic => {
439                let p = self.cost.params_per_segment() as f64;
440                p * (n as f64).ln()
441            }
442            Penalty::Custom(val) => val,
443        }
444    }
445
446    /// Computes the cost of the segment `data[start..end]` using cumulative sums.
447    ///
448    /// # Panics
449    ///
450    /// Panics if `end <= start`.
451    fn segment_cost(&self, cum_sum: &[f64], cum_sum_sq: &[f64], start: usize, end: usize) -> f64 {
452        let seg_len = (end - start) as f64;
453        let sum = cum_sum[end] - cum_sum[start];
454        let sum_sq = cum_sum_sq[end] - cum_sum_sq[start];
455        let mean = sum / seg_len;
456
457        match self.cost {
458            CostFunction::L2 => {
459                // sum((y_i - mean)^2) = sum(y_i^2) - n * mean^2
460                sum_sq - seg_len * mean * mean
461            }
462            CostFunction::Normal => {
463                // n * ln(variance) where variance = sum((y_i - mean)^2) / n
464                let variance = (sum_sq - seg_len * mean * mean) / seg_len;
465                if variance <= 0.0 {
466                    // Degenerate segment (constant values): assign a large negative
467                    // log-likelihood to make it favorable (perfect fit).
468                    seg_len * (f64::MIN_POSITIVE).ln()
469                } else {
470                    seg_len * variance.ln()
471                }
472            }
473        }
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480
481    // --- Constructor validation ---
482
483    #[test]
484    fn test_pelt_valid_construction() {
485        assert!(Pelt::new(CostFunction::L2, Penalty::Bic).is_some());
486        assert!(Pelt::new(CostFunction::Normal, Penalty::Bic).is_some());
487        assert!(Pelt::new(CostFunction::L2, Penalty::Custom(10.0)).is_some());
488    }
489
490    #[test]
491    fn test_pelt_invalid_custom_penalty() {
492        assert!(Pelt::new(CostFunction::L2, Penalty::Custom(0.0)).is_none());
493        assert!(Pelt::new(CostFunction::L2, Penalty::Custom(-1.0)).is_none());
494        assert!(Pelt::new(CostFunction::L2, Penalty::Custom(f64::NAN)).is_none());
495        assert!(Pelt::new(CostFunction::L2, Penalty::Custom(f64::INFINITY)).is_none());
496    }
497
498    #[test]
499    fn test_pelt_invalid_min_segment_len() {
500        assert!(Pelt::with_min_segment_len(CostFunction::L2, Penalty::Bic, 0).is_none());
501        assert!(Pelt::with_min_segment_len(CostFunction::L2, Penalty::Bic, 1).is_none());
502        assert!(Pelt::with_min_segment_len(CostFunction::L2, Penalty::Bic, 2).is_some());
503    }
504
505    // --- Empty and short data ---
506
507    #[test]
508    fn test_pelt_empty_data() {
509        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
510        let result = pelt.detect(&[]);
511        assert!(result.changepoints.is_empty());
512    }
513
514    #[test]
515    fn test_pelt_too_short_data() {
516        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
517        // min_segment_len=2, need at least 4 points for a changepoint
518        let result = pelt.detect(&[1.0, 2.0, 3.0]);
519        assert!(result.changepoints.is_empty());
520    }
521
522    // --- No changepoint scenarios ---
523
524    #[test]
525    fn test_pelt_constant_data_no_changepoint() {
526        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
527        let data = vec![5.0; 100];
528        let result = pelt.detect(&data);
529        assert!(
530            result.changepoints.is_empty(),
531            "constant data should have no changepoints, got {:?}",
532            result.changepoints
533        );
534    }
535
536    #[test]
537    fn test_pelt_normal_cost_constant_data() {
538        let pelt = Pelt::new(CostFunction::Normal, Penalty::Bic).expect("valid");
539        let data = vec![5.0; 100];
540        let result = pelt.detect(&data);
541        assert!(
542            result.changepoints.is_empty(),
543            "constant data should have no changepoints with Normal cost, got {:?}",
544            result.changepoints
545        );
546    }
547
548    // --- Single changepoint detection ---
549
550    #[test]
551    fn test_pelt_single_mean_shift_l2() {
552        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
553
554        let mut data = vec![0.0; 50];
555        data.extend(vec![5.0; 50]);
556
557        let result = pelt.detect(&data);
558        assert_eq!(
559            result.changepoints.len(),
560            1,
561            "expected 1 changepoint, got {:?}",
562            result.changepoints
563        );
564        assert!(
565            (result.changepoints[0] as i64 - 50).unsigned_abs() <= 2,
566            "changepoint should be near index 50, got {}",
567            result.changepoints[0]
568        );
569    }
570
571    #[test]
572    fn test_pelt_single_mean_shift_normal() {
573        let pelt = Pelt::new(CostFunction::Normal, Penalty::Bic).expect("valid");
574
575        let mut data = vec![0.0; 50];
576        data.extend(vec![5.0; 50]);
577
578        let result = pelt.detect(&data);
579        assert_eq!(
580            result.changepoints.len(),
581            1,
582            "expected 1 changepoint with Normal cost, got {:?}",
583            result.changepoints
584        );
585        assert!(
586            (result.changepoints[0] as i64 - 50).unsigned_abs() <= 2,
587            "changepoint should be near index 50, got {}",
588            result.changepoints[0]
589        );
590    }
591
592    // --- Multiple changepoints ---
593
594    #[test]
595    fn test_pelt_two_changepoints() {
596        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
597
598        let mut data = vec![0.0; 40];
599        data.extend(vec![5.0; 40]);
600        data.extend(vec![0.0; 40]);
601
602        let result = pelt.detect(&data);
603        assert_eq!(
604            result.changepoints.len(),
605            2,
606            "expected 2 changepoints, got {:?}",
607            result.changepoints
608        );
609
610        // Changepoints should be near 40 and 80
611        assert!(
612            (result.changepoints[0] as i64 - 40).unsigned_abs() <= 2,
613            "first changepoint near 40, got {}",
614            result.changepoints[0]
615        );
616        assert!(
617            (result.changepoints[1] as i64 - 80).unsigned_abs() <= 2,
618            "second changepoint near 80, got {}",
619            result.changepoints[1]
620        );
621    }
622
623    #[test]
624    fn test_pelt_three_changepoints() {
625        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
626
627        let mut data = vec![0.0; 30];
628        data.extend(vec![4.0; 30]);
629        data.extend(vec![-2.0; 30]);
630        data.extend(vec![3.0; 30]);
631
632        let result = pelt.detect(&data);
633        assert_eq!(
634            result.changepoints.len(),
635            3,
636            "expected 3 changepoints, got {:?}",
637            result.changepoints
638        );
639
640        // Check ordering
641        for i in 1..result.changepoints.len() {
642            assert!(
643                result.changepoints[i] > result.changepoints[i - 1],
644                "changepoints should be strictly increasing"
645            );
646        }
647    }
648
649    // --- Variance change detection ---
650
651    #[test]
652    fn test_pelt_variance_change_normal_cost() {
653        // Normal cost should detect a variance change even when mean is constant.
654        let pelt = Pelt::new(CostFunction::Normal, Penalty::Bic).expect("valid");
655
656        // Low variance segment then high variance segment
657        let mut data = Vec::with_capacity(200);
658        // Segment 1: mean=0, low spread (deterministic zigzag)
659        for i in 0..100 {
660            data.push(if i % 2 == 0 { 0.1 } else { -0.1 });
661        }
662        // Segment 2: mean=0, high spread
663        for i in 0..100 {
664            data.push(if i % 2 == 0 { 5.0 } else { -5.0 });
665        }
666
667        let result = pelt.detect(&data);
668        assert!(
669            !result.changepoints.is_empty(),
670            "Normal cost should detect variance change"
671        );
672        // The changepoint should be near index 100
673        let cp = result.changepoints[0];
674        assert!(
675            (cp as i64 - 100).unsigned_abs() <= 5,
676            "variance changepoint should be near 100, got {}",
677            cp
678        );
679    }
680
681    // --- Penalty sensitivity ---
682
683    #[test]
684    fn test_pelt_higher_penalty_fewer_changepoints() {
685        let mut data = vec![0.0; 30];
686        data.extend(vec![2.0; 30]);
687        data.extend(vec![0.0; 30]);
688
689        let pelt_low = Pelt::new(CostFunction::L2, Penalty::Custom(1.0)).expect("valid");
690        let pelt_high = Pelt::new(CostFunction::L2, Penalty::Custom(100.0)).expect("valid");
691
692        let result_low = pelt_low.detect(&data);
693        let result_high = pelt_high.detect(&data);
694
695        assert!(
696            result_low.changepoints.len() >= result_high.changepoints.len(),
697            "higher penalty should produce fewer or equal changepoints: low={}, high={}",
698            result_low.changepoints.len(),
699            result_high.changepoints.len()
700        );
701    }
702
703    // --- Custom minimum segment length ---
704
705    #[test]
706    fn test_pelt_custom_min_segment_len() {
707        let mut data = vec![0.0; 50];
708        data.extend(vec![10.0; 50]);
709
710        let pelt = Pelt::with_min_segment_len(CostFunction::L2, Penalty::Bic, 10).expect("valid");
711        let result = pelt.detect(&data);
712        assert_eq!(
713            result.changepoints.len(),
714            1,
715            "should detect changepoint with min_segment_len=10"
716        );
717
718        // All segments should respect minimum length
719        let mut boundaries = vec![0];
720        boundaries.extend_from_slice(&result.changepoints);
721        boundaries.push(data.len());
722        for i in 1..boundaries.len() {
723            let seg_len = boundaries[i] - boundaries[i - 1];
724            assert!(
725                seg_len >= 10,
726                "segment length {} is less than min_segment_len=10",
727                seg_len
728            );
729        }
730    }
731
732    // --- Exact numeric verification ---
733
734    /// Verifies PELT on a simple 4-point example with known optimal solution.
735    ///
736    /// Data: [0, 0, 10, 10], penalty = 2*ln(4) ≈ 2.77
737    ///
738    /// Without changepoint: cost = sum((y_i - 5)^2) = 25+25+25+25 = 100
739    /// With changepoint at 2: cost = 0 + 0 + 2*ln(4) = 2.77 (L2 cost of each segment = 0)
740    ///
741    /// PELT should find changepoint at index 2.
742    #[test]
743    fn test_pelt_exact_small_example() {
744        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
745        let data = [0.0, 0.0, 10.0, 10.0];
746        let result = pelt.detect(&data);
747
748        assert_eq!(
749            result.changepoints.len(),
750            1,
751            "expected 1 changepoint in [0,0,10,10], got {:?}",
752            result.changepoints
753        );
754        assert_eq!(
755            result.changepoints[0], 2,
756            "changepoint should be at index 2"
757        );
758    }
759
760    // --- Changepoints are sorted ---
761
762    #[test]
763    fn test_pelt_changepoints_sorted() {
764        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
765
766        let mut data = vec![0.0; 25];
767        data.extend(vec![5.0; 25]);
768        data.extend(vec![-3.0; 25]);
769        data.extend(vec![7.0; 25]);
770
771        let result = pelt.detect(&data);
772        for i in 1..result.changepoints.len() {
773            assert!(
774                result.changepoints[i] > result.changepoints[i - 1],
775                "changepoints must be strictly increasing: {:?}",
776                result.changepoints
777            );
778        }
779    }
780
781    // --- BIC penalty scales with n ---
782
783    #[test]
784    fn test_pelt_bic_penalty_scales() {
785        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
786
787        // BIC for L2: penalty = ln(n). For n=100, penalty ≈ 4.6.
788        // A small shift should not be detected with BIC.
789        let mut data = vec![0.0; 50];
790        data.extend(vec![0.5; 50]); // Very small shift
791
792        let result = pelt.detect(&data);
793        // BIC should suppress this tiny shift
794        // (0.5^2 * 50 = 12.5 for each segment reduction, but penalty is ~4.6)
795        // This depends on exact cost, but a 0.5-unit shift in 100 points
796        // may or may not be detected. We just verify it runs without panic.
797        let _ = result;
798    }
799
800    // --- Property: segments cover entire data ---
801
802    #[test]
803    fn test_pelt_segments_cover_data() {
804        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
805
806        let mut data = vec![0.0; 30];
807        data.extend(vec![5.0; 30]);
808        data.extend(vec![0.0; 30]);
809
810        let result = pelt.detect(&data);
811
812        // Verify segments cover [0, n) without gaps
813        let mut boundaries = vec![0];
814        boundaries.extend_from_slice(&result.changepoints);
815        boundaries.push(data.len());
816
817        for i in 1..boundaries.len() {
818            assert!(
819                boundaries[i] > boundaries[i - 1],
820                "segments must not have zero length"
821            );
822        }
823        assert_eq!(
824            *boundaries.last().expect("non-empty boundaries"),
825            data.len(),
826            "segments must cover entire data"
827        );
828    }
829
830    // --- Downward shift ---
831
832    #[test]
833    fn test_pelt_downward_shift() {
834        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
835
836        let mut data = vec![10.0; 50];
837        data.extend(vec![2.0; 50]); // Downward shift
838
839        let result = pelt.detect(&data);
840        assert_eq!(result.changepoints.len(), 1, "should detect downward shift");
841        assert!(
842            (result.changepoints[0] as i64 - 50).unsigned_abs() <= 2,
843            "changepoint should be near index 50, got {}",
844            result.changepoints[0]
845        );
846    }
847
848    // --- Cost function enum properties ---
849
850    #[test]
851    fn test_cost_function_params() {
852        assert_eq!(CostFunction::L2.params_per_segment(), 1);
853        assert_eq!(CostFunction::Normal.params_per_segment(), 2);
854    }
855
856    // --- Multi-signal tests ---
857
858    #[test]
859    fn test_pelt_multi_single_channel_matches_univariate() {
860        let pelt = Pelt::new(CostFunction::L2, Penalty::Custom(5.0)).expect("valid");
861        let mut data = vec![0.0; 50];
862        data.extend(vec![5.0; 50]);
863
864        let uni = pelt.detect(&data);
865        let multi = pelt.detect_multi(&[&data]).expect("valid");
866        assert_eq!(uni.changepoints, multi.changepoints);
867    }
868
869    #[test]
870    fn test_pelt_multi_two_channels() {
871        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
872        let a: Vec<f64> = [vec![0.0; 50], vec![5.0; 50]].concat();
873        let b: Vec<f64> = [vec![0.0; 50], vec![3.0; 50]].concat();
874
875        let result = pelt.detect_multi(&[&a, &b]).expect("valid");
876        assert_eq!(
877            result.changepoints.len(),
878            1,
879            "expected 1 changepoint, got {:?}",
880            result.changepoints
881        );
882        assert!(
883            (result.changepoints[0] as i64 - 50).unsigned_abs() <= 2,
884            "changepoint near 50, got {}",
885            result.changepoints[0]
886        );
887    }
888
889    #[test]
890    fn test_pelt_multi_inconsistent_lengths() {
891        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
892        let a = vec![0.0; 50];
893        let b = vec![0.0; 30];
894        assert!(pelt.detect_multi(&[&a[..], &b[..]]).is_none());
895    }
896
897    #[test]
898    fn test_pelt_multi_empty_signals() {
899        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
900        let result = pelt.detect_multi(&[]).expect("valid");
901        assert!(result.changepoints.is_empty());
902    }
903
904    #[test]
905    fn test_pelt_multi_three_channels_two_changepoints() {
906        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
907        let a: Vec<f64> = [vec![0.0; 40], vec![5.0; 40], vec![0.0; 40]].concat();
908        let b: Vec<f64> = [vec![0.0; 40], vec![3.0; 40], vec![0.0; 40]].concat();
909        let c: Vec<f64> = [vec![0.0; 40], vec![4.0; 40], vec![0.0; 40]].concat();
910
911        let result = pelt.detect_multi(&[&a, &b, &c]).expect("valid");
912        assert_eq!(
913            result.changepoints.len(),
914            2,
915            "expected 2 changepoints, got {:?}",
916            result.changepoints
917        );
918    }
919
920    #[test]
921    fn test_pelt_multi_short_data() {
922        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
923        let a = [1.0, 2.0];
924        let b = [3.0, 4.0];
925        let result = pelt.detect_multi(&[&a, &b]).expect("valid");
926        assert!(result.changepoints.is_empty());
927    }
928}