Skip to main content

u_analytics/detection/
pelt.rs

1//! Pruned Exact Linear Time (PELT) algorithm for offline changepoint detection.
2//!
3//! # Algorithm
4//!
5//! PELT solves the penalized cost minimization problem:
6//!
7//! ```text
8//! minimize  sum_{j=1}^{m+1} C(y_{tau_{j-1}+1 : tau_j}) + m * beta
9//! ```
10//!
11//! where `C` is a segment cost function, `m` is the number of changepoints,
12//! and `beta` is a penalty per changepoint.
13//!
14//! The dynamic programming recurrence is:
15//!
16//! ```text
17//! F(t) = min_{tau in R_t} [ F(tau) + C(y_{tau+1:t}) + beta ]
18//! ```
19//!
20//! with pruning rule: remove `tau` from candidates if `F(tau) + C(y_{tau+1:t}) >= F(t)`,
21//! since such `tau` can never be optimal for any future `s > t`.
22//!
23//! # Complexity
24//!
25//! Expected O(n) under mild conditions on the cost function.
26//! Worst case O(n^2) (no pruning effective).
27//!
28//! # Cost Functions
29//!
30//! - [`CostFunction::L2`] — Detects changes in mean. Cost = sum of squared residuals
31//!   from segment mean. One parameter per segment.
32//! - [`CostFunction::Normal`] — Detects changes in mean and/or variance. Cost =
33//!   n * ln(MLE variance). Two parameters per segment.
34//!
35//! # Penalty Selection
36//!
37//! - **BIC** (default): `beta = p * ln(n)` where `p` is the number of parameters
38//!   per segment (1 for L2, 2 for Normal). Balances model complexity and fit.
39//! - **Custom**: User-specified penalty value.
40//!
41//! # References
42//!
43//! - Killick, R., Fearnhead, P., & Eckley, I.A. (2012). "Optimal Detection of
44//!   Changepoints with a Linear Computational Cost", *Journal of the American
45//!   Statistical Association* 107(500), pp. 1590-1598.
46//! - Schwarz, G. (1978). "Estimating the Dimension of a Model",
47//!   *Annals of Statistics* 6(2), pp. 461-464.
48
49/// Cost function for evaluating segment homogeneity.
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum CostFunction {
52    /// Detects changes in mean only.
53    ///
54    /// Segment cost: `sum((y_i - y_bar)^2)` where `y_bar` is the segment mean.
55    /// Equivalent to minus log-likelihood of a normal model with known variance.
56    /// One parameter per segment (mean).
57    L2,
58    /// Detects changes in mean and/or variance.
59    ///
60    /// Segment cost: `n * ln(sigma^2_MLE)` where `sigma^2_MLE = sum((y_i - y_bar)^2) / n`.
61    /// Derived from the negative log-likelihood of a normal distribution
62    /// with both mean and variance estimated by MLE.
63    /// Two parameters per segment (mean, variance).
64    Normal,
65}
66
67impl CostFunction {
68    /// Number of estimated parameters per segment.
69    fn params_per_segment(self) -> usize {
70        match self {
71            CostFunction::L2 => 1,
72            CostFunction::Normal => 2,
73        }
74    }
75}
76
77/// Penalty selection for the PELT algorithm.
78#[derive(Debug, Clone, Copy)]
79pub enum Penalty {
80    /// BIC penalty: `p * ln(n)` where `p` = parameters per segment.
81    ///
82    /// Reference: Schwarz (1978). Automatically scales with data length.
83    Bic,
84    /// User-specified penalty value (must be positive and finite).
85    Custom(f64),
86}
87
88/// PELT changepoint detector.
89///
90/// Implements the Pruned Exact Linear Time algorithm for detecting
91/// multiple changepoints in a univariate time series.
92///
93/// # Examples
94///
95/// ```
96/// use u_analytics::detection::{Pelt, CostFunction, Penalty};
97///
98/// // Data with a mean shift at index 50
99/// let mut data: Vec<f64> = vec![0.0; 50];
100/// data.extend(vec![5.0; 50]);
101///
102/// let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).unwrap();
103/// let result = pelt.detect(&data);
104/// assert!(!result.changepoints.is_empty());
105/// // The detected changepoint should be near index 50
106/// assert!((result.changepoints[0] as i64 - 50).unsigned_abs() <= 2);
107/// ```
108///
109/// # References
110///
111/// Killick, R., Fearnhead, P., & Eckley, I.A. (2012). "Optimal Detection of
112/// Changepoints with a Linear Computational Cost", *JASA* 107(500), pp. 1590-1598.
113pub struct Pelt {
114    /// Cost function for segment evaluation.
115    cost: CostFunction,
116    /// Penalty per changepoint.
117    penalty: Penalty,
118    /// Minimum segment length (must be >= 2 for variance estimation).
119    min_segment_len: usize,
120}
121
122/// Result of PELT changepoint detection.
123#[derive(Debug, Clone)]
124pub struct PeltResult {
125    /// Detected changepoint indices (0-based). Each index marks the first
126    /// observation of a new segment.
127    ///
128    /// For example, if `changepoints = [50, 100]`, the segments are
129    /// `[0..50)`, `[50..100)`, `[100..n)`.
130    pub changepoints: Vec<usize>,
131}
132
133/// Result of multivariate PELT changepoint detection.
134#[derive(Debug, Clone)]
135pub struct MultiPeltResult {
136    /// Detected changepoint indices (0-based), shared across all channels.
137    pub changepoints: Vec<usize>,
138}
139
140impl Pelt {
141    /// Creates a new PELT detector with the given cost function and penalty.
142    ///
143    /// Uses a default minimum segment length of 2.
144    ///
145    /// # Returns
146    ///
147    /// `None` if a custom penalty is not positive or not finite.
148    ///
149    /// # Reference
150    ///
151    /// Killick et al. (2012), §2.2: penalty must be positive to avoid
152    /// trivial solutions (changepoint at every observation).
153    pub fn new(cost: CostFunction, penalty: Penalty) -> Option<Self> {
154        Self::with_min_segment_len(cost, penalty, 2)
155    }
156
157    /// Creates a PELT detector with a custom minimum segment length.
158    ///
159    /// # Parameters
160    ///
161    /// - `cost`: Cost function for segment evaluation
162    /// - `penalty`: Penalty per changepoint
163    /// - `min_segment_len`: Minimum number of observations per segment.
164    ///   Must be >= 2 (needed for variance estimation in Normal cost).
165    ///
166    /// # Returns
167    ///
168    /// `None` if parameters are invalid.
169    pub fn with_min_segment_len(
170        cost: CostFunction,
171        penalty: Penalty,
172        min_segment_len: usize,
173    ) -> Option<Self> {
174        if let Penalty::Custom(p) = penalty {
175            if !p.is_finite() || p <= 0.0 {
176                return None;
177            }
178        }
179        if min_segment_len < 2 {
180            return None;
181        }
182        Some(Self {
183            cost,
184            penalty,
185            min_segment_len,
186        })
187    }
188
189    /// Detects changepoints in the given data.
190    ///
191    /// Returns a [`PeltResult`] containing the detected changepoint indices.
192    /// If the data is too short (fewer than `2 * min_segment_len` observations),
193    /// no changepoints can be detected and an empty result is returned.
194    ///
195    /// Non-finite values (NaN, Infinity) are **not** supported in the input.
196    /// The data should be pre-cleaned; non-finite values will lead to
197    /// incorrect cost computations.
198    ///
199    /// # Examples
200    ///
201    /// ```
202    /// use u_analytics::detection::{Pelt, CostFunction, Penalty};
203    ///
204    /// // Two changepoints: shift up at 30, shift down at 70
205    /// let mut data = vec![0.0; 30];
206    /// data.extend(vec![3.0; 40]);
207    /// data.extend(vec![0.0; 30]);
208    ///
209    /// let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).unwrap();
210    /// let result = pelt.detect(&data);
211    /// assert_eq!(result.changepoints.len(), 2);
212    /// ```
213    ///
214    /// # Complexity
215    ///
216    /// Expected O(n), worst case O(n^2).
217    pub fn detect(&self, data: &[f64]) -> PeltResult {
218        let n = data.len();
219
220        if n < 2 * self.min_segment_len {
221            return PeltResult {
222                changepoints: Vec::new(),
223            };
224        }
225
226        let penalty_value = self.resolve_penalty(n);
227
228        // Precompute cumulative sums for O(1) segment cost evaluation.
229        // cum_sum[i] = sum(data[0..i])
230        // cum_sum_sq[i] = sum(data[0..i]^2)
231        let mut cum_sum = vec![0.0_f64; n + 1];
232        let mut cum_sum_sq = vec![0.0_f64; n + 1];
233        for i in 0..n {
234            cum_sum[i + 1] = cum_sum[i] + data[i];
235            cum_sum_sq[i + 1] = cum_sum_sq[i] + data[i] * data[i];
236        }
237
238        // F[t] = optimal cost for data[0..t]
239        // F[0] = -penalty (so that the first segment cost + penalty = cost alone)
240        let mut f = vec![0.0_f64; n + 1];
241        f[0] = -penalty_value;
242
243        // last_change[t] = last changepoint index for optimal segmentation of data[0..t]
244        let mut last_change = vec![0_usize; n + 1];
245
246        // Candidate set R: indices tau such that tau could be the last changepoint before t
247        let mut candidates: Vec<usize> = vec![0];
248
249        for t in self.min_segment_len..=n {
250            // Find the optimal last changepoint for position t
251            let mut best_cost = f64::INFINITY;
252            let mut best_tau = 0;
253
254            for &tau in &candidates {
255                let seg_len = t - tau;
256                if seg_len < self.min_segment_len {
257                    continue;
258                }
259
260                let cost = self.segment_cost(&cum_sum, &cum_sum_sq, tau, t);
261                let total = f[tau] + cost + penalty_value;
262
263                if total < best_cost {
264                    best_cost = total;
265                    best_tau = tau;
266                }
267            }
268
269            f[t] = best_cost;
270            last_change[t] = best_tau;
271
272            // PELT pruning: remove candidates that can never be optimal
273            // Killick et al. (2012), Theorem 3.1:
274            // If F(tau) + C(y_{tau+1:t}) >= F(t), then tau can be pruned.
275            candidates.retain(|&tau| {
276                let seg_len = t - tau;
277                if seg_len < self.min_segment_len {
278                    return true; // Keep — not yet evaluable
279                }
280                let cost = self.segment_cost(&cum_sum, &cum_sum_sq, tau, t);
281                f[tau] + cost < f[t] + penalty_value
282            });
283
284            candidates.push(t);
285        }
286
287        // Backtrack to extract changepoints
288        let mut changepoints = Vec::new();
289        let mut t = n;
290        while t > 0 {
291            let tau = last_change[t];
292            if tau > 0 {
293                changepoints.push(tau);
294            }
295            t = tau;
296        }
297
298        changepoints.sort_unstable();
299
300        PeltResult { changepoints }
301    }
302
303    /// Detects changepoints in multivariate (multi-signal) data.
304    ///
305    /// Each inner slice represents one signal channel. All channels must
306    /// have the same length. The cost function is applied independently
307    /// to each channel and summed — a single set of changepoints is
308    /// returned that applies to all channels simultaneously.
309    ///
310    /// The penalty scales with the number of channels: `penalty * n_channels`.
311    ///
312    /// # Parameters
313    ///
314    /// - `signals`: slice of signal channels, each of length `n`
315    ///
316    /// # Returns
317    ///
318    /// `None` if channels have inconsistent lengths.
319    /// Otherwise returns `Some(MultiPeltResult)`.
320    ///
321    /// # Examples
322    ///
323    /// ```
324    /// use u_analytics::detection::{Pelt, CostFunction, Penalty};
325    ///
326    /// let signal_a: Vec<f64> = [vec![0.0; 50], vec![5.0; 50]].concat();
327    /// let signal_b: Vec<f64> = [vec![0.0; 50], vec![3.0; 50]].concat();
328    ///
329    /// let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).unwrap();
330    /// let result = pelt.detect_multi(&[&signal_a, &signal_b]).unwrap();
331    /// assert!(!result.changepoints.is_empty());
332    /// ```
333    ///
334    /// # Reference
335    ///
336    /// Killick, R. & Eckley, I.A. (2014). "changepoint: An R Package for
337    /// Changepoint Analysis", *Journal of Statistical Software* 58(3).
338    ///
339    /// # Complexity
340    ///
341    /// Expected O(n * k), worst case O(n^2 * k) where k = number of channels.
342    pub fn detect_multi(&self, signals: &[&[f64]]) -> Option<MultiPeltResult> {
343        if signals.is_empty() {
344            return Some(MultiPeltResult {
345                changepoints: Vec::new(),
346            });
347        }
348
349        let n = signals[0].len();
350        if signals.iter().any(|s| s.len() != n) {
351            return None;
352        }
353
354        if n < 2 * self.min_segment_len {
355            return Some(MultiPeltResult {
356                changepoints: Vec::new(),
357            });
358        }
359
360        let n_channels = signals.len();
361        let penalty_value = self.resolve_penalty(n) * n_channels as f64;
362
363        // Precompute cumulative sums per channel for O(1) segment cost.
364        let mut cum_sums: Vec<Vec<f64>> = Vec::with_capacity(n_channels);
365        let mut cum_sum_sqs: Vec<Vec<f64>> = Vec::with_capacity(n_channels);
366
367        for signal in signals {
368            let mut cs = vec![0.0_f64; n + 1];
369            let mut css = vec![0.0_f64; n + 1];
370            for i in 0..n {
371                cs[i + 1] = cs[i] + signal[i];
372                css[i + 1] = css[i] + signal[i] * signal[i];
373            }
374            cum_sums.push(cs);
375            cum_sum_sqs.push(css);
376        }
377
378        let mut f = vec![0.0_f64; n + 1];
379        f[0] = -penalty_value;
380        let mut last_change = vec![0_usize; n + 1];
381        let mut candidates: Vec<usize> = vec![0];
382
383        for t in self.min_segment_len..=n {
384            let mut best_cost = f64::INFINITY;
385            let mut best_tau = 0;
386
387            for &tau in &candidates {
388                let seg_len = t - tau;
389                if seg_len < self.min_segment_len {
390                    continue;
391                }
392
393                let cost: f64 = (0..n_channels)
394                    .map(|ch| self.segment_cost(&cum_sums[ch], &cum_sum_sqs[ch], tau, t))
395                    .sum();
396                let total = f[tau] + cost + penalty_value;
397
398                if total < best_cost {
399                    best_cost = total;
400                    best_tau = tau;
401                }
402            }
403
404            f[t] = best_cost;
405            last_change[t] = best_tau;
406
407            candidates.retain(|&tau| {
408                let seg_len = t - tau;
409                if seg_len < self.min_segment_len {
410                    return true;
411                }
412                let cost: f64 = (0..n_channels)
413                    .map(|ch| self.segment_cost(&cum_sums[ch], &cum_sum_sqs[ch], tau, t))
414                    .sum();
415                f[tau] + cost < f[t] + penalty_value
416            });
417
418            candidates.push(t);
419        }
420
421        let mut changepoints = Vec::new();
422        let mut t = n;
423        while t > 0 {
424            let tau = last_change[t];
425            if tau > 0 {
426                changepoints.push(tau);
427            }
428            t = tau;
429        }
430        changepoints.sort_unstable();
431
432        Some(MultiPeltResult { changepoints })
433    }
434
435    /// Resolves the penalty value for a dataset of length `n`.
436    fn resolve_penalty(&self, n: usize) -> f64 {
437        match self.penalty {
438            Penalty::Bic => {
439                let p = self.cost.params_per_segment() as f64;
440                p * (n as f64).ln()
441            }
442            Penalty::Custom(val) => val,
443        }
444    }
445
446    /// Computes the cost of the segment `data[start..end]` using cumulative sums.
447    ///
448    /// # Panics
449    ///
450    /// Panics if `end <= start`.
451    fn segment_cost(
452        &self,
453        cum_sum: &[f64],
454        cum_sum_sq: &[f64],
455        start: usize,
456        end: usize,
457    ) -> f64 {
458        let seg_len = (end - start) as f64;
459        let sum = cum_sum[end] - cum_sum[start];
460        let sum_sq = cum_sum_sq[end] - cum_sum_sq[start];
461        let mean = sum / seg_len;
462
463        match self.cost {
464            CostFunction::L2 => {
465                // sum((y_i - mean)^2) = sum(y_i^2) - n * mean^2
466                sum_sq - seg_len * mean * mean
467            }
468            CostFunction::Normal => {
469                // n * ln(variance) where variance = sum((y_i - mean)^2) / n
470                let variance = (sum_sq - seg_len * mean * mean) / seg_len;
471                if variance <= 0.0 {
472                    // Degenerate segment (constant values): assign a large negative
473                    // log-likelihood to make it favorable (perfect fit).
474                    seg_len * (f64::MIN_POSITIVE).ln()
475                } else {
476                    seg_len * variance.ln()
477                }
478            }
479        }
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    // --- Constructor validation ---
488
489    #[test]
490    fn test_pelt_valid_construction() {
491        assert!(Pelt::new(CostFunction::L2, Penalty::Bic).is_some());
492        assert!(Pelt::new(CostFunction::Normal, Penalty::Bic).is_some());
493        assert!(Pelt::new(CostFunction::L2, Penalty::Custom(10.0)).is_some());
494    }
495
496    #[test]
497    fn test_pelt_invalid_custom_penalty() {
498        assert!(Pelt::new(CostFunction::L2, Penalty::Custom(0.0)).is_none());
499        assert!(Pelt::new(CostFunction::L2, Penalty::Custom(-1.0)).is_none());
500        assert!(Pelt::new(CostFunction::L2, Penalty::Custom(f64::NAN)).is_none());
501        assert!(Pelt::new(CostFunction::L2, Penalty::Custom(f64::INFINITY)).is_none());
502    }
503
504    #[test]
505    fn test_pelt_invalid_min_segment_len() {
506        assert!(Pelt::with_min_segment_len(CostFunction::L2, Penalty::Bic, 0).is_none());
507        assert!(Pelt::with_min_segment_len(CostFunction::L2, Penalty::Bic, 1).is_none());
508        assert!(Pelt::with_min_segment_len(CostFunction::L2, Penalty::Bic, 2).is_some());
509    }
510
511    // --- Empty and short data ---
512
513    #[test]
514    fn test_pelt_empty_data() {
515        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
516        let result = pelt.detect(&[]);
517        assert!(result.changepoints.is_empty());
518    }
519
520    #[test]
521    fn test_pelt_too_short_data() {
522        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
523        // min_segment_len=2, need at least 4 points for a changepoint
524        let result = pelt.detect(&[1.0, 2.0, 3.0]);
525        assert!(result.changepoints.is_empty());
526    }
527
528    // --- No changepoint scenarios ---
529
530    #[test]
531    fn test_pelt_constant_data_no_changepoint() {
532        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
533        let data = vec![5.0; 100];
534        let result = pelt.detect(&data);
535        assert!(
536            result.changepoints.is_empty(),
537            "constant data should have no changepoints, got {:?}",
538            result.changepoints
539        );
540    }
541
542    #[test]
543    fn test_pelt_normal_cost_constant_data() {
544        let pelt = Pelt::new(CostFunction::Normal, Penalty::Bic).expect("valid");
545        let data = vec![5.0; 100];
546        let result = pelt.detect(&data);
547        assert!(
548            result.changepoints.is_empty(),
549            "constant data should have no changepoints with Normal cost, got {:?}",
550            result.changepoints
551        );
552    }
553
554    // --- Single changepoint detection ---
555
556    #[test]
557    fn test_pelt_single_mean_shift_l2() {
558        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
559
560        let mut data = vec![0.0; 50];
561        data.extend(vec![5.0; 50]);
562
563        let result = pelt.detect(&data);
564        assert_eq!(
565            result.changepoints.len(),
566            1,
567            "expected 1 changepoint, got {:?}",
568            result.changepoints
569        );
570        assert!(
571            (result.changepoints[0] as i64 - 50).unsigned_abs() <= 2,
572            "changepoint should be near index 50, got {}",
573            result.changepoints[0]
574        );
575    }
576
577    #[test]
578    fn test_pelt_single_mean_shift_normal() {
579        let pelt = Pelt::new(CostFunction::Normal, Penalty::Bic).expect("valid");
580
581        let mut data = vec![0.0; 50];
582        data.extend(vec![5.0; 50]);
583
584        let result = pelt.detect(&data);
585        assert_eq!(
586            result.changepoints.len(),
587            1,
588            "expected 1 changepoint with Normal cost, got {:?}",
589            result.changepoints
590        );
591        assert!(
592            (result.changepoints[0] as i64 - 50).unsigned_abs() <= 2,
593            "changepoint should be near index 50, got {}",
594            result.changepoints[0]
595        );
596    }
597
598    // --- Multiple changepoints ---
599
600    #[test]
601    fn test_pelt_two_changepoints() {
602        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
603
604        let mut data = vec![0.0; 40];
605        data.extend(vec![5.0; 40]);
606        data.extend(vec![0.0; 40]);
607
608        let result = pelt.detect(&data);
609        assert_eq!(
610            result.changepoints.len(),
611            2,
612            "expected 2 changepoints, got {:?}",
613            result.changepoints
614        );
615
616        // Changepoints should be near 40 and 80
617        assert!(
618            (result.changepoints[0] as i64 - 40).unsigned_abs() <= 2,
619            "first changepoint near 40, got {}",
620            result.changepoints[0]
621        );
622        assert!(
623            (result.changepoints[1] as i64 - 80).unsigned_abs() <= 2,
624            "second changepoint near 80, got {}",
625            result.changepoints[1]
626        );
627    }
628
629    #[test]
630    fn test_pelt_three_changepoints() {
631        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
632
633        let mut data = vec![0.0; 30];
634        data.extend(vec![4.0; 30]);
635        data.extend(vec![-2.0; 30]);
636        data.extend(vec![3.0; 30]);
637
638        let result = pelt.detect(&data);
639        assert_eq!(
640            result.changepoints.len(),
641            3,
642            "expected 3 changepoints, got {:?}",
643            result.changepoints
644        );
645
646        // Check ordering
647        for i in 1..result.changepoints.len() {
648            assert!(
649                result.changepoints[i] > result.changepoints[i - 1],
650                "changepoints should be strictly increasing"
651            );
652        }
653    }
654
655    // --- Variance change detection ---
656
657    #[test]
658    fn test_pelt_variance_change_normal_cost() {
659        // Normal cost should detect a variance change even when mean is constant.
660        let pelt = Pelt::new(CostFunction::Normal, Penalty::Bic).expect("valid");
661
662        // Low variance segment then high variance segment
663        let mut data = Vec::with_capacity(200);
664        // Segment 1: mean=0, low spread (deterministic zigzag)
665        for i in 0..100 {
666            data.push(if i % 2 == 0 { 0.1 } else { -0.1 });
667        }
668        // Segment 2: mean=0, high spread
669        for i in 0..100 {
670            data.push(if i % 2 == 0 { 5.0 } else { -5.0 });
671        }
672
673        let result = pelt.detect(&data);
674        assert!(
675            !result.changepoints.is_empty(),
676            "Normal cost should detect variance change"
677        );
678        // The changepoint should be near index 100
679        let cp = result.changepoints[0];
680        assert!(
681            (cp as i64 - 100).unsigned_abs() <= 5,
682            "variance changepoint should be near 100, got {}",
683            cp
684        );
685    }
686
687    // --- Penalty sensitivity ---
688
689    #[test]
690    fn test_pelt_higher_penalty_fewer_changepoints() {
691        let mut data = vec![0.0; 30];
692        data.extend(vec![2.0; 30]);
693        data.extend(vec![0.0; 30]);
694
695        let pelt_low = Pelt::new(CostFunction::L2, Penalty::Custom(1.0)).expect("valid");
696        let pelt_high = Pelt::new(CostFunction::L2, Penalty::Custom(100.0)).expect("valid");
697
698        let result_low = pelt_low.detect(&data);
699        let result_high = pelt_high.detect(&data);
700
701        assert!(
702            result_low.changepoints.len() >= result_high.changepoints.len(),
703            "higher penalty should produce fewer or equal changepoints: low={}, high={}",
704            result_low.changepoints.len(),
705            result_high.changepoints.len()
706        );
707    }
708
709    // --- Custom minimum segment length ---
710
711    #[test]
712    fn test_pelt_custom_min_segment_len() {
713        let mut data = vec![0.0; 50];
714        data.extend(vec![10.0; 50]);
715
716        let pelt = Pelt::with_min_segment_len(CostFunction::L2, Penalty::Bic, 10).expect("valid");
717        let result = pelt.detect(&data);
718        assert_eq!(
719            result.changepoints.len(),
720            1,
721            "should detect changepoint with min_segment_len=10"
722        );
723
724        // All segments should respect minimum length
725        let mut boundaries = vec![0];
726        boundaries.extend_from_slice(&result.changepoints);
727        boundaries.push(data.len());
728        for i in 1..boundaries.len() {
729            let seg_len = boundaries[i] - boundaries[i - 1];
730            assert!(
731                seg_len >= 10,
732                "segment length {} is less than min_segment_len=10",
733                seg_len
734            );
735        }
736    }
737
738    // --- Exact numeric verification ---
739
740    /// Verifies PELT on a simple 4-point example with known optimal solution.
741    ///
742    /// Data: [0, 0, 10, 10], penalty = 2*ln(4) ≈ 2.77
743    ///
744    /// Without changepoint: cost = sum((y_i - 5)^2) = 25+25+25+25 = 100
745    /// With changepoint at 2: cost = 0 + 0 + 2*ln(4) = 2.77 (L2 cost of each segment = 0)
746    ///
747    /// PELT should find changepoint at index 2.
748    #[test]
749    fn test_pelt_exact_small_example() {
750        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
751        let data = [0.0, 0.0, 10.0, 10.0];
752        let result = pelt.detect(&data);
753
754        assert_eq!(
755            result.changepoints.len(),
756            1,
757            "expected 1 changepoint in [0,0,10,10], got {:?}",
758            result.changepoints
759        );
760        assert_eq!(
761            result.changepoints[0], 2,
762            "changepoint should be at index 2"
763        );
764    }
765
766    // --- Changepoints are sorted ---
767
768    #[test]
769    fn test_pelt_changepoints_sorted() {
770        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
771
772        let mut data = vec![0.0; 25];
773        data.extend(vec![5.0; 25]);
774        data.extend(vec![-3.0; 25]);
775        data.extend(vec![7.0; 25]);
776
777        let result = pelt.detect(&data);
778        for i in 1..result.changepoints.len() {
779            assert!(
780                result.changepoints[i] > result.changepoints[i - 1],
781                "changepoints must be strictly increasing: {:?}",
782                result.changepoints
783            );
784        }
785    }
786
787    // --- BIC penalty scales with n ---
788
789    #[test]
790    fn test_pelt_bic_penalty_scales() {
791        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
792
793        // BIC for L2: penalty = ln(n). For n=100, penalty ≈ 4.6.
794        // A small shift should not be detected with BIC.
795        let mut data = vec![0.0; 50];
796        data.extend(vec![0.5; 50]); // Very small shift
797
798        let result = pelt.detect(&data);
799        // BIC should suppress this tiny shift
800        // (0.5^2 * 50 = 12.5 for each segment reduction, but penalty is ~4.6)
801        // This depends on exact cost, but a 0.5-unit shift in 100 points
802        // may or may not be detected. We just verify it runs without panic.
803        let _ = result;
804    }
805
806    // --- Property: segments cover entire data ---
807
808    #[test]
809    fn test_pelt_segments_cover_data() {
810        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
811
812        let mut data = vec![0.0; 30];
813        data.extend(vec![5.0; 30]);
814        data.extend(vec![0.0; 30]);
815
816        let result = pelt.detect(&data);
817
818        // Verify segments cover [0, n) without gaps
819        let mut boundaries = vec![0];
820        boundaries.extend_from_slice(&result.changepoints);
821        boundaries.push(data.len());
822
823        for i in 1..boundaries.len() {
824            assert!(
825                boundaries[i] > boundaries[i - 1],
826                "segments must not have zero length"
827            );
828        }
829        assert_eq!(
830            *boundaries.last().expect("non-empty boundaries"),
831            data.len(),
832            "segments must cover entire data"
833        );
834    }
835
836    // --- Downward shift ---
837
838    #[test]
839    fn test_pelt_downward_shift() {
840        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
841
842        let mut data = vec![10.0; 50];
843        data.extend(vec![2.0; 50]); // Downward shift
844
845        let result = pelt.detect(&data);
846        assert_eq!(
847            result.changepoints.len(),
848            1,
849            "should detect downward shift"
850        );
851        assert!(
852            (result.changepoints[0] as i64 - 50).unsigned_abs() <= 2,
853            "changepoint should be near index 50, got {}",
854            result.changepoints[0]
855        );
856    }
857
858    // --- Cost function enum properties ---
859
860    #[test]
861    fn test_cost_function_params() {
862        assert_eq!(CostFunction::L2.params_per_segment(), 1);
863        assert_eq!(CostFunction::Normal.params_per_segment(), 2);
864    }
865
866    // --- Multi-signal tests ---
867
868    #[test]
869    fn test_pelt_multi_single_channel_matches_univariate() {
870        let pelt = Pelt::new(CostFunction::L2, Penalty::Custom(5.0)).expect("valid");
871        let mut data = vec![0.0; 50];
872        data.extend(vec![5.0; 50]);
873
874        let uni = pelt.detect(&data);
875        let multi = pelt.detect_multi(&[&data]).expect("valid");
876        assert_eq!(uni.changepoints, multi.changepoints);
877    }
878
879    #[test]
880    fn test_pelt_multi_two_channels() {
881        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
882        let a: Vec<f64> = [vec![0.0; 50], vec![5.0; 50]].concat();
883        let b: Vec<f64> = [vec![0.0; 50], vec![3.0; 50]].concat();
884
885        let result = pelt.detect_multi(&[&a, &b]).expect("valid");
886        assert_eq!(
887            result.changepoints.len(),
888            1,
889            "expected 1 changepoint, got {:?}",
890            result.changepoints
891        );
892        assert!(
893            (result.changepoints[0] as i64 - 50).unsigned_abs() <= 2,
894            "changepoint near 50, got {}",
895            result.changepoints[0]
896        );
897    }
898
899    #[test]
900    fn test_pelt_multi_inconsistent_lengths() {
901        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
902        let a = vec![0.0; 50];
903        let b = vec![0.0; 30];
904        assert!(pelt.detect_multi(&[&a[..], &b[..]]).is_none());
905    }
906
907    #[test]
908    fn test_pelt_multi_empty_signals() {
909        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
910        let result = pelt.detect_multi(&[]).expect("valid");
911        assert!(result.changepoints.is_empty());
912    }
913
914    #[test]
915    fn test_pelt_multi_three_channels_two_changepoints() {
916        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
917        let a: Vec<f64> = [vec![0.0; 40], vec![5.0; 40], vec![0.0; 40]].concat();
918        let b: Vec<f64> = [vec![0.0; 40], vec![3.0; 40], vec![0.0; 40]].concat();
919        let c: Vec<f64> = [vec![0.0; 40], vec![4.0; 40], vec![0.0; 40]].concat();
920
921        let result = pelt.detect_multi(&[&a, &b, &c]).expect("valid");
922        assert_eq!(
923            result.changepoints.len(),
924            2,
925            "expected 2 changepoints, got {:?}",
926            result.changepoints
927        );
928    }
929
930    #[test]
931    fn test_pelt_multi_short_data() {
932        let pelt = Pelt::new(CostFunction::L2, Penalty::Bic).expect("valid");
933        let a = [1.0, 2.0];
934        let b = [3.0, 4.0];
935        let result = pelt.detect_multi(&[&a, &b]).expect("valid");
936        assert!(result.changepoints.is_empty());
937    }
938}