1use scirs2_core::ndarray::distributions::Distribution;
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::random::{prelude::*, Rng};
10use sklears_core::error::{Result, SklearsError};
11use sklears_core::traits::{Estimator, Fit, Predict, Trained};
12use sklears_core::types::Float;
13use std::collections::{HashMap, VecDeque};
14
15#[derive(Debug, Clone, PartialEq)]
17pub enum DriftDetectionMethod {
18 ADWIN,
20 PageHinkley,
22 EDDM,
24 StatisticalTest,
26}
27
28#[derive(Debug, Clone, PartialEq)]
30pub enum WindowStrategy {
31 FixedWindow(usize),
33 ExponentialDecay(f64),
35 AdaptiveWindow,
37 ForgettingFactor(f64),
39}
40
41#[derive(Debug, Clone, PartialEq)]
43pub enum OnlineStrategy {
44 OnlineMean {
46 drift_detection: Option<DriftDetectionMethod>,
47 },
48 EWMA { alpha: f64 },
50 AdaptiveWindow {
52 max_window_size: usize,
53 drift_threshold: f64,
54 },
55 ForgettingFactor { lambda: f64 },
57 OnlineQuantile { quantile: f64, learning_rate: f64 },
59}
60
61#[derive(Debug, Clone)]
63pub struct OnlineDummyRegressor<State = sklears_core::traits::Untrained> {
64 strategy: OnlineStrategy,
65 window_strategy: WindowStrategy,
66 random_state: Option<u64>,
67 running_mean: f64,
69 running_variance: f64,
70 sample_count: usize,
71 ewma_mean: f64,
72 forgetting_weight_sum: f64,
73 quantile_estimate: f64,
74 window_data: VecDeque<f64>,
76 drift_detector_state: DriftDetectorState,
78 _state: std::marker::PhantomData<State>,
80}
81
82#[derive(Debug, Clone)]
84struct DriftDetectorState {
85 adwin_buckets: VecDeque<(f64, usize)>,
87 adwin_total: f64,
88 adwin_count: usize,
89 ph_sum: f64,
91 ph_min: f64,
92 ph_threshold: f64,
93 eddm_errors: VecDeque<bool>,
95 eddm_distances: VecDeque<usize>,
96 eddm_mean_distance: f64,
97 eddm_std_distance: f64,
98}
99
100impl Default for DriftDetectorState {
101 fn default() -> Self {
102 Self {
103 adwin_buckets: VecDeque::new(),
104 adwin_total: 0.0,
105 adwin_count: 0,
106 ph_sum: 0.0,
107 ph_min: 0.0,
108 ph_threshold: 50.0,
109 eddm_errors: VecDeque::new(),
110 eddm_distances: VecDeque::new(),
111 eddm_mean_distance: 0.0,
112 eddm_std_distance: 0.0,
113 }
114 }
115}
116
117impl<State> OnlineDummyRegressor<State> {
118 pub fn new(strategy: OnlineStrategy) -> Self {
120 Self {
121 strategy,
122 window_strategy: WindowStrategy::FixedWindow(1000),
123 random_state: None,
124 running_mean: 0.0,
125 running_variance: 0.0,
126 sample_count: 0,
127 ewma_mean: 0.0,
128 forgetting_weight_sum: 0.0,
129 quantile_estimate: 0.0,
130 window_data: VecDeque::new(),
131 drift_detector_state: DriftDetectorState::default(),
132 _state: std::marker::PhantomData,
133 }
134 }
135
136 pub fn with_window_strategy(mut self, window_strategy: WindowStrategy) -> Self {
138 self.window_strategy = window_strategy;
139 self
140 }
141
142 pub fn with_random_state(mut self, random_state: u64) -> Self {
144 self.random_state = Some(random_state);
145 self
146 }
147
148 pub fn partial_fit(&mut self, target: f64) -> Result<()> {
150 self.sample_count += 1;
151
152 match &self.strategy {
154 OnlineStrategy::OnlineMean { drift_detection } => {
155 let drift_detection = drift_detection.clone();
156 self.update_online_mean(target);
157 if let Some(detection_method) = &drift_detection {
158 if self.detect_drift(target, detection_method)? {
159 self.handle_drift();
160 }
161 }
162 }
163 OnlineStrategy::EWMA { alpha } => {
164 self.update_ewma(target, *alpha);
165 }
166 OnlineStrategy::AdaptiveWindow {
167 max_window_size,
168 drift_threshold,
169 } => {
170 self.update_adaptive_window(target, *max_window_size, *drift_threshold)?;
171 }
172 OnlineStrategy::ForgettingFactor { lambda } => {
173 self.update_forgetting_factor(target, *lambda);
174 }
175 OnlineStrategy::OnlineQuantile {
176 quantile,
177 learning_rate,
178 } => {
179 self.update_online_quantile(target, *quantile, *learning_rate);
180 }
181 }
182
183 match &self.window_strategy {
185 WindowStrategy::FixedWindow(size) => {
186 self.window_data.push_back(target);
187 if self.window_data.len() > *size {
188 self.window_data.pop_front();
189 }
190 }
191 _ => {} }
193
194 Ok(())
195 }
196
197 pub fn predict_single(&self) -> f64 {
199 match &self.strategy {
200 OnlineStrategy::OnlineMean { .. } => self.running_mean,
201 OnlineStrategy::EWMA { .. } => self.ewma_mean,
202 OnlineStrategy::AdaptiveWindow { .. } => {
203 if self.window_data.is_empty() {
204 0.0
205 } else {
206 self.window_data.iter().sum::<f64>() / self.window_data.len() as f64
207 }
208 }
209 OnlineStrategy::ForgettingFactor { .. } => self.running_mean,
210 OnlineStrategy::OnlineQuantile { .. } => self.quantile_estimate,
211 }
212 }
213
214 pub fn sample_count(&self) -> usize {
216 self.sample_count
217 }
218
219 pub fn get_statistics(&self) -> (f64, f64) {
221 (self.running_mean, self.running_variance)
222 }
223
224 pub fn drift_detected(&self) -> bool {
226 false }
229
230 fn update_online_mean(&mut self, target: f64) {
231 let delta = target - self.running_mean;
232 self.running_mean += delta / self.sample_count as f64;
233
234 if self.sample_count > 1 {
235 let delta2 = target - self.running_mean;
236 self.running_variance +=
237 (delta * delta2 - self.running_variance) / (self.sample_count - 1) as f64;
238 }
239 }
240
241 fn update_ewma(&mut self, target: f64, alpha: f64) {
242 if self.sample_count == 1 {
243 self.ewma_mean = target;
244 } else {
245 self.ewma_mean = alpha * target + (1.0 - alpha) * self.ewma_mean;
246 }
247 }
248
249 fn update_adaptive_window(
250 &mut self,
251 target: f64,
252 max_size: usize,
253 drift_threshold: f64,
254 ) -> Result<()> {
255 self.window_data.push_back(target);
256
257 if self.window_data.len() > 10 {
259 let recent_mean: f64 = self.window_data.iter().rev().take(5).sum::<f64>() / 5.0;
260 let overall_mean: f64 =
261 self.window_data.iter().sum::<f64>() / self.window_data.len() as f64;
262
263 if (recent_mean - overall_mean).abs() > drift_threshold {
264 let new_size = std::cmp::max(self.window_data.len() / 2, 10);
266 while self.window_data.len() > new_size {
267 self.window_data.pop_front();
268 }
269 }
270 }
271
272 if self.window_data.len() > max_size {
273 self.window_data.pop_front();
274 }
275
276 Ok(())
277 }
278
279 fn update_forgetting_factor(&mut self, target: f64, lambda: f64) {
280 self.forgetting_weight_sum = lambda * self.forgetting_weight_sum + 1.0;
281 self.running_mean = (lambda * self.running_mean * (self.forgetting_weight_sum - 1.0)
282 + target)
283 / self.forgetting_weight_sum;
284 }
285
286 fn update_online_quantile(&mut self, target: f64, quantile: f64, learning_rate: f64) {
287 if self.sample_count == 1 {
288 self.quantile_estimate = target;
289 } else {
290 let error = if target > self.quantile_estimate {
291 quantile
292 } else {
293 quantile - 1.0
294 };
295 self.quantile_estimate += learning_rate * error;
296 }
297 }
298
299 fn detect_drift(&mut self, target: f64, method: &DriftDetectionMethod) -> Result<bool> {
300 match method {
301 DriftDetectionMethod::ADWIN => self.adwin_drift_detection(target),
302 DriftDetectionMethod::PageHinkley => self.page_hinkley_drift_detection(target),
303 DriftDetectionMethod::EDDM => self.eddm_drift_detection(target),
304 DriftDetectionMethod::StatisticalTest => self.statistical_drift_detection(target),
305 }
306 }
307
308 fn adwin_drift_detection(&mut self, target: f64) -> Result<bool> {
309 self.drift_detector_state.adwin_total += target;
311 self.drift_detector_state.adwin_count += 1;
312 self.drift_detector_state
313 .adwin_buckets
314 .push_back((target, 1));
315
316 if self.drift_detector_state.adwin_buckets.len() > 5 {
318 let recent_sum: f64 = self
319 .drift_detector_state
320 .adwin_buckets
321 .iter()
322 .rev()
323 .take(3)
324 .map(|(v, _)| v)
325 .sum();
326 let recent_mean = recent_sum / 3.0;
327 let overall_mean = self.drift_detector_state.adwin_total
328 / self.drift_detector_state.adwin_count as f64;
329
330 Ok((recent_mean - overall_mean).abs() > 2.0) } else {
332 Ok(false)
333 }
334 }
335
336 fn page_hinkley_drift_detection(&mut self, target: f64) -> Result<bool> {
337 let mean_estimate = self.running_mean;
338 self.drift_detector_state.ph_sum += target - mean_estimate - 0.5; self.drift_detector_state.ph_min = self
340 .drift_detector_state
341 .ph_min
342 .min(self.drift_detector_state.ph_sum);
343
344 let test_statistic = self.drift_detector_state.ph_sum - self.drift_detector_state.ph_min;
345 Ok(test_statistic > self.drift_detector_state.ph_threshold)
346 }
347
348 fn eddm_drift_detection(&mut self, _target: f64) -> Result<bool> {
349 Ok(false)
351 }
352
353 fn statistical_drift_detection(&mut self, target: f64) -> Result<bool> {
354 if self.sample_count < 30 {
355 return Ok(false);
356 }
357
358 let z_score = (target - self.running_mean) / self.running_variance.sqrt();
360 Ok(z_score.abs() > 3.0) }
362
363 fn handle_drift(&mut self) {
364 self.running_mean = 0.0;
366 self.running_variance = 0.0;
367 self.sample_count = 0;
368 self.ewma_mean = 0.0;
369 self.forgetting_weight_sum = 0.0;
370 self.window_data.clear();
371 self.drift_detector_state = DriftDetectorState::default();
372 }
373}
374
375#[derive(Debug, Clone)]
377pub struct OnlineDummyClassifier<State = sklears_core::traits::Untrained> {
378 strategy: OnlineClassificationStrategy,
379 class_counts: HashMap<i32, usize>,
380 total_samples: usize,
381 window_strategy: WindowStrategy,
382 class_window: VecDeque<i32>,
383 random_state: Option<u64>,
384 _state: std::marker::PhantomData<State>,
385}
386
387#[derive(Debug, Clone, PartialEq)]
389pub enum OnlineClassificationStrategy {
390 OnlineMostFrequent,
392 ExponentiallyWeighted { alpha: f64 },
394 AdaptiveDistribution { window_size: usize },
396 UniformWithForgetting { lambda: f64 },
398}
399
400impl<State> OnlineDummyClassifier<State> {
401 pub fn new(strategy: OnlineClassificationStrategy) -> Self {
403 Self {
404 strategy,
405 class_counts: HashMap::new(),
406 total_samples: 0,
407 window_strategy: WindowStrategy::FixedWindow(1000),
408 class_window: VecDeque::new(),
409 random_state: None,
410 _state: std::marker::PhantomData,
411 }
412 }
413
414 pub fn with_window_strategy(mut self, window_strategy: WindowStrategy) -> Self {
416 self.window_strategy = window_strategy;
417 self
418 }
419
420 pub fn with_random_state(mut self, random_state: u64) -> Self {
422 self.random_state = Some(random_state);
423 self
424 }
425
426 pub fn partial_fit(&mut self, target: i32) {
428 self.total_samples += 1;
429 *self.class_counts.entry(target).or_insert(0) += 1;
430
431 match &self.window_strategy {
432 WindowStrategy::FixedWindow(size) => {
433 self.class_window.push_back(target);
434 if self.class_window.len() > *size {
435 if let Some(old_class) = self.class_window.pop_front() {
436 if let Some(count) = self.class_counts.get_mut(&old_class) {
437 *count = count.saturating_sub(1);
438 if *count == 0 {
439 self.class_counts.remove(&old_class);
440 }
441 }
442 self.total_samples = self.total_samples.saturating_sub(1);
443 }
444 }
445 }
446 _ => {} }
448 }
449
450 pub fn predict_single(&self) -> Option<i32> {
452 match &self.strategy {
453 OnlineClassificationStrategy::OnlineMostFrequent => self
454 .class_counts
455 .iter()
456 .max_by_key(|(_, &count)| count)
457 .map(|(&class, _)| class),
458 OnlineClassificationStrategy::ExponentiallyWeighted { .. } => {
459 self.class_counts
461 .iter()
462 .max_by_key(|(_, &count)| count)
463 .map(|(&class, _)| class)
464 }
465 OnlineClassificationStrategy::AdaptiveDistribution { .. } => self
466 .class_counts
467 .iter()
468 .max_by_key(|(_, &count)| count)
469 .map(|(&class, _)| class),
470 OnlineClassificationStrategy::UniformWithForgetting { .. } => {
471 if self.class_counts.is_empty() {
472 None
473 } else {
474 let classes: Vec<i32> = self.class_counts.keys().cloned().collect();
475 let mut rng = if let Some(seed) = self.random_state {
476 StdRng::seed_from_u64(seed)
477 } else {
478 StdRng::seed_from_u64(0)
479 };
480 Some(classes[rng.gen_range(0..classes.len())])
481 }
482 }
483 }
484 }
485
486 pub fn get_class_distribution(&self) -> HashMap<i32, f64> {
488 if self.total_samples == 0 {
489 return HashMap::new();
490 }
491
492 self.class_counts
493 .iter()
494 .map(|(&class, &count)| (class, count as f64 / self.total_samples as f64))
495 .collect()
496 }
497
498 pub fn sample_count(&self) -> usize {
500 self.total_samples
501 }
502}
503
504impl Estimator for OnlineDummyRegressor {
505 type Config = ();
506 type Error = SklearsError;
507 type Float = Float;
508
509 fn config(&self) -> &Self::Config {
510 &()
511 }
512}
513
514impl Fit<Array2<Float>, Array1<Float>> for OnlineDummyRegressor {
515 type Fitted = OnlineDummyRegressor<Trained>;
516
517 fn fit(self, _x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
518 let mut regressor = self;
519
520 for &target in y.iter() {
521 regressor.partial_fit(target)?;
522 }
523
524 Ok(OnlineDummyRegressor {
525 strategy: regressor.strategy,
526 window_strategy: regressor.window_strategy,
527 random_state: regressor.random_state,
528 running_mean: regressor.running_mean,
529 running_variance: regressor.running_variance,
530 sample_count: regressor.sample_count,
531 ewma_mean: regressor.ewma_mean,
532 forgetting_weight_sum: regressor.forgetting_weight_sum,
533 quantile_estimate: regressor.quantile_estimate,
534 window_data: regressor.window_data,
535 drift_detector_state: regressor.drift_detector_state,
536 _state: std::marker::PhantomData::<Trained>,
537 })
538 }
539}
540
541impl Predict<Array2<Float>, Array1<Float>> for OnlineDummyRegressor<Trained> {
542 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
543 let n_samples = x.nrows();
544 let prediction = self.predict_single();
545 Ok(Array1::from_elem(n_samples, prediction))
546 }
547}
548
549impl Estimator for OnlineDummyClassifier {
550 type Config = ();
551 type Error = SklearsError;
552 type Float = Float;
553
554 fn config(&self) -> &Self::Config {
555 &()
556 }
557}
558
559impl Fit<Array2<Float>, Array1<i32>> for OnlineDummyClassifier {
560 type Fitted = OnlineDummyClassifier<Trained>;
561
562 fn fit(self, _x: &Array2<Float>, y: &Array1<i32>) -> Result<Self::Fitted> {
563 let mut classifier = self;
564
565 for &target in y.iter() {
566 classifier.partial_fit(target);
567 }
568
569 Ok(OnlineDummyClassifier {
570 strategy: classifier.strategy,
571 class_counts: classifier.class_counts,
572 total_samples: classifier.total_samples,
573 window_strategy: classifier.window_strategy,
574 class_window: classifier.class_window,
575 random_state: classifier.random_state,
576 _state: std::marker::PhantomData::<Trained>,
577 })
578 }
579}
580
581impl Predict<Array2<Float>, Array1<i32>> for OnlineDummyClassifier<Trained> {
582 fn predict(&self, x: &Array2<Float>) -> Result<Array1<i32>> {
583 let n_samples = x.nrows();
584 let prediction = self.predict_single().unwrap_or(0);
585 Ok(Array1::from_elem(n_samples, prediction))
586 }
587}
588
589#[allow(non_snake_case)]
590#[cfg(test)]
591mod tests {
592 use super::*;
593 use approx::assert_abs_diff_eq;
594
595 #[test]
596 fn test_online_dummy_regressor_mean() {
597 let mut regressor: OnlineDummyRegressor =
598 OnlineDummyRegressor::new(OnlineStrategy::OnlineMean {
599 drift_detection: None,
600 });
601
602 regressor.partial_fit(1.0).unwrap();
603 assert_abs_diff_eq!(regressor.predict_single(), 1.0, epsilon = 1e-10);
604
605 regressor.partial_fit(3.0).unwrap();
606 assert_abs_diff_eq!(regressor.predict_single(), 2.0, epsilon = 1e-10);
607
608 regressor.partial_fit(2.0).unwrap();
609 assert_abs_diff_eq!(regressor.predict_single(), 2.0, epsilon = 1e-10);
610 }
611
612 #[test]
613 fn test_online_dummy_regressor_ewma() {
614 let mut regressor: OnlineDummyRegressor =
615 OnlineDummyRegressor::new(OnlineStrategy::EWMA { alpha: 0.5 });
616
617 regressor.partial_fit(1.0).unwrap();
618 assert_abs_diff_eq!(regressor.predict_single(), 1.0, epsilon = 1e-10);
619
620 regressor.partial_fit(3.0).unwrap();
621 assert_abs_diff_eq!(regressor.predict_single(), 2.0, epsilon = 1e-10);
622
623 regressor.partial_fit(1.0).unwrap();
624 assert_abs_diff_eq!(regressor.predict_single(), 1.5, epsilon = 1e-10);
625 }
626
627 #[test]
628 fn test_online_dummy_regressor_quantile() {
629 let mut regressor: OnlineDummyRegressor =
630 OnlineDummyRegressor::new(OnlineStrategy::OnlineQuantile {
631 quantile: 0.5,
632 learning_rate: 0.1,
633 });
634
635 for value in [1.0, 2.0, 3.0, 4.0, 5.0] {
636 regressor.partial_fit(value).unwrap();
637 }
638
639 let prediction = regressor.predict_single();
641 assert!(prediction > 1.0 && prediction < 5.0);
642 }
643
644 #[test]
645 fn test_online_dummy_classifier() {
646 let mut classifier: OnlineDummyClassifier =
647 OnlineDummyClassifier::new(OnlineClassificationStrategy::OnlineMostFrequent);
648
649 classifier.partial_fit(0);
650 assert_eq!(classifier.predict_single(), Some(0));
651
652 classifier.partial_fit(1);
653 classifier.partial_fit(1);
654 assert_eq!(classifier.predict_single(), Some(1));
655
656 let distribution = classifier.get_class_distribution();
658 assert_abs_diff_eq!(distribution[&0], 1.0 / 3.0, epsilon = 1e-10);
659 assert_abs_diff_eq!(distribution[&1], 2.0 / 3.0, epsilon = 1e-10);
660 }
661
662 #[test]
663 fn test_adaptive_window() {
664 let mut regressor: OnlineDummyRegressor =
665 OnlineDummyRegressor::new(OnlineStrategy::AdaptiveWindow {
666 max_window_size: 5,
667 drift_threshold: 1.0,
668 });
669
670 for value in [1.0, 1.1, 0.9, 1.0, 1.1] {
672 regressor.partial_fit(value).unwrap();
673 }
674
675 regressor.partial_fit(5.0).unwrap();
677
678 assert!(regressor.window_data.len() <= 10); }
681
682 #[test]
683 fn test_forgetting_factor() {
684 let mut regressor: OnlineDummyRegressor =
685 OnlineDummyRegressor::new(OnlineStrategy::ForgettingFactor { lambda: 0.9 });
686
687 regressor.partial_fit(1.0).unwrap();
688 let pred1 = regressor.predict_single();
689
690 regressor.partial_fit(10.0).unwrap();
691 let pred2 = regressor.predict_single();
692
693 assert!(pred2 > pred1);
695 assert!(pred2 < 10.0); }
697
698 #[test]
699 fn test_drift_detection() {
700 let mut regressor: OnlineDummyRegressor =
701 OnlineDummyRegressor::new(OnlineStrategy::OnlineMean {
702 drift_detection: Some(DriftDetectionMethod::ADWIN),
703 });
704
705 for value in [1.0; 10] {
707 regressor.partial_fit(value).unwrap();
708 }
709
710 for value in [5.0; 5] {
712 regressor.partial_fit(value).unwrap();
713 }
714
715 assert!(regressor.sample_count() > 0);
717 }
718
719 #[test]
720 fn test_window_strategy_fixed() {
721 let regressor: OnlineDummyRegressor =
722 OnlineDummyRegressor::new(OnlineStrategy::OnlineMean {
723 drift_detection: None,
724 })
725 .with_window_strategy(WindowStrategy::FixedWindow(3));
726
727 let mut regressor = regressor;
728
729 for value in [1.0, 2.0, 3.0, 4.0, 5.0] {
730 regressor.partial_fit(value).unwrap();
731 }
732
733 assert_eq!(regressor.window_data.len(), 3);
735 assert_eq!(regressor.window_data[0], 3.0);
736 assert_eq!(regressor.window_data[1], 4.0);
737 assert_eq!(regressor.window_data[2], 5.0);
738 }
739
740 #[test]
741 fn test_online_estimator_trait() {
742 let x =
743 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
744 let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
745
746 let regressor = OnlineDummyRegressor::new(OnlineStrategy::OnlineMean {
747 drift_detection: None,
748 });
749 let fitted = regressor.fit(&x, &y).unwrap();
750 let predictions = fitted.predict(&x).unwrap();
751
752 assert_eq!(predictions.len(), 4);
753 assert_abs_diff_eq!(predictions[0], 2.5, epsilon = 1e-10); }
755}