sklears_preprocessing/temporal/
changepoint.rs1use scirs2_core::ndarray::{s, Array1};
7use sklears_core::{
8 error::Result,
9 traits::{Fit, Trained, Transform, Untrained},
10 types::Float,
11};
12use std::marker::PhantomData;
13
14#[derive(Debug, Clone)]
16pub struct ChangePointDetectorConfig {
17 pub method: ChangePointMethod,
19 pub min_segment_length: usize,
21 pub threshold: Float,
23 pub binary_output: bool,
25}
26
27#[derive(Debug, Clone, Copy)]
29pub enum ChangePointMethod {
30 CUSUM,
32 VarianceChange,
34 MeanChange,
36 SimpleDifference,
38}
39
40impl Default for ChangePointDetectorConfig {
41 fn default() -> Self {
42 Self {
43 method: ChangePointMethod::CUSUM,
44 min_segment_length: 5,
45 threshold: 2.0,
46 binary_output: true,
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct ChangePointDetector<S> {
54 config: ChangePointDetectorConfig,
55 _phantom: PhantomData<S>,
56}
57
58impl ChangePointDetector<Untrained> {
59 pub fn new() -> Self {
61 Self {
62 config: ChangePointDetectorConfig::default(),
63 _phantom: PhantomData,
64 }
65 }
66
67 pub fn method(mut self, method: ChangePointMethod) -> Self {
69 self.config.method = method;
70 self
71 }
72
73 pub fn min_segment_length(mut self, min_segment_length: usize) -> Self {
75 self.config.min_segment_length = min_segment_length;
76 self
77 }
78
79 pub fn threshold(mut self, threshold: Float) -> Self {
81 self.config.threshold = threshold;
82 self
83 }
84
85 pub fn binary_output(mut self, binary_output: bool) -> Self {
87 self.config.binary_output = binary_output;
88 self
89 }
90}
91
92impl ChangePointDetector<Trained> {
93 fn detect_cusum(&self, data: &Array1<Float>) -> Array1<Float> {
95 let n = data.len();
96 let mut scores = Array1::<Float>::zeros(n);
97
98 if n < 2 {
99 return scores;
100 }
101
102 let mean = data.mean().unwrap_or(0.0);
103 let std = data.std(0.0);
104
105 if std < 1e-10 {
106 return scores;
107 }
108
109 let mut cusum_pos = 0.0;
110 let mut cusum_neg = 0.0;
111
112 for i in 0..n {
113 let standardized = (data[i] - mean) / std;
114
115 cusum_pos = (cusum_pos + standardized - 0.5).max(0.0);
116 cusum_neg = (cusum_neg - standardized - 0.5).max(0.0);
117
118 scores[i] = cusum_pos.max(cusum_neg);
119 }
120
121 if self.config.binary_output {
122 scores.mapv(|x| if x > self.config.threshold { 1.0 } else { 0.0 })
123 } else {
124 scores
125 }
126 }
127
128 fn detect_variance_change(&self, data: &Array1<Float>) -> Array1<Float> {
130 let n = data.len();
131 let window = self.config.min_segment_length;
132 let mut scores = Array1::<Float>::zeros(n);
133
134 if n < 2 * window {
135 return scores;
136 }
137
138 for i in window..(n - window) {
139 let left_window = data.slice(s![(i - window)..i]);
140 let right_window = data.slice(s![i..(i + window)]);
141
142 let left_var = left_window.var(0.0);
143 let right_var = right_window.var(0.0);
144
145 let ratio = if right_var > 1e-10 {
146 left_var / right_var
147 } else if left_var > 1e-10 {
148 Float::INFINITY
149 } else {
150 1.0
151 };
152
153 scores[i] = (ratio.ln()).abs();
154 }
155
156 if self.config.binary_output {
157 scores.mapv(|x| if x > self.config.threshold { 1.0 } else { 0.0 })
158 } else {
159 scores
160 }
161 }
162
163 fn detect_mean_change(&self, data: &Array1<Float>) -> Array1<Float> {
165 let n = data.len();
166 let window = self.config.min_segment_length;
167 let mut scores = Array1::<Float>::zeros(n);
168
169 if n < 2 * window {
170 return scores;
171 }
172
173 for i in window..(n - window) {
174 let left_window = data.slice(s![(i - window)..i]);
175 let right_window = data.slice(s![i..(i + window)]);
176
177 let left_mean = left_window.mean().unwrap_or(0.0);
178 let right_mean = right_window.mean().unwrap_or(0.0);
179 let pooled_std = ((left_window.var(0.0) + right_window.var(0.0)) / 2.0).sqrt();
180
181 scores[i] = if pooled_std > 1e-10 {
182 (left_mean - right_mean).abs() / pooled_std
183 } else {
184 0.0
185 };
186 }
187
188 if self.config.binary_output {
189 scores.mapv(|x| if x > self.config.threshold { 1.0 } else { 0.0 })
190 } else {
191 scores
192 }
193 }
194
195 fn detect_simple_difference(&self, data: &Array1<Float>) -> Array1<Float> {
197 let n = data.len();
198 let mut scores = Array1::<Float>::zeros(n);
199
200 if n < 2 {
201 return scores;
202 }
203
204 for i in 1..n {
206 scores[i] = (data[i] - data[i - 1]).abs();
207 }
208
209 let threshold = if self.config.binary_output {
210 scores.mean().unwrap_or(0.0) + self.config.threshold * scores.std(0.0)
211 } else {
212 0.0
213 };
214
215 if self.config.binary_output {
216 scores.mapv(|x| if x > threshold { 1.0 } else { 0.0 })
217 } else {
218 scores
219 }
220 }
221}
222
223impl Default for ChangePointDetector<Untrained> {
224 fn default() -> Self {
225 Self::new()
226 }
227}
228
229impl Fit<Array1<Float>, ()> for ChangePointDetector<Untrained> {
230 type Fitted = ChangePointDetector<Trained>;
231
232 fn fit(self, _x: &Array1<Float>, _y: &()) -> Result<Self::Fitted> {
233 Ok(ChangePointDetector {
234 config: self.config,
235 _phantom: PhantomData,
236 })
237 }
238}
239
240impl Transform<Array1<Float>, Array1<Float>> for ChangePointDetector<Trained> {
241 fn transform(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
242 let result = match self.config.method {
243 ChangePointMethod::CUSUM => self.detect_cusum(x),
244 ChangePointMethod::VarianceChange => self.detect_variance_change(x),
245 ChangePointMethod::MeanChange => self.detect_mean_change(x),
246 ChangePointMethod::SimpleDifference => self.detect_simple_difference(x),
247 };
248
249 Ok(result)
250 }
251}