1use crate::gradient_boosting::{GradientBoostingRegressor, TrainedGradientBoostingRegressor};
8use scirs2_core::ndarray::{s, Array1, Array2};
9use sklears_core::{
10 error::Result as SklResult,
11 prelude::SklearsError,
12 traits::{Estimator, Fit, Predict},
13};
14use std::f64;
15
16#[derive(Debug, Clone)]
18pub struct TimeSeriesEnsembleConfig {
19 pub window_size: usize,
21 pub n_estimators: usize,
23 pub temporal_overlap: f64,
25 pub use_seasonal_decomposition: bool,
27 pub seasonal_periods: Vec<usize>,
29 pub temporal_decay: f64,
31 pub drift_adaptation: DriftAdaptationStrategy,
33 pub cv_strategy: TimeSeriesCVStrategy,
35 pub aggregation_method: TemporalAggregationMethod,
37}
38
39impl Default for TimeSeriesEnsembleConfig {
40 fn default() -> Self {
41 Self {
42 window_size: 24,
43 n_estimators: 10,
44 temporal_overlap: 0.5,
45 use_seasonal_decomposition: true,
46 seasonal_periods: vec![7, 30, 365], temporal_decay: 0.95,
48 drift_adaptation: DriftAdaptationStrategy::SlidingWindow,
49 cv_strategy: TimeSeriesCVStrategy::TimeSeriesSplit,
50 aggregation_method: TemporalAggregationMethod::WeightedAverage,
51 }
52 }
53}
54
55#[derive(Debug, Clone, PartialEq)]
57pub enum DriftAdaptationStrategy {
58 SlidingWindow,
60 ExponentialForgetting,
62 DynamicEnsemble,
64 OnlineWeightUpdate,
66 SeasonalAdaptation,
68}
69
70#[derive(Debug, Clone, PartialEq)]
72pub enum TimeSeriesCVStrategy {
73 TimeSeriesSplit,
75 BlockedCV,
77 PurgedTimeSeriesSplit,
79 WalkForward,
81 SlidingWindow,
83}
84
85#[derive(Debug, Clone, PartialEq)]
87pub enum TemporalAggregationMethod {
88 SimpleAverage,
90 WeightedAverage,
92 MedianAggregation,
94 ExponentialSmoothing,
96 KalmanFilter,
98 BayesianTemporal,
100}
101
102pub struct TimeSeriesEnsembleClassifier {
104 config: TimeSeriesEnsembleConfig,
105 base_models: Vec<TrainedGradientBoostingRegressor>,
106 temporal_weights: Vec<f64>,
107 seasonal_components: Option<SeasonalComponents>,
108 drift_detector: Option<Box<dyn ConceptDriftDetector>>,
109 is_fitted: bool,
110}
111
112pub struct TimeSeriesEnsembleRegressor {
114 config: TimeSeriesEnsembleConfig,
115 base_models: Vec<TrainedGradientBoostingRegressor>,
116 temporal_weights: Vec<f64>,
117 seasonal_components: Option<SeasonalComponents>,
118 drift_detector: Option<Box<dyn ConceptDriftDetector>>,
119 is_fitted: bool,
120}
121
122#[derive(Debug, Clone)]
124pub struct SeasonalComponents {
125 pub trend: Vec<f64>,
127 pub seasonal: Vec<Vec<f64>>,
129 pub residual: Vec<f64>,
131}
132
133pub trait ConceptDriftDetector: Send + Sync {
135 fn detect_drift(&mut self, observations: &[f64]) -> bool;
137
138 fn reset(&mut self);
140
141 fn get_statistics(&self) -> DriftStatistics;
143}
144
145#[derive(Debug, Clone)]
147pub struct DriftStatistics {
148 pub drift_points: usize,
150 pub avg_drift_interval: f64,
152 pub drift_confidence: f64,
154 pub time_since_drift: usize,
156}
157
158pub struct AdwinDriftDetector {
160 window: Vec<f64>,
161 max_window_size: usize,
162 delta: f64, min_window_size: usize,
164 drift_count: usize,
165 last_drift_time: usize,
166 current_time: usize,
167}
168
169impl AdwinDriftDetector {
170 pub fn new(delta: f64, max_window_size: usize, min_window_size: usize) -> Self {
171 Self {
172 window: Vec::new(),
173 max_window_size,
174 delta,
175 min_window_size,
176 drift_count: 0,
177 last_drift_time: 0,
178 current_time: 0,
179 }
180 }
181
182 fn cut_expression(&self, n0: f64, n1: f64, _u0: f64, _u1: f64) -> f64 {
183 let n = n0 + n1;
184 let delta_prime = self.delta / n;
185 let m = 1.0 / (2.0 * n0) + 1.0 / (2.0 * n1);
186
187 let log_term = (-delta_prime.ln()).max(0.0);
189 (2.0 * log_term / m).sqrt() + (2.0 * log_term / (3.0 * m)) / n
190 }
191}
192
193impl ConceptDriftDetector for AdwinDriftDetector {
194 fn detect_drift(&mut self, observations: &[f64]) -> bool {
195 for &obs in observations {
196 self.window.push(obs);
197 self.current_time += 1;
198
199 if self.window.len() > self.max_window_size {
201 self.window.remove(0);
202 }
203
204 if self.window.len() < self.min_window_size {
205 continue;
206 }
207
208 let n = self.window.len();
210 for i in 1..n {
211 let w0 = &self.window[0..i];
212 let w1 = &self.window[i..];
213
214 if w0.len() < self.min_window_size || w1.len() < self.min_window_size {
215 continue;
216 }
217
218 let u0: f64 = w0.iter().sum::<f64>() / w0.len() as f64;
219 let u1: f64 = w1.iter().sum::<f64>() / w1.len() as f64;
220 let cut_val = self.cut_expression(w0.len() as f64, w1.len() as f64, u0, u1);
221
222 if (u0 - u1).abs() > cut_val {
223 self.window.drain(0..i);
225 self.drift_count += 1;
226 self.last_drift_time = self.current_time;
227 return true;
228 }
229 }
230 }
231 false
232 }
233
234 fn reset(&mut self) {
235 self.window.clear();
236 self.drift_count = 0;
237 self.last_drift_time = 0;
238 self.current_time = 0;
239 }
240
241 fn get_statistics(&self) -> DriftStatistics {
242 let avg_drift_interval = if self.drift_count > 0 {
243 self.current_time as f64 / self.drift_count as f64
244 } else {
245 0.0
246 };
247
248 DriftStatistics {
249 drift_points: self.drift_count,
250 avg_drift_interval,
251 drift_confidence: if self.window.len() >= self.min_window_size {
252 0.95
253 } else {
254 0.0
255 },
256 time_since_drift: self.current_time - self.last_drift_time,
257 }
258 }
259}
260
261impl TimeSeriesEnsembleConfig {
262 pub fn builder() -> TimeSeriesEnsembleConfigBuilder {
263 TimeSeriesEnsembleConfigBuilder::default()
264 }
265}
266
267#[derive(Default)]
268pub struct TimeSeriesEnsembleConfigBuilder {
269 config: TimeSeriesEnsembleConfig,
270}
271
272impl TimeSeriesEnsembleConfigBuilder {
273 pub fn window_size(mut self, window_size: usize) -> Self {
274 self.config.window_size = window_size;
275 self
276 }
277
278 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
279 self.config.n_estimators = n_estimators;
280 self
281 }
282
283 pub fn temporal_overlap(mut self, temporal_overlap: f64) -> Self {
284 self.config.temporal_overlap = temporal_overlap;
285 self
286 }
287
288 pub fn seasonal_periods(mut self, periods: Vec<usize>) -> Self {
289 self.config.seasonal_periods = periods;
290 self
291 }
292
293 pub fn temporal_decay(mut self, decay: f64) -> Self {
294 self.config.temporal_decay = decay;
295 self
296 }
297
298 pub fn drift_adaptation(mut self, strategy: DriftAdaptationStrategy) -> Self {
299 self.config.drift_adaptation = strategy;
300 self
301 }
302
303 pub fn cv_strategy(mut self, strategy: TimeSeriesCVStrategy) -> Self {
304 self.config.cv_strategy = strategy;
305 self
306 }
307
308 pub fn aggregation_method(mut self, method: TemporalAggregationMethod) -> Self {
309 self.config.aggregation_method = method;
310 self
311 }
312
313 pub fn use_seasonal_decomposition(mut self, use_seasonal: bool) -> Self {
314 self.config.use_seasonal_decomposition = use_seasonal;
315 self
316 }
317
318 pub fn build(self) -> TimeSeriesEnsembleConfig {
319 self.config
320 }
321}
322
323impl TimeSeriesEnsembleRegressor {
324 pub fn new(config: TimeSeriesEnsembleConfig) -> Self {
325 Self {
326 config,
327 base_models: Vec::new(),
328 temporal_weights: Vec::new(),
329 seasonal_components: None,
330 drift_detector: None,
331 is_fitted: false,
332 }
333 }
334
335 pub fn builder() -> TimeSeriesEnsembleRegressorBuilder {
336 TimeSeriesEnsembleRegressorBuilder::new()
337 }
338
339 fn create_time_features(&self, data: &Array2<f64>) -> SklResult<Array2<f64>> {
341 let shape = data.shape();
342 let (n_samples, n_features) = (shape[0], shape[1]);
343 let window_size = self.config.window_size;
344
345 if n_samples < window_size {
346 return Err(SklearsError::InvalidInput(format!(
347 "Not enough samples ({}) for window size ({})",
348 n_samples, window_size
349 )));
350 }
351
352 let n_output_samples = n_samples - window_size + 1;
353 let n_output_features = n_features * window_size;
354
355 let mut features = Array2::zeros((n_output_samples, n_output_features));
356
357 for i in 0..n_output_samples {
358 for j in 0..window_size {
359 for k in 0..n_features {
360 features[[i, j * n_features + k]] = data[[i + j, k]];
361 }
362 }
363 }
364
365 Ok(features)
366 }
367
368 fn extract_seasonal_components(&mut self, y: &[f64]) -> SklResult<()> {
370 if !self.config.use_seasonal_decomposition {
371 return Ok(());
372 }
373
374 let n = y.len();
375 let mut trend = vec![0.0; n];
376 let mut seasonal = Vec::new();
377 let mut residual = vec![0.0; n];
378
379 let ma_window = 12.min(n / 4);
381 for i in ma_window / 2..n - ma_window / 2 {
382 let sum: f64 = y[i - ma_window / 2..i + ma_window / 2 + 1].iter().sum();
383 trend[i] = sum / (ma_window + 1) as f64;
384 }
385
386 for &period in &self.config.seasonal_periods {
388 if period > n {
389 continue;
390 }
391
392 let mut period_seasonal = vec![0.0; period];
393 let mut counts = vec![0; period];
394
395 for i in 0..n {
396 let seasonal_idx = i % period;
397 period_seasonal[seasonal_idx] += y[i] - trend[i];
398 counts[seasonal_idx] += 1;
399 }
400
401 for i in 0..period {
403 if counts[i] > 0 {
404 period_seasonal[i] /= counts[i] as f64;
405 }
406 }
407
408 seasonal.push(period_seasonal);
409 }
410
411 for i in 0..n {
413 residual[i] = y[i] - trend[i];
414 for (idx, &period) in self.config.seasonal_periods.iter().enumerate() {
415 if idx < seasonal.len() && period <= n {
416 residual[i] -= seasonal[idx][i % period];
417 }
418 }
419 }
420
421 self.seasonal_components = Some(SeasonalComponents {
422 trend,
423 seasonal,
424 residual,
425 });
426
427 Ok(())
428 }
429
430 fn update_temporal_weights(&mut self, recent_errors: &[f64]) {
432 let n_models = self.base_models.len();
433 if n_models == 0 {
434 return;
435 }
436
437 self.temporal_weights = vec![1.0 / n_models as f64; n_models];
438
439 if recent_errors.len() != n_models {
440 return;
441 }
442
443 let total_error: f64 = recent_errors.iter().sum();
445 if total_error > 0.0 {
446 for i in 0..n_models {
447 let inv_error = 1.0 / (recent_errors[i] + 1e-8);
448 let temporal_factor = self.config.temporal_decay.powi(i as i32);
449 self.temporal_weights[i] = inv_error * temporal_factor;
450 }
451
452 let sum_weights: f64 = self.temporal_weights.iter().sum();
454 if sum_weights > 0.0 {
455 for weight in &mut self.temporal_weights {
456 *weight /= sum_weights;
457 }
458 }
459 }
460 }
461
462 fn aggregate_predictions(&self, predictions: &[Vec<f64>]) -> SklResult<Vec<f64>> {
464 if predictions.is_empty() {
465 return Err(SklearsError::InvalidInput(
466 "No predictions to aggregate".to_string(),
467 ));
468 }
469
470 let n_samples = predictions[0].len();
471 let mut result = vec![0.0; n_samples];
472
473 match self.config.aggregation_method {
474 TemporalAggregationMethod::SimpleAverage => {
475 for pred in predictions {
476 for (i, &p) in pred.iter().enumerate() {
477 result[i] += p;
478 }
479 }
480 for r in &mut result {
481 *r /= predictions.len() as f64;
482 }
483 }
484 TemporalAggregationMethod::WeightedAverage => {
485 for (j, pred) in predictions.iter().enumerate() {
486 let weight = if j < self.temporal_weights.len() {
487 self.temporal_weights[j]
488 } else {
489 1.0 / predictions.len() as f64
490 };
491 for (i, &p) in pred.iter().enumerate() {
492 result[i] += p * weight;
493 }
494 }
495 }
496 TemporalAggregationMethod::MedianAggregation => {
497 for i in 0..n_samples {
498 let mut values: Vec<f64> = predictions.iter().map(|p| p[i]).collect();
499 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
500 result[i] = if values.len() % 2 == 0 {
501 (values[values.len() / 2 - 1] + values[values.len() / 2]) / 2.0
502 } else {
503 values[values.len() / 2]
504 };
505 }
506 }
507 TemporalAggregationMethod::ExponentialSmoothing => {
508 let alpha = 0.3; for i in 0..n_samples {
510 result[i] = predictions[0][i];
511 for j in 1..predictions.len() {
512 result[i] = alpha * predictions[j][i] + (1.0 - alpha) * result[i];
513 }
514 }
515 }
516 _ => {
517 for pred in predictions {
519 for (i, &p) in pred.iter().enumerate() {
520 result[i] += p;
521 }
522 }
523 for r in &mut result {
524 *r /= predictions.len() as f64;
525 }
526 }
527 }
528
529 Ok(result)
530 }
531}
532
533pub struct TimeSeriesEnsembleRegressorBuilder {
534 config: TimeSeriesEnsembleConfig,
535}
536
537impl Default for TimeSeriesEnsembleRegressorBuilder {
538 fn default() -> Self {
539 Self::new()
540 }
541}
542
543impl TimeSeriesEnsembleRegressorBuilder {
544 pub fn new() -> Self {
545 Self {
546 config: TimeSeriesEnsembleConfig::default(),
547 }
548 }
549
550 pub fn config(mut self, config: TimeSeriesEnsembleConfig) -> Self {
551 self.config = config;
552 self
553 }
554
555 pub fn window_size(mut self, window_size: usize) -> Self {
556 self.config.window_size = window_size;
557 self
558 }
559
560 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
561 self.config.n_estimators = n_estimators;
562 self
563 }
564
565 pub fn temporal_decay(mut self, decay: f64) -> Self {
566 self.config.temporal_decay = decay;
567 self
568 }
569
570 pub fn drift_adaptation(mut self, strategy: DriftAdaptationStrategy) -> Self {
571 self.config.drift_adaptation = strategy;
572 self
573 }
574
575 pub fn build(self) -> TimeSeriesEnsembleRegressor {
576 TimeSeriesEnsembleRegressor::new(self.config)
577 }
578}
579
580impl Estimator for TimeSeriesEnsembleRegressor {
581 type Config = TimeSeriesEnsembleConfig;
582 type Error = SklearsError;
583 type Float = f64;
584
585 fn config(&self) -> &Self::Config {
586 &self.config
587 }
588}
589
590impl Fit<Array2<f64>, Vec<f64>> for TimeSeriesEnsembleRegressor {
591 type Fitted = Self;
592
593 #[allow(non_snake_case)]
594 fn fit(mut self, X: &Array2<f64>, y: &Vec<f64>) -> SklResult<Self::Fitted> {
595 if X.shape()[0] != y.len() {
596 return Err(SklearsError::ShapeMismatch {
597 expected: format!("X and y must have same number of samples: {}", y.len()),
598 actual: format!(
599 "X has {} samples but y has {} samples",
600 X.shape()[0],
601 y.len()
602 ),
603 });
604 }
605
606 self.extract_seasonal_components(y)?;
608
609 let time_features = self.create_time_features(X)?;
611 let n_time_samples = time_features.shape()[0];
612
613 let y_time = &y[self.config.window_size - 1..];
615
616 if matches!(
618 self.config.drift_adaptation,
619 DriftAdaptationStrategy::SlidingWindow
620 ) {
621 self.drift_detector = Some(Box::new(AdwinDriftDetector::new(0.002, 1000, 30)));
622 }
623
624 self.base_models.clear();
626 let overlap_size = (n_time_samples as f64 * self.config.temporal_overlap) as usize;
627 let step_size = (n_time_samples - overlap_size) / self.config.n_estimators.max(1);
628
629 for i in 0..self.config.n_estimators {
630 let start_idx = i * step_size;
631 let end_idx = (start_idx + overlap_size + step_size).min(n_time_samples);
632
633 if end_idx <= start_idx {
634 break;
635 }
636
637 let X_subset = time_features.slice(s![start_idx..end_idx, ..]);
639 let y_subset = &y_time[start_idx..end_idx];
640
641 let X_subset_owned = X_subset.to_owned();
643 let y_subset_owned = Array1::from_vec(y_subset.to_vec());
644
645 let model = GradientBoostingRegressor::builder()
646 .n_estimators(50)
647 .learning_rate(0.1)
648 .max_depth(6)
649 .build()
650 .fit(&X_subset_owned, &y_subset_owned)?;
651
652 self.base_models.push(model);
653 }
654
655 self.temporal_weights = vec![1.0 / self.base_models.len() as f64; self.base_models.len()];
657 self.is_fitted = true;
658
659 Ok(self)
660 }
661}
662
663impl Predict<Array2<f64>, Vec<f64>> for TimeSeriesEnsembleRegressor {
664 fn predict(&self, X: &Array2<f64>) -> SklResult<Vec<f64>> {
665 if !self.is_fitted {
666 return Err(SklearsError::NotFitted {
667 operation: "prediction".to_string(),
668 });
669 }
670
671 let time_features = self.create_time_features(X)?;
672 let mut predictions = Vec::new();
673
674 for model in &self.base_models {
676 let pred = model.predict(&time_features)?;
677 let pred_vec: Vec<f64> = pred.to_vec();
678 predictions.push(pred_vec);
679 }
680
681 self.aggregate_predictions(&predictions)
683 }
684}
685
686#[allow(non_snake_case)]
687#[cfg(test)]
688mod tests {
689 use super::*;
690 use scirs2_core::ndarray::Array2;
691
692 #[test]
693 fn test_time_series_ensemble_config() {
694 let config = TimeSeriesEnsembleConfig::builder()
695 .window_size(12)
696 .n_estimators(5)
697 .temporal_decay(0.9)
698 .build();
699
700 assert_eq!(config.window_size, 12);
701 assert_eq!(config.n_estimators, 5);
702 assert_eq!(config.temporal_decay, 0.9);
703 }
704
705 #[test]
706 fn test_adwin_drift_detector() {
707 let mut detector = AdwinDriftDetector::new(0.01, 1000, 30); let stable_data = vec![1.0; 50];
711 assert!(!detector.detect_drift(&stable_data));
712
713 let mut drift_data = vec![1.0; 30];
715 drift_data.extend(vec![100.0; 30]); let drift_detected = detector.detect_drift(&drift_data);
719 assert!(
720 drift_detected,
721 "ADWIN should detect significant drift from 1.0 to 100.0"
722 );
723 }
724
725 #[test]
726 #[allow(non_snake_case)]
727 fn test_time_series_ensemble_basic() {
728 let config = TimeSeriesEnsembleConfig::builder()
729 .window_size(3)
730 .n_estimators(3)
731 .build();
732
733 let ensemble = TimeSeriesEnsembleRegressor::new(config);
734
735 let data = vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0];
737
738 let X = Array2::from_shape_vec((5, 2), data).unwrap();
739 let y: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
740
741 assert_eq!(ensemble.config.window_size, 3);
743 assert_eq!(ensemble.config.n_estimators, 3);
744 }
745
746 #[test]
747 fn test_seasonal_decomposition() {
748 let config = TimeSeriesEnsembleConfig::builder()
749 .window_size(4)
750 .use_seasonal_decomposition(true)
751 .seasonal_periods(vec![7])
752 .build();
753
754 let mut ensemble = TimeSeriesEnsembleRegressor::new(config);
755
756 let y: Vec<f64> = (0..28)
758 .map(|i| (i as f64 / 7.0 * 2.0 * std::f64::consts::PI).sin())
759 .collect();
760
761 ensemble.extract_seasonal_components(&y).unwrap();
762 assert!(ensemble.seasonal_components.is_some());
763
764 let components = ensemble.seasonal_components.unwrap();
765 assert_eq!(components.trend.len(), 28);
766 assert_eq!(components.seasonal.len(), 1);
767 assert_eq!(components.seasonal[0].len(), 7);
768 assert_eq!(components.residual.len(), 28);
769 }
770}