scirs2_transform/monitoring/
adwin.rs1use crate::error::{Result, TransformError};
29
30#[derive(Debug, Clone)]
32struct Bucket {
33 count: usize,
35 total: f64,
37 variance: f64,
39}
40
41#[derive(Debug, Clone)]
65pub struct Adwin {
66 delta: f64,
68 buckets: Vec<Vec<Bucket>>,
70 max_buckets: usize,
72 total_count: usize,
74 total_sum: f64,
76 total_variance: f64,
78 last_change_detected: bool,
80 min_window_length: usize,
82}
83
84impl Adwin {
85 pub fn new(delta: f64) -> Result<Self> {
91 if delta <= 0.0 || delta >= 1.0 {
92 return Err(TransformError::InvalidInput(
93 "delta must be in (0, 1)".to_string(),
94 ));
95 }
96 Ok(Self {
97 delta,
98 buckets: Vec::new(),
99 max_buckets: 5, total_count: 0,
101 total_sum: 0.0,
102 total_variance: 0.0,
103 last_change_detected: false,
104 min_window_length: 10,
105 })
106 }
107
108 pub fn set_min_window_length(&mut self, min_len: usize) {
110 self.min_window_length = min_len;
111 }
112
113 pub fn add_element(&mut self, value: f64) -> Result<bool> {
117 if !value.is_finite() {
118 return Err(TransformError::InvalidInput(
119 "Value must be finite".to_string(),
120 ));
121 }
122
123 self.last_change_detected = false;
124
125 let new_bucket = Bucket {
127 count: 1,
128 total: value,
129 variance: 0.0,
130 };
131
132 if self.buckets.is_empty() {
133 self.buckets.push(Vec::new());
134 }
135 self.buckets[0].push(new_bucket);
136 self.total_count += 1;
137 self.total_sum += value;
138 self.total_variance += value * value;
139
140 self.compress();
142
143 if self.total_count >= self.min_window_length {
145 self.last_change_detected = self.check_and_cut();
146 }
147
148 Ok(self.last_change_detected)
149 }
150
151 fn compress(&mut self) {
154 let mut level = 0;
155 while level < self.buckets.len() {
156 if self.buckets[level].len() > self.max_buckets + 1 {
157 if self.buckets[level].len() >= 2 {
159 let b1 = self.buckets[level].remove(0);
160 let b2 = self.buckets[level].remove(0);
161
162 let merged_count = b1.count + b2.count;
163 let merged_total = b1.total + b2.total;
164 let delta_mean =
166 b2.total / b2.count.max(1) as f64 - b1.total / b1.count.max(1) as f64;
167 let merged_variance = b1.variance
168 + b2.variance
169 + delta_mean * delta_mean * (b1.count * b2.count) as f64
170 / merged_count.max(1) as f64;
171
172 let merged = Bucket {
173 count: merged_count,
174 total: merged_total,
175 variance: merged_variance,
176 };
177
178 if level + 1 >= self.buckets.len() {
180 self.buckets.push(Vec::new());
181 }
182 self.buckets[level + 1].push(merged);
183 }
184 }
185 level += 1;
186 }
187 }
188
189 fn check_and_cut(&mut self) -> bool {
192 let mut w1_count: usize = 0;
195 let mut w1_sum: f64 = 0.0;
196 let mut _w1_var: f64 = 0.0;
197
198 let n_levels = self.buckets.len();
201
202 let mut ordered_buckets: Vec<(usize, usize)> = Vec::new(); for level in 0..n_levels {
205 for idx in (0..self.buckets[level].len()).rev() {
206 ordered_buckets.push((level, idx));
207 }
208 }
209
210 for &(level, idx) in ordered_buckets.iter() {
211 let bucket = &self.buckets[level][idx];
212 w1_count += bucket.count;
213 w1_sum += bucket.total;
214 _w1_var += bucket.variance;
215
216 let w0_count = self.total_count - w1_count;
217 if w0_count < 1 || w1_count < 1 {
218 continue;
219 }
220
221 let w0_sum = self.total_sum - w1_sum;
222
223 let mean0 = w0_sum / w0_count as f64;
224 let mean1 = w1_sum / w1_count as f64;
225 let diff = (mean0 - mean1).abs();
226
227 let n = self.total_count as f64;
229 let m = (1.0 / w0_count as f64 + 1.0 / w1_count as f64).min(1.0);
230 let delta_prime = self.delta / n.ln().max(1.0);
231 let epsilon = ((m / (2.0 * delta_prime)).ln().max(0.0) * m / 2.0).sqrt();
232
233 if diff >= epsilon && w0_count >= 2 && w1_count >= 2 {
234 self.drop_oldest(w0_count);
236 return true;
237 }
238 }
239
240 false
241 }
242
243 fn drop_oldest(&mut self, count: usize) {
245 let mut remaining = count;
246
247 let mut level = self.buckets.len();
249 while level > 0 && remaining > 0 {
250 level -= 1;
251 while !self.buckets[level].is_empty() && remaining > 0 {
252 let bucket = &self.buckets[level][0];
253 if bucket.count <= remaining {
254 let removed = self.buckets[level].remove(0);
255 remaining -= removed.count;
256 self.total_count -= removed.count;
257 self.total_sum -= removed.total;
258 self.total_variance -=
259 removed.total * removed.total / removed.count.max(1) as f64;
260 } else {
261 break;
262 }
263 }
264 }
265
266 while let Some(last) = self.buckets.last() {
268 if last.is_empty() {
269 self.buckets.pop();
270 } else {
271 break;
272 }
273 }
274 }
275
276 pub fn detected_change(&self) -> bool {
278 self.last_change_detected
279 }
280
281 pub fn current_mean(&self) -> f64 {
283 if self.total_count == 0 {
284 0.0
285 } else {
286 self.total_sum / self.total_count as f64
287 }
288 }
289
290 pub fn current_length(&self) -> usize {
292 self.total_count
293 }
294
295 pub fn current_sum(&self) -> f64 {
297 self.total_sum
298 }
299
300 pub fn delta(&self) -> f64 {
302 self.delta
303 }
304
305 pub fn reset(&mut self) {
307 self.buckets.clear();
308 self.total_count = 0;
309 self.total_sum = 0.0;
310 self.total_variance = 0.0;
311 self.last_change_detected = false;
312 }
313}
314
315#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_adwin_no_change_stable_data() {
325 let mut adwin = Adwin::new(0.01).expect("valid delta");
326
327 let mut any_change = false;
328 for i in 0..500 {
329 let val = 5.0 + (i as f64) * 0.0001; let changed = adwin.add_element(val).expect("add");
331 if changed {
332 any_change = true;
333 }
334 }
335
336 let mean = adwin.current_mean();
339 assert!(
340 mean > 4.0 && mean < 6.0,
341 "Mean should be around 5.0: {}",
342 mean
343 );
344 assert!(adwin.current_length() > 0);
345 let _ = any_change;
348 }
349
350 #[test]
351 fn test_adwin_detect_abrupt_change() {
352 let mut adwin = Adwin::new(0.002).expect("valid delta");
353 adwin.set_min_window_length(5);
354
355 for _ in 0..200 {
357 adwin.add_element(0.0).expect("add");
358 }
359
360 let mut detected = false;
362 for _ in 0..200 {
363 let changed = adwin.add_element(100.0).expect("add");
364 if changed {
365 detected = true;
366 break;
367 }
368 }
369
370 assert!(
371 detected,
372 "ADWIN should detect abrupt mean shift from 0 to 100"
373 );
374 }
375
376 #[test]
377 fn test_adwin_window_shrinks_on_change() {
378 let mut adwin = Adwin::new(0.01).expect("valid delta");
379 adwin.set_min_window_length(5);
380
381 for _ in 0..200 {
383 adwin.add_element(0.0).expect("add");
384 }
385 let len_before = adwin.current_length();
386 assert!(len_before > 100, "Window should have grown: {}", len_before);
387
388 for _ in 0..200 {
390 let changed = adwin.add_element(50.0).expect("add");
391 if changed {
392 break;
393 }
394 }
395
396 let len_after = adwin.current_length();
397 assert!(
398 len_after < len_before,
399 "Window should shrink after drift: {} -> {}",
400 len_before,
401 len_after
402 );
403 }
404
405 #[test]
406 fn test_adwin_mean_tracking() {
407 let mut adwin = Adwin::new(0.05).expect("valid delta");
408
409 for _ in 0..100 {
410 adwin.add_element(10.0).expect("add");
411 }
412
413 let mean = adwin.current_mean();
414 assert!(
415 (mean - 10.0).abs() < 1.0,
416 "Mean should be close to 10.0: {}",
417 mean
418 );
419 }
420
421 #[test]
422 fn test_adwin_reset() {
423 let mut adwin = Adwin::new(0.01).expect("valid delta");
424 for _ in 0..50 {
425 adwin.add_element(1.0).expect("add");
426 }
427 assert!(adwin.current_length() > 0);
428
429 adwin.reset();
430 assert_eq!(adwin.current_length(), 0);
431 assert!((adwin.current_mean()).abs() < 1e-15);
432 }
433
434 #[test]
435 fn test_adwin_invalid_delta() {
436 assert!(Adwin::new(0.0).is_err());
437 assert!(Adwin::new(1.0).is_err());
438 assert!(Adwin::new(-0.5).is_err());
439 }
440
441 #[test]
442 fn test_adwin_nan_input() {
443 let mut adwin = Adwin::new(0.01).expect("valid delta");
444 assert!(adwin.add_element(f64::NAN).is_err());
445 assert!(adwin.add_element(f64::INFINITY).is_err());
446 }
447
448 #[test]
449 fn test_adwin_accessors() {
450 let adwin = Adwin::new(0.05).expect("valid delta");
451 assert!((adwin.delta() - 0.05).abs() < 1e-15);
452 assert_eq!(adwin.current_length(), 0);
453 assert!(!adwin.detected_change());
454 }
455}