sklears_preprocessing/temporal/
changepoint.rs

1//! Change point detection for time series
2//!
3//! This module provides various methods for detecting structural changes
4//! or change points in time series data.
5
6use scirs2_core::ndarray::{s, Array1};
7use sklears_core::{
8    error::Result,
9    traits::{Fit, Trained, Transform, Untrained},
10    types::Float,
11};
12use std::marker::PhantomData;
13
14/// Configuration for ChangePointDetector
15#[derive(Debug, Clone)]
16pub struct ChangePointDetectorConfig {
17    /// Detection method
18    pub method: ChangePointMethod,
19    /// Minimum segment length between change points
20    pub min_segment_length: usize,
21    /// Threshold for change point detection
22    pub threshold: Float,
23    /// Whether to return binary indicators or change point scores
24    pub binary_output: bool,
25}
26
27/// Change point detection methods
28#[derive(Debug, Clone, Copy)]
29pub enum ChangePointMethod {
30    /// CUSUM (Cumulative Sum) method
31    CUSUM,
32    /// Variance change detection
33    VarianceChange,
34    /// Mean change detection
35    MeanChange,
36    /// Simple difference-based detection
37    SimpleDifference,
38}
39
40impl Default for ChangePointDetectorConfig {
41    fn default() -> Self {
42        Self {
43            method: ChangePointMethod::CUSUM,
44            min_segment_length: 5,
45            threshold: 2.0,
46            binary_output: true,
47        }
48    }
49}
50
51/// ChangePointDetector for detecting structural changes in time series
52#[derive(Debug, Clone)]
53pub struct ChangePointDetector<S> {
54    config: ChangePointDetectorConfig,
55    _phantom: PhantomData<S>,
56}
57
58impl ChangePointDetector<Untrained> {
59    /// Create a new ChangePointDetector
60    pub fn new() -> Self {
61        Self {
62            config: ChangePointDetectorConfig::default(),
63            _phantom: PhantomData,
64        }
65    }
66
67    /// Set the detection method
68    pub fn method(mut self, method: ChangePointMethod) -> Self {
69        self.config.method = method;
70        self
71    }
72
73    /// Set the minimum segment length
74    pub fn min_segment_length(mut self, min_segment_length: usize) -> Self {
75        self.config.min_segment_length = min_segment_length;
76        self
77    }
78
79    /// Set the detection threshold
80    pub fn threshold(mut self, threshold: Float) -> Self {
81        self.config.threshold = threshold;
82        self
83    }
84
85    /// Set whether to use binary output
86    pub fn binary_output(mut self, binary_output: bool) -> Self {
87        self.config.binary_output = binary_output;
88        self
89    }
90}
91
92impl ChangePointDetector<Trained> {
93    /// CUSUM-based change point detection
94    fn detect_cusum(&self, data: &Array1<Float>) -> Array1<Float> {
95        let n = data.len();
96        let mut scores = Array1::<Float>::zeros(n);
97
98        if n < 2 {
99            return scores;
100        }
101
102        let mean = data.mean().unwrap_or(0.0);
103        let std = data.std(0.0);
104
105        if std < 1e-10 {
106            return scores;
107        }
108
109        let mut cusum_pos = 0.0;
110        let mut cusum_neg = 0.0;
111
112        for i in 0..n {
113            let standardized = (data[i] - mean) / std;
114
115            cusum_pos = (cusum_pos + standardized - 0.5).max(0.0);
116            cusum_neg = (cusum_neg - standardized - 0.5).max(0.0);
117
118            scores[i] = cusum_pos.max(cusum_neg);
119        }
120
121        if self.config.binary_output {
122            scores.mapv(|x| if x > self.config.threshold { 1.0 } else { 0.0 })
123        } else {
124            scores
125        }
126    }
127
128    /// Variance-based change point detection
129    fn detect_variance_change(&self, data: &Array1<Float>) -> Array1<Float> {
130        let n = data.len();
131        let window = self.config.min_segment_length;
132        let mut scores = Array1::<Float>::zeros(n);
133
134        if n < 2 * window {
135            return scores;
136        }
137
138        for i in window..(n - window) {
139            let left_window = data.slice(s![(i - window)..i]);
140            let right_window = data.slice(s![i..(i + window)]);
141
142            let left_var = left_window.var(0.0);
143            let right_var = right_window.var(0.0);
144
145            let ratio = if right_var > 1e-10 {
146                left_var / right_var
147            } else if left_var > 1e-10 {
148                Float::INFINITY
149            } else {
150                1.0
151            };
152
153            scores[i] = (ratio.ln()).abs();
154        }
155
156        if self.config.binary_output {
157            scores.mapv(|x| if x > self.config.threshold { 1.0 } else { 0.0 })
158        } else {
159            scores
160        }
161    }
162
163    /// Mean-based change point detection
164    fn detect_mean_change(&self, data: &Array1<Float>) -> Array1<Float> {
165        let n = data.len();
166        let window = self.config.min_segment_length;
167        let mut scores = Array1::<Float>::zeros(n);
168
169        if n < 2 * window {
170            return scores;
171        }
172
173        for i in window..(n - window) {
174            let left_window = data.slice(s![(i - window)..i]);
175            let right_window = data.slice(s![i..(i + window)]);
176
177            let left_mean = left_window.mean().unwrap_or(0.0);
178            let right_mean = right_window.mean().unwrap_or(0.0);
179            let pooled_std = ((left_window.var(0.0) + right_window.var(0.0)) / 2.0).sqrt();
180
181            scores[i] = if pooled_std > 1e-10 {
182                (left_mean - right_mean).abs() / pooled_std
183            } else {
184                0.0
185            };
186        }
187
188        if self.config.binary_output {
189            scores.mapv(|x| if x > self.config.threshold { 1.0 } else { 0.0 })
190        } else {
191            scores
192        }
193    }
194
195    /// Simple difference-based change point detection
196    fn detect_simple_difference(&self, data: &Array1<Float>) -> Array1<Float> {
197        let n = data.len();
198        let mut scores = Array1::<Float>::zeros(n);
199
200        if n < 2 {
201            return scores;
202        }
203
204        // Calculate first differences
205        for i in 1..n {
206            scores[i] = (data[i] - data[i - 1]).abs();
207        }
208
209        let threshold = if self.config.binary_output {
210            scores.mean().unwrap_or(0.0) + self.config.threshold * scores.std(0.0)
211        } else {
212            0.0
213        };
214
215        if self.config.binary_output {
216            scores.mapv(|x| if x > threshold { 1.0 } else { 0.0 })
217        } else {
218            scores
219        }
220    }
221}
222
223impl Default for ChangePointDetector<Untrained> {
224    fn default() -> Self {
225        Self::new()
226    }
227}
228
229impl Fit<Array1<Float>, ()> for ChangePointDetector<Untrained> {
230    type Fitted = ChangePointDetector<Trained>;
231
232    fn fit(self, _x: &Array1<Float>, _y: &()) -> Result<Self::Fitted> {
233        Ok(ChangePointDetector {
234            config: self.config,
235            _phantom: PhantomData,
236        })
237    }
238}
239
240impl Transform<Array1<Float>, Array1<Float>> for ChangePointDetector<Trained> {
241    fn transform(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
242        let result = match self.config.method {
243            ChangePointMethod::CUSUM => self.detect_cusum(x),
244            ChangePointMethod::VarianceChange => self.detect_variance_change(x),
245            ChangePointMethod::MeanChange => self.detect_mean_change(x),
246            ChangePointMethod::SimpleDifference => self.detect_simple_difference(x),
247        };
248
249        Ok(result)
250    }
251}