sklears_model_selection/
incremental_evaluation.rs

1//! Incremental Cross-Validation for Streaming/Online Learning
2//!
3//! This module provides specialized cross-validation and evaluation methods for incremental
4//! and online learning scenarios where data arrives in streams and models are updated continuously.
5//! It includes methods for handling concept drift, adaptive evaluation windows, and performance
6//! monitoring in non-stationary environments.
7
8use scirs2_core::ndarray::Array1;
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::SeedableRng;
11use sklears_core::types::Float;
12use std::collections::{HashMap, VecDeque};
13use std::time::{Duration, Instant};
14
15/// Incremental evaluation strategies
16#[derive(Debug, Clone)]
17pub enum IncrementalEvaluationStrategy {
18    /// Sliding window evaluation
19    SlidingWindow {
20        window_size: usize,
21
22        step_size: usize,
23
24        overlap_ratio: Float,
25    },
26    /// Prequential evaluation (test-then-train)
27    Prequential {
28        adaptation_rate: Float,
29
30        forgetting_factor: Float,
31    },
32    /// Holdout evaluation with periodic updates
33    HoldoutEvaluation {
34        holdout_ratio: Float,
35        update_frequency: usize,
36        drift_detection: bool,
37    },
38    /// Block-based evaluation for chunk learning
39    BlockBased {
40        block_size: usize,
41        evaluation_blocks: usize,
42        overlap_blocks: usize,
43    },
44    /// Adaptive window evaluation
45    AdaptiveWindow {
46        min_window_size: usize,
47        max_window_size: usize,
48        adaptation_criterion: AdaptationCriterion,
49    },
50    /// Fading factor evaluation
51    FadingFactor { alpha: Float, minimum_weight: Float },
52    /// Cross-validation for data streams
53    StreamingCrossValidation {
54        n_folds: usize,
55        fold_update_strategy: FoldUpdateStrategy,
56    },
57}
58
59/// Criteria for adaptive window sizing
60#[derive(Debug, Clone)]
61pub enum AdaptationCriterion {
62    /// Performance-based adaptation
63    PerformanceBased { threshold: Float },
64    /// Drift detection-based adaptation
65    DriftBased { drift_detector: DriftDetectorType },
66    /// Variance-based adaptation
67    VarianceBased { variance_threshold: Float },
68    /// Hybrid approach
69    Hybrid { criteria: Vec<AdaptationCriterion> },
70}
71
72/// Types of drift detectors
73#[derive(Debug, Clone)]
74pub enum DriftDetectorType {
75    /// ADWIN (Adaptive Windowing)
76    ADWIN { confidence: Float },
77    /// Page-Hinkley test
78    PageHinkley { threshold: Float, alpha: Float },
79    /// EDDM (Early Drift Detection Method)
80    EDDM { alpha: Float, beta: Float },
81    /// DDM (Drift Detection Method)
82    DDM {
83        warning_level: Float,
84
85        drift_level: Float,
86    },
87    /// Statistical test-based detection
88    StatisticalTest { test_type: String, p_value: Float },
89}
90
91/// Strategies for updating folds in streaming CV
92#[derive(Debug, Clone)]
93pub enum FoldUpdateStrategy {
94    /// Replace oldest data
95    ReplaceOldest,
96    /// Weighted update
97    WeightedUpdate { decay_factor: Float },
98    /// Selective update based on performance
99    SelectiveUpdate { performance_threshold: Float },
100    /// Random replacement
101    RandomReplacement { replacement_rate: Float },
102}
103
104/// Incremental evaluation configuration
105#[derive(Debug, Clone)]
106pub struct IncrementalEvaluationConfig {
107    pub strategy: IncrementalEvaluationStrategy,
108    pub performance_metrics: Vec<String>,
109    pub drift_detection_enabled: bool,
110    pub adaptive_thresholds: bool,
111    pub concept_drift_handling: ConceptDriftHandling,
112    pub memory_budget: Option<usize>,
113    pub evaluation_frequency: usize,
114    pub random_state: Option<u64>,
115}
116
117/// Concept drift handling strategies
118#[derive(Debug, Clone)]
119pub enum ConceptDriftHandling {
120    /// Ignore drift, continue with current model
121    Ignore,
122    /// Reset model completely on drift detection
123    Reset,
124    /// Gradual adaptation to new concept
125    GradualAdaptation { adaptation_rate: Float },
126    /// Ensemble-based handling
127    EnsembleBased { ensemble_size: usize },
128    /// Active learning approach
129    ActiveLearning { uncertainty_threshold: Float },
130}
131
132/// Incremental evaluation result
133#[derive(Debug, Clone)]
134pub struct IncrementalEvaluationResult {
135    pub performance_history: Vec<PerformanceSnapshot>,
136    pub concept_drift_events: Vec<DriftEvent>,
137    pub adaptive_parameters: AdaptiveParameters,
138    pub streaming_statistics: StreamingStatistics,
139    pub window_evolution: Option<WindowEvolution>,
140    pub computational_metrics: ComputationalMetrics,
141}
142
143/// Performance snapshot at a specific time
144#[derive(Debug, Clone)]
145pub struct PerformanceSnapshot {
146    pub timestamp: Instant,
147    pub sample_index: usize,
148    pub performance_score: Float,
149    pub window_size: usize,
150    pub model_age: usize,
151    pub confidence_interval: Option<(Float, Float)>,
152    pub additional_metrics: HashMap<String, Float>,
153}
154
155/// Detected drift event
156#[derive(Debug, Clone)]
157pub struct DriftEvent {
158    pub timestamp: Instant,
159    pub sample_index: usize,
160    pub drift_type: DriftType,
161    pub confidence: Float,
162    pub detection_method: String,
163    pub affected_features: Option<Vec<usize>>,
164}
165
166/// Types of detected drift
167#[derive(Debug, Clone)]
168pub enum DriftType {
169    /// Gradual drift
170    Gradual,
171    /// Sudden drift
172    Sudden,
173    /// Incremental drift
174    Incremental,
175    /// Recurring concept
176    Recurring,
177    /// Unknown drift pattern
178    Unknown,
179}
180
181/// Adaptive parameters tracking
182#[derive(Debug, Clone)]
183pub struct AdaptiveParameters {
184    pub window_size_history: Vec<usize>,
185    pub learning_rate_history: Vec<Float>,
186    pub threshold_history: Vec<Float>,
187    pub adaptation_events: Vec<AdaptationEvent>,
188}
189
190/// Adaptation event
191#[derive(Debug, Clone)]
192pub struct AdaptationEvent {
193    pub timestamp: Instant,
194    pub event_type: AdaptationType,
195    pub old_value: Float,
196    pub new_value: Float,
197    pub trigger_reason: String,
198}
199
200/// Types of adaptations
201#[derive(Debug, Clone)]
202pub enum AdaptationType {
203    /// WindowSizeChange
204    WindowSizeChange,
205    /// LearningRateChange
206    LearningRateChange,
207    /// ThresholdChange
208    ThresholdChange,
209    /// ModelReset
210    ModelReset,
211    /// ParameterAdjustment
212    ParameterAdjustment,
213}
214
215/// Streaming statistics
216#[derive(Debug, Clone)]
217pub struct StreamingStatistics {
218    pub total_samples_processed: usize,
219    pub total_batches_processed: usize,
220    pub average_processing_time: Duration,
221    pub memory_usage_peak: usize,
222    pub model_updates_count: usize,
223    pub evaluation_count: usize,
224    pub drift_rate: Float,
225}
226
227/// Window evolution tracking
228#[derive(Debug, Clone)]
229pub struct WindowEvolution {
230    pub size_evolution: Vec<usize>,
231    pub performance_evolution: Vec<Float>,
232    pub adaptation_points: Vec<usize>,
233    pub efficiency_scores: Vec<Float>,
234}
235
236/// Computational metrics
237#[derive(Debug, Clone)]
238pub struct ComputationalMetrics {
239    pub total_computation_time: Duration,
240    pub average_update_time: Duration,
241    pub memory_efficiency: Float,
242    pub throughput: Float,
243    pub latency_percentiles: HashMap<String, Duration>,
244}
245
246/// Incremental evaluator
247pub struct IncrementalEvaluator {
248    config: IncrementalEvaluationConfig,
249    performance_history: Vec<PerformanceSnapshot>,
250    drift_events: Vec<DriftEvent>,
251    current_window: VecDeque<(Array1<Float>, Float)>, // (features, label)
252    current_predictions: VecDeque<Float>,
253    adaptive_parameters: AdaptiveParameters,
254    drift_detector: Option<Box<dyn DriftDetector>>,
255    rng: StdRng,
256    start_time: Instant,
257    sample_count: usize,
258}
259
260/// Trait for drift detection
261trait DriftDetector: Send + Sync {
262    fn update(&mut self, value: Float) -> bool;
263    fn reset(&mut self);
264    fn get_confidence(&self) -> Float;
265}
266
267/// ADWIN drift detector implementation
268#[derive(Debug)]
269struct ADWINDetector {
270    confidence: Float,
271    window: VecDeque<Float>,
272    total: Float,
273    variance: Float,
274    width: usize,
275}
276
277/// Page-Hinkley drift detector implementation
278#[derive(Debug)]
279struct PageHinkleyDetector {
280    threshold: Float,
281    alpha: Float,
282    x_mean: Float,
283    sample_count: usize,
284    sum: Float,
285    drift_detected: bool,
286}
287
288impl Default for IncrementalEvaluationConfig {
289    fn default() -> Self {
290        Self {
291            strategy: IncrementalEvaluationStrategy::Prequential {
292                adaptation_rate: 0.01,
293                forgetting_factor: 0.95,
294            },
295            performance_metrics: vec!["accuracy".to_string()],
296            drift_detection_enabled: true,
297            adaptive_thresholds: true,
298            concept_drift_handling: ConceptDriftHandling::GradualAdaptation {
299                adaptation_rate: 0.1,
300            },
301            memory_budget: Some(10000),
302            evaluation_frequency: 100,
303            random_state: None,
304        }
305    }
306}
307
308impl IncrementalEvaluator {
309    /// Create a new incremental evaluator
310    pub fn new(config: IncrementalEvaluationConfig) -> Self {
311        let rng = match config.random_state {
312            Some(seed) => StdRng::seed_from_u64(seed),
313            None => {
314                use scirs2_core::random::thread_rng;
315                StdRng::from_rng(&mut thread_rng())
316            }
317        };
318
319        let drift_detector = if config.drift_detection_enabled {
320            Some(Self::create_drift_detector(&config.strategy))
321        } else {
322            None
323        };
324
325        Self {
326            config,
327            performance_history: Vec::new(),
328            drift_events: Vec::new(),
329            current_window: VecDeque::new(),
330            current_predictions: VecDeque::new(),
331            adaptive_parameters: AdaptiveParameters {
332                window_size_history: Vec::new(),
333                learning_rate_history: Vec::new(),
334                threshold_history: Vec::new(),
335                adaptation_events: Vec::new(),
336            },
337            drift_detector,
338            rng,
339            start_time: Instant::now(),
340            sample_count: 0,
341        }
342    }
343
344    /// Process a new data point and evaluate incrementally
345    pub fn update<F>(
346        &mut self,
347        features: Array1<Float>,
348        true_label: Float,
349        prediction: Float,
350        model_update_fn: F,
351    ) -> Result<Option<PerformanceSnapshot>, Box<dyn std::error::Error>>
352    where
353        F: FnOnce(&Array1<Float>, Float),
354    {
355        let _update_start = Instant::now();
356        self.sample_count += 1;
357
358        // Add to current window
359        self.current_window
360            .push_back((features.clone(), true_label));
361        self.current_predictions.push_back(prediction);
362
363        // Handle memory budget
364        if let Some(budget) = self.config.memory_budget {
365            while self.current_window.len() > budget {
366                self.current_window.pop_front();
367                self.current_predictions.pop_front();
368            }
369        }
370
371        // Check for concept drift
372        let error = (prediction - true_label).abs();
373        let drift_detected = if let Some(ref mut detector) = self.drift_detector {
374            detector.update(error)
375        } else {
376            false
377        };
378
379        if drift_detected {
380            self.handle_concept_drift()?;
381        }
382
383        // Perform evaluation based on strategy
384        let performance_snapshot = match &self.config.strategy {
385            IncrementalEvaluationStrategy::SlidingWindow { .. } => {
386                self.evaluate_sliding_window()?
387            }
388            IncrementalEvaluationStrategy::Prequential { .. } => {
389                self.evaluate_prequential(prediction, true_label)?
390            }
391            IncrementalEvaluationStrategy::HoldoutEvaluation { .. } => self.evaluate_holdout()?,
392            IncrementalEvaluationStrategy::BlockBased { .. } => self.evaluate_block_based()?,
393            IncrementalEvaluationStrategy::AdaptiveWindow { .. } => {
394                self.evaluate_adaptive_window()?
395            }
396            IncrementalEvaluationStrategy::FadingFactor { .. } => {
397                self.evaluate_fading_factor(prediction, true_label)?
398            }
399            IncrementalEvaluationStrategy::StreamingCrossValidation { .. } => {
400                self.evaluate_streaming_cv()?
401            }
402        };
403
404        // Update model (test-then-train paradigm)
405        model_update_fn(&features, true_label);
406
407        // Check if evaluation should be performed
408        if self.sample_count % self.config.evaluation_frequency == 0 {
409            if let Some(snapshot) = performance_snapshot {
410                self.performance_history.push(snapshot.clone());
411                Ok(Some(snapshot))
412            } else {
413                Ok(None)
414            }
415        } else {
416            Ok(None)
417        }
418    }
419
420    /// Get the final evaluation result
421    pub fn finalize(self) -> IncrementalEvaluationResult {
422        let total_time = self.start_time.elapsed();
423
424        let streaming_statistics = StreamingStatistics {
425            total_samples_processed: self.sample_count,
426            total_batches_processed: self.performance_history.len(),
427            average_processing_time: if self.sample_count > 0 {
428                total_time / self.sample_count as u32
429            } else {
430                Duration::from_secs(0)
431            },
432            memory_usage_peak: self.current_window.len(),
433            model_updates_count: self.sample_count,
434            evaluation_count: self.performance_history.len(),
435            drift_rate: self.drift_events.len() as Float / self.sample_count.max(1) as Float,
436        };
437
438        let window_evolution = if !self.adaptive_parameters.window_size_history.is_empty() {
439            Some(WindowEvolution {
440                size_evolution: self.adaptive_parameters.window_size_history.clone(),
441                performance_evolution: self
442                    .performance_history
443                    .iter()
444                    .map(|s| s.performance_score)
445                    .collect(),
446                adaptation_points: self
447                    .adaptive_parameters
448                    .adaptation_events
449                    .iter()
450                    .map(|e| e.timestamp.duration_since(self.start_time).as_millis() as usize)
451                    .collect(),
452                efficiency_scores: vec![0.8; self.performance_history.len()], // Placeholder
453            })
454        } else {
455            None
456        };
457
458        let computational_metrics = ComputationalMetrics {
459            total_computation_time: total_time,
460            average_update_time: if self.sample_count > 0 {
461                total_time / self.sample_count as u32
462            } else {
463                Duration::from_secs(0)
464            },
465            memory_efficiency: 0.8, // Placeholder
466            throughput: self.sample_count as Float / total_time.as_secs_f64() as Float,
467            latency_percentiles: HashMap::new(), // Could add percentile calculations
468        };
469
470        IncrementalEvaluationResult {
471            performance_history: self.performance_history,
472            concept_drift_events: self.drift_events,
473            adaptive_parameters: self.adaptive_parameters,
474            streaming_statistics,
475            window_evolution,
476            computational_metrics,
477        }
478    }
479
480    /// Create appropriate drift detector based on strategy
481    fn create_drift_detector(_strategy: &IncrementalEvaluationStrategy) -> Box<dyn DriftDetector> {
482        // Default to ADWIN detector
483        Box::new(ADWINDetector::new(0.002))
484    }
485
486    /// Handle concept drift detection
487    fn handle_concept_drift(&mut self) -> Result<(), Box<dyn std::error::Error>> {
488        let drift_event = DriftEvent {
489            timestamp: Instant::now(),
490            sample_index: self.sample_count,
491            drift_type: DriftType::Sudden, // Simplified classification
492            confidence: self
493                .drift_detector
494                .as_ref()
495                .map_or(0.0, |d| d.get_confidence()),
496            detection_method: "ADWIN".to_string(),
497            affected_features: None,
498        };
499
500        self.drift_events.push(drift_event);
501
502        // Handle drift based on configuration
503        match &self.config.concept_drift_handling {
504            ConceptDriftHandling::Reset => {
505                self.current_window.clear();
506                self.current_predictions.clear();
507                if let Some(ref mut detector) = self.drift_detector {
508                    detector.reset();
509                }
510            }
511            ConceptDriftHandling::GradualAdaptation { adaptation_rate } => {
512                // Reduce window size gradually
513                let new_size =
514                    (self.current_window.len() as Float * (1.0 - adaptation_rate)) as usize;
515                while self.current_window.len() > new_size {
516                    self.current_window.pop_front();
517                    self.current_predictions.pop_front();
518                }
519            }
520            _ => {
521                // Other strategies would be implemented here
522            }
523        }
524
525        Ok(())
526    }
527
528    /// Sliding window evaluation
529    fn evaluate_sliding_window(
530        &mut self,
531    ) -> Result<Option<PerformanceSnapshot>, Box<dyn std::error::Error>> {
532        let (window_size, step_size, _overlap_ratio) = match &self.config.strategy {
533            IncrementalEvaluationStrategy::SlidingWindow {
534                window_size,
535                step_size,
536                overlap_ratio,
537            } => (*window_size, *step_size, *overlap_ratio),
538            _ => unreachable!(),
539        };
540
541        if self.current_window.len() >= window_size && self.sample_count % step_size == 0 {
542            let recent_predictions: Vec<Float> = self
543                .current_predictions
544                .iter()
545                .rev()
546                .take(window_size)
547                .cloned()
548                .collect();
549
550            let recent_labels: Vec<Float> = self
551                .current_window
552                .iter()
553                .rev()
554                .take(window_size)
555                .map(|(_, label)| *label)
556                .collect();
557
558            let accuracy = recent_predictions
559                .iter()
560                .zip(recent_labels.iter())
561                .map(|(&pred, &label)| {
562                    if (pred > 0.5) == (label > 0.5) {
563                        1.0
564                    } else {
565                        0.0
566                    }
567                })
568                .sum::<Float>()
569                / recent_predictions.len() as Float;
570
571            Ok(Some(PerformanceSnapshot {
572                timestamp: Instant::now(),
573                sample_index: self.sample_count,
574                performance_score: accuracy,
575                window_size,
576                model_age: self.sample_count,
577                confidence_interval: None,
578                additional_metrics: HashMap::new(),
579            }))
580        } else {
581            Ok(None)
582        }
583    }
584
585    /// Prequential evaluation (test-then-train)
586    fn evaluate_prequential(
587        &mut self,
588        prediction: Float,
589        true_label: Float,
590    ) -> Result<Option<PerformanceSnapshot>, Box<dyn std::error::Error>> {
591        let (adaptation_rate, _forgetting_factor) = match &self.config.strategy {
592            IncrementalEvaluationStrategy::Prequential {
593                adaptation_rate,
594                forgetting_factor,
595            } => (*adaptation_rate, *forgetting_factor),
596            _ => unreachable!(),
597        };
598
599        // Simple prequential accuracy calculation
600        let is_correct = (prediction > 0.5) == (true_label > 0.5);
601        let current_accuracy = if is_correct { 1.0 } else { 0.0 };
602
603        // Update running average (exponential moving average)
604        let previous_performance = self
605            .performance_history
606            .last()
607            .map(|s| s.performance_score)
608            .unwrap_or(0.5);
609
610        let updated_performance =
611            (1.0 - adaptation_rate) * previous_performance + adaptation_rate * current_accuracy;
612
613        Ok(Some(PerformanceSnapshot {
614            timestamp: Instant::now(),
615            sample_index: self.sample_count,
616            performance_score: updated_performance,
617            window_size: 1,
618            model_age: self.sample_count,
619            confidence_interval: None,
620            additional_metrics: HashMap::new(),
621        }))
622    }
623
624    /// Holdout evaluation
625    fn evaluate_holdout(
626        &mut self,
627    ) -> Result<Option<PerformanceSnapshot>, Box<dyn std::error::Error>> {
628        let (holdout_ratio, update_frequency, _drift_detection) = match &self.config.strategy {
629            IncrementalEvaluationStrategy::HoldoutEvaluation {
630                holdout_ratio,
631                update_frequency,
632                drift_detection,
633            } => (*holdout_ratio, *update_frequency, *drift_detection),
634            _ => unreachable!(),
635        };
636
637        if self.sample_count % update_frequency == 0 && !self.current_window.is_empty() {
638            let holdout_size = (self.current_window.len() as Float * holdout_ratio) as usize;
639
640            if holdout_size > 0 {
641                let holdout_predictions: Vec<Float> = self
642                    .current_predictions
643                    .iter()
644                    .rev()
645                    .take(holdout_size)
646                    .cloned()
647                    .collect();
648
649                let holdout_labels: Vec<Float> = self
650                    .current_window
651                    .iter()
652                    .rev()
653                    .take(holdout_size)
654                    .map(|(_, label)| *label)
655                    .collect();
656
657                let accuracy = holdout_predictions
658                    .iter()
659                    .zip(holdout_labels.iter())
660                    .map(|(&pred, &label)| {
661                        if (pred > 0.5) == (label > 0.5) {
662                            1.0
663                        } else {
664                            0.0
665                        }
666                    })
667                    .sum::<Float>()
668                    / holdout_predictions.len() as Float;
669
670                Ok(Some(PerformanceSnapshot {
671                    timestamp: Instant::now(),
672                    sample_index: self.sample_count,
673                    performance_score: accuracy,
674                    window_size: holdout_size,
675                    model_age: self.sample_count,
676                    confidence_interval: None,
677                    additional_metrics: HashMap::new(),
678                }))
679            } else {
680                Ok(None)
681            }
682        } else {
683            Ok(None)
684        }
685    }
686
687    /// Block-based evaluation
688    fn evaluate_block_based(
689        &mut self,
690    ) -> Result<Option<PerformanceSnapshot>, Box<dyn std::error::Error>> {
691        let (block_size, _evaluation_blocks, _overlap_blocks) = match &self.config.strategy {
692            IncrementalEvaluationStrategy::BlockBased {
693                block_size,
694                evaluation_blocks,
695                overlap_blocks,
696            } => (*block_size, *evaluation_blocks, *overlap_blocks),
697            _ => unreachable!(),
698        };
699
700        if self.sample_count % block_size == 0 && self.current_window.len() >= block_size {
701            let block_predictions: Vec<Float> = self
702                .current_predictions
703                .iter()
704                .rev()
705                .take(block_size)
706                .cloned()
707                .collect();
708
709            let block_labels: Vec<Float> = self
710                .current_window
711                .iter()
712                .rev()
713                .take(block_size)
714                .map(|(_, label)| *label)
715                .collect();
716
717            let accuracy = block_predictions
718                .iter()
719                .zip(block_labels.iter())
720                .map(|(&pred, &label)| {
721                    if (pred > 0.5) == (label > 0.5) {
722                        1.0
723                    } else {
724                        0.0
725                    }
726                })
727                .sum::<Float>()
728                / block_predictions.len() as Float;
729
730            Ok(Some(PerformanceSnapshot {
731                timestamp: Instant::now(),
732                sample_index: self.sample_count,
733                performance_score: accuracy,
734                window_size: block_size,
735                model_age: self.sample_count,
736                confidence_interval: None,
737                additional_metrics: HashMap::new(),
738            }))
739        } else {
740            Ok(None)
741        }
742    }
743
744    /// Adaptive window evaluation
745    fn evaluate_adaptive_window(
746        &mut self,
747    ) -> Result<Option<PerformanceSnapshot>, Box<dyn std::error::Error>> {
748        let (min_window_size, max_window_size, _adaptation_criterion) = match &self.config.strategy
749        {
750            IncrementalEvaluationStrategy::AdaptiveWindow {
751                min_window_size,
752                max_window_size,
753                adaptation_criterion,
754            } => (*min_window_size, *max_window_size, adaptation_criterion),
755            _ => unreachable!(),
756        };
757
758        // Simple adaptive logic: adjust window size based on recent performance variance
759        if self.performance_history.len() >= 3 {
760            let recent_scores: Vec<Float> = self
761                .performance_history
762                .iter()
763                .rev()
764                .take(3)
765                .map(|s| s.performance_score)
766                .collect();
767
768            let mean_score = recent_scores.iter().sum::<Float>() / recent_scores.len() as Float;
769            let variance = recent_scores
770                .iter()
771                .map(|&score| (score - mean_score).powi(2))
772                .sum::<Float>()
773                / recent_scores.len() as Float;
774
775            // Adjust window size based on variance
776            let current_window_size = self.current_window.len();
777            let new_window_size = if variance > 0.1 {
778                // High variance, reduce window size
779                (current_window_size / 2).max(min_window_size)
780            } else {
781                // Low variance, can increase window size
782                (current_window_size * 2).min(max_window_size)
783            };
784
785            // Record window size change
786            if new_window_size != current_window_size {
787                self.adaptive_parameters
788                    .window_size_history
789                    .push(new_window_size);
790
791                // Adjust actual window
792                while self.current_window.len() > new_window_size {
793                    self.current_window.pop_front();
794                    self.current_predictions.pop_front();
795                }
796            }
797        }
798
799        // Evaluate using current window
800        if !self.current_window.is_empty() {
801            let predictions: Vec<Float> = self.current_predictions.iter().cloned().collect();
802            let labels: Vec<Float> = self
803                .current_window
804                .iter()
805                .map(|(_, label)| *label)
806                .collect();
807
808            let accuracy = predictions
809                .iter()
810                .zip(labels.iter())
811                .map(|(&pred, &label)| {
812                    if (pred > 0.5) == (label > 0.5) {
813                        1.0
814                    } else {
815                        0.0
816                    }
817                })
818                .sum::<Float>()
819                / predictions.len() as Float;
820
821            Ok(Some(PerformanceSnapshot {
822                timestamp: Instant::now(),
823                sample_index: self.sample_count,
824                performance_score: accuracy,
825                window_size: self.current_window.len(),
826                model_age: self.sample_count,
827                confidence_interval: None,
828                additional_metrics: HashMap::new(),
829            }))
830        } else {
831            Ok(None)
832        }
833    }
834
835    /// Fading factor evaluation
836    fn evaluate_fading_factor(
837        &mut self,
838        prediction: Float,
839        true_label: Float,
840    ) -> Result<Option<PerformanceSnapshot>, Box<dyn std::error::Error>> {
841        let (alpha, _minimum_weight) = match &self.config.strategy {
842            IncrementalEvaluationStrategy::FadingFactor {
843                alpha,
844                minimum_weight,
845            } => (*alpha, *minimum_weight),
846            _ => unreachable!(),
847        };
848
849        // Weighted evaluation with fading factor
850        let is_correct = (prediction > 0.5) == (true_label > 0.5);
851        let current_accuracy = if is_correct { 1.0 } else { 0.0 };
852
853        let previous_performance = self
854            .performance_history
855            .last()
856            .map(|s| s.performance_score)
857            .unwrap_or(0.5);
858
859        let faded_performance = alpha * current_accuracy + (1.0 - alpha) * previous_performance;
860
861        Ok(Some(PerformanceSnapshot {
862            timestamp: Instant::now(),
863            sample_index: self.sample_count,
864            performance_score: faded_performance,
865            window_size: 1,
866            model_age: self.sample_count,
867            confidence_interval: None,
868            additional_metrics: HashMap::new(),
869        }))
870    }
871
872    /// Streaming cross-validation evaluation
873    fn evaluate_streaming_cv(
874        &mut self,
875    ) -> Result<Option<PerformanceSnapshot>, Box<dyn std::error::Error>> {
876        let (n_folds, _fold_update_strategy) = match &self.config.strategy {
877            IncrementalEvaluationStrategy::StreamingCrossValidation {
878                n_folds,
879                fold_update_strategy,
880            } => (*n_folds, fold_update_strategy),
881            _ => unreachable!(),
882        };
883
884        if self.current_window.len() >= n_folds && self.sample_count % 10 == 0 {
885            // Simple streaming CV: evaluate on rotating folds
886            let fold_size = self.current_window.len() / n_folds;
887            let mut fold_scores = Vec::new();
888
889            for fold in 0..n_folds {
890                let test_start = fold * fold_size;
891                let test_end = if fold == n_folds - 1 {
892                    self.current_window.len()
893                } else {
894                    (fold + 1) * fold_size
895                };
896
897                if test_end <= self.current_predictions.len() {
898                    let fold_predictions: Vec<Float> = self
899                        .current_predictions
900                        .iter()
901                        .skip(test_start)
902                        .take(test_end - test_start)
903                        .cloned()
904                        .collect();
905
906                    let fold_labels: Vec<Float> = self
907                        .current_window
908                        .iter()
909                        .skip(test_start)
910                        .take(test_end - test_start)
911                        .map(|(_, label)| *label)
912                        .collect();
913
914                    let fold_accuracy = fold_predictions
915                        .iter()
916                        .zip(fold_labels.iter())
917                        .map(|(&pred, &label)| {
918                            if (pred > 0.5) == (label > 0.5) {
919                                1.0
920                            } else {
921                                0.0
922                            }
923                        })
924                        .sum::<Float>()
925                        / fold_predictions.len() as Float;
926
927                    fold_scores.push(fold_accuracy);
928                }
929            }
930
931            if !fold_scores.is_empty() {
932                let mean_accuracy = fold_scores.iter().sum::<Float>() / fold_scores.len() as Float;
933                let std_accuracy = {
934                    let variance = fold_scores
935                        .iter()
936                        .map(|&score| (score - mean_accuracy).powi(2))
937                        .sum::<Float>()
938                        / fold_scores.len() as Float;
939                    variance.sqrt()
940                };
941
942                let confidence_interval = (
943                    mean_accuracy - 1.96 * std_accuracy,
944                    mean_accuracy + 1.96 * std_accuracy,
945                );
946
947                Ok(Some(PerformanceSnapshot {
948                    timestamp: Instant::now(),
949                    sample_index: self.sample_count,
950                    performance_score: mean_accuracy,
951                    window_size: self.current_window.len(),
952                    model_age: self.sample_count,
953                    confidence_interval: Some(confidence_interval),
954                    additional_metrics: {
955                        let mut metrics = HashMap::new();
956                        metrics.insert("std_accuracy".to_string(), std_accuracy);
957                        metrics
958                    },
959                }))
960            } else {
961                Ok(None)
962            }
963        } else {
964            Ok(None)
965        }
966    }
967}
968
969impl ADWINDetector {
970    fn new(confidence: Float) -> Self {
971        Self {
972            confidence,
973            window: VecDeque::new(),
974            total: 0.0,
975            variance: 0.0,
976            width: 0,
977        }
978    }
979}
980
981impl DriftDetector for ADWINDetector {
982    fn update(&mut self, value: Float) -> bool {
983        self.window.push_back(value);
984        self.total += value;
985        self.width += 1;
986
987        // Simplified ADWIN logic
988        if self.width > 100 {
989            // Check for significant difference between window halves
990            let half = self.width / 2;
991            let first_half_mean = self.window.iter().take(half).sum::<Float>() / half as Float;
992            let second_half_mean =
993                self.window.iter().skip(half).sum::<Float>() / (self.width - half) as Float;
994
995            let difference = (first_half_mean - second_half_mean).abs();
996            difference > self.confidence
997        } else {
998            false
999        }
1000    }
1001
1002    fn reset(&mut self) {
1003        self.window.clear();
1004        self.total = 0.0;
1005        self.variance = 0.0;
1006        self.width = 0;
1007    }
1008
1009    fn get_confidence(&self) -> Float {
1010        if self.width > 0 {
1011            self.variance / self.width as Float
1012        } else {
1013            0.0
1014        }
1015    }
1016}
1017
1018impl PageHinkleyDetector {
1019    fn new(threshold: Float, alpha: Float) -> Self {
1020        Self {
1021            threshold,
1022            alpha,
1023            x_mean: 0.0,
1024            sample_count: 0,
1025            sum: 0.0,
1026            drift_detected: false,
1027        }
1028    }
1029}
1030
1031impl DriftDetector for PageHinkleyDetector {
1032    fn update(&mut self, value: Float) -> bool {
1033        self.sample_count += 1;
1034
1035        if self.sample_count == 1 {
1036            self.x_mean = value;
1037            return false;
1038        }
1039
1040        // Update mean
1041        self.x_mean = self.x_mean + (value - self.x_mean) / self.sample_count as Float;
1042
1043        // Update Page-Hinkley statistic
1044        self.sum = (self.sum + value - self.x_mean - self.alpha).max(0.0);
1045
1046        // Check for drift
1047        if self.sum > self.threshold {
1048            self.drift_detected = true;
1049            true
1050        } else {
1051            false
1052        }
1053    }
1054
1055    fn reset(&mut self) {
1056        self.x_mean = 0.0;
1057        self.sample_count = 0;
1058        self.sum = 0.0;
1059        self.drift_detected = false;
1060    }
1061
1062    fn get_confidence(&self) -> Float {
1063        if self.threshold > 0.0 {
1064            (self.sum / self.threshold).min(1.0)
1065        } else {
1066            0.0
1067        }
1068    }
1069}
1070
1071/// Convenience function for incremental evaluation
1072pub fn evaluate_incremental_stream<F>(
1073    data_stream: impl Iterator<Item = (Array1<Float>, Float, Float)>, // (features, label, prediction)
1074    model_update_fn: F,
1075    config: Option<IncrementalEvaluationConfig>,
1076) -> Result<IncrementalEvaluationResult, Box<dyn std::error::Error>>
1077where
1078    F: Fn(&Array1<Float>, Float),
1079{
1080    let config = config.unwrap_or_default();
1081    let mut evaluator = IncrementalEvaluator::new(config);
1082
1083    for (features, label, prediction) in data_stream {
1084        evaluator.update(features, label, prediction, &model_update_fn)?;
1085    }
1086
1087    Ok(evaluator.finalize())
1088}
1089
1090#[allow(non_snake_case)]
1091#[cfg(test)]
1092mod tests {
1093    use super::*;
1094
1095    #[test]
1096    fn test_incremental_evaluator_creation() {
1097        let config = IncrementalEvaluationConfig::default();
1098        let evaluator = IncrementalEvaluator::new(config);
1099        assert_eq!(evaluator.sample_count, 0);
1100    }
1101
1102    #[test]
1103    fn test_prequential_evaluation() {
1104        let config = IncrementalEvaluationConfig {
1105            strategy: IncrementalEvaluationStrategy::Prequential {
1106                adaptation_rate: 0.1,
1107                forgetting_factor: 0.9,
1108            },
1109            evaluation_frequency: 1,
1110            ..Default::default()
1111        };
1112
1113        let mut evaluator = IncrementalEvaluator::new(config);
1114
1115        let features = Array1::from_vec(vec![0.5, 0.3, 0.8]);
1116        let model_update_fn = |_: &Array1<Float>, _: Float| {}; // No-op update
1117
1118        let result = evaluator
1119            .update(features, 1.0, 0.8, model_update_fn)
1120            .unwrap();
1121        assert!(result.is_some());
1122
1123        let snapshot = result.unwrap();
1124        assert!(snapshot.performance_score >= 0.0 && snapshot.performance_score <= 1.0);
1125    }
1126
1127    #[test]
1128    fn test_sliding_window_evaluation() {
1129        let config = IncrementalEvaluationConfig {
1130            strategy: IncrementalEvaluationStrategy::SlidingWindow {
1131                window_size: 5,
1132                step_size: 5,
1133                overlap_ratio: 0.0,
1134            },
1135            evaluation_frequency: 5,
1136            ..Default::default()
1137        };
1138
1139        let mut evaluator = IncrementalEvaluator::new(config);
1140        let model_update_fn = |_: &Array1<Float>, _: Float| {};
1141
1142        // Add 5 samples
1143        for i in 0..5 {
1144            let features = Array1::from_vec(vec![i as Float * 0.1]);
1145            let label = if i % 2 == 0 { 1.0 } else { 0.0 };
1146            let prediction = if i % 2 == 0 { 0.8 } else { 0.2 };
1147
1148            let result = evaluator
1149                .update(features, label, prediction, &model_update_fn)
1150                .unwrap();
1151
1152            if i == 4 {
1153                // Last sample should trigger evaluation
1154                assert!(result.is_some());
1155            }
1156        }
1157    }
1158
1159    #[test]
1160    fn test_drift_detector() {
1161        let mut detector = ADWINDetector::new(0.1);
1162
1163        // Add some stable values
1164        for _ in 0..50 {
1165            assert!(!detector.update(0.1));
1166        }
1167
1168        // Add some values that should trigger drift
1169        for _ in 0..60 {
1170            detector.update(0.9);
1171        }
1172
1173        // Should eventually detect drift (simplified test)
1174        assert!(detector.get_confidence() >= 0.0);
1175    }
1176
1177    #[test]
1178    fn test_streaming_evaluation() {
1179        let data_stream = (0..20).map(|i| {
1180            let features = Array1::from_vec(vec![i as Float * 0.05]);
1181            let label = if i % 2 == 0 { 1.0 } else { 0.0 };
1182            let prediction = if i % 2 == 0 { 0.8 } else { 0.3 };
1183            (features, label, prediction)
1184        });
1185
1186        let model_update_fn = |_: &Array1<Float>, _: Float| {};
1187
1188        let config = IncrementalEvaluationConfig {
1189            evaluation_frequency: 10, // Evaluate every 10 samples instead of default 100
1190            ..Default::default()
1191        };
1192
1193        let result =
1194            evaluate_incremental_stream(data_stream, model_update_fn, Some(config)).unwrap();
1195
1196        assert!(result.streaming_statistics.total_samples_processed == 20);
1197        assert!(!result.performance_history.is_empty());
1198    }
1199}