sklears_model_selection/
early_stopping.rs

1//! Early stopping strategies for hyperparameter optimization
2//!
3//! This module implements various early stopping criteria that can be used to terminate
4//! optimization algorithms early when certain conditions are met, saving computational
5//! resources while maintaining optimization quality.
6
7use sklears_core::error::Result;
8use std::collections::VecDeque;
9
10/// Early stopping criterion configuration
11#[derive(Debug, Clone)]
12pub struct EarlyStoppingConfig {
13    /// Minimum number of iterations before early stopping can trigger
14    pub min_iterations: usize,
15    /// Patience: number of iterations to wait without improvement
16    pub patience: usize,
17    /// Minimum improvement threshold to reset patience
18    pub min_delta: f64,
19    /// Whether to restore best weights when stopping
20    pub restore_best_weights: bool,
21    /// Whether to monitor for maximum or minimum (true for max, false for min)
22    pub maximize: bool,
23    /// Baseline score to compare against
24    pub baseline: Option<f64>,
25    /// Smoothing factor for exponential moving average
26    pub smoothing_factor: f64,
27}
28
29impl Default for EarlyStoppingConfig {
30    fn default() -> Self {
31        Self {
32            min_iterations: 10,
33            patience: 10,
34            min_delta: 1e-4,
35            restore_best_weights: true,
36            maximize: true,
37            baseline: None,
38            smoothing_factor: 0.9,
39        }
40    }
41}
42
43/// Different early stopping strategies
44#[derive(Debug, Clone)]
45pub enum EarlyStoppingStrategy {
46    /// Stop when no improvement for patience iterations
47    Patience,
48    /// Stop when improvement rate falls below threshold
49    ImprovementRate(f64),
50    /// Stop when exponential moving average converges
51    ExponentialMovingAverage,
52    /// Stop when validation loss starts increasing (overfitting detection)
53    ValidationLoss,
54    /// Stop when relative improvement becomes small
55    RelativeImprovement(f64),
56    /// Stop when absolute improvement becomes small
57    AbsoluteImprovement(f64),
58    /// Combination of multiple strategies (OR logic)
59    Combined(Vec<EarlyStoppingStrategy>),
60}
61
62/// Early stopping state tracker
63#[derive(Debug, Clone)]
64pub struct EarlyStoppingState {
65    /// Best score seen so far
66    pub best_score: f64,
67    /// Iteration when best score was achieved
68    pub best_iteration: usize,
69    /// Number of iterations without improvement
70    pub patience_counter: usize,
71    /// All scores seen so far
72    pub score_history: Vec<f64>,
73    /// Exponential moving average of scores
74    pub ema_score: f64,
75    /// Whether EMA has been initialized
76    pub ema_initialized: bool,
77    /// Recent scores for trend analysis
78    pub recent_scores: VecDeque<f64>,
79    /// Maximum window size for recent scores
80    pub window_size: usize,
81}
82
83impl EarlyStoppingState {
84    fn new(window_size: usize) -> Self {
85        Self {
86            best_score: f64::NEG_INFINITY,
87            best_iteration: 0,
88            patience_counter: 0,
89            score_history: Vec::new(),
90            ema_score: 0.0,
91            ema_initialized: false,
92            recent_scores: VecDeque::with_capacity(window_size),
93            window_size,
94        }
95    }
96
97    fn update(&mut self, score: f64, iteration: usize, config: &EarlyStoppingConfig) {
98        self.score_history.push(score);
99
100        // Update recent scores window
101        if self.recent_scores.len() >= self.window_size {
102            self.recent_scores.pop_front();
103        }
104        self.recent_scores.push_back(score);
105
106        // Update exponential moving average
107        if !self.ema_initialized {
108            self.ema_score = score;
109            self.ema_initialized = true;
110        } else {
111            self.ema_score =
112                config.smoothing_factor * self.ema_score + (1.0 - config.smoothing_factor) * score;
113        }
114
115        // Check if this is the best score
116        let is_improvement = if config.maximize {
117            score > self.best_score + config.min_delta
118        } else {
119            score < self.best_score - config.min_delta
120        };
121
122        if is_improvement {
123            self.best_score = score;
124            self.best_iteration = iteration;
125            self.patience_counter = 0;
126        } else {
127            self.patience_counter += 1;
128        }
129    }
130
131    fn improvement_rate(&self) -> f64 {
132        if self.score_history.len() < 2 {
133            return 0.0;
134        }
135
136        let recent_window = self.score_history.len().min(5);
137        let recent_scores = &self.score_history[self.score_history.len() - recent_window..];
138
139        if recent_scores.len() < 2 {
140            return 0.0;
141        }
142
143        let start_score = recent_scores[0];
144        let end_score = recent_scores[recent_scores.len() - 1];
145
146        if start_score.abs() < 1e-8 {
147            return 0.0;
148        }
149
150        (end_score - start_score) / start_score.abs()
151    }
152
153    fn relative_improvement(&self) -> f64 {
154        if self.score_history.len() < 2 {
155            return f64::INFINITY;
156        }
157
158        let current = self.score_history[self.score_history.len() - 1];
159        let previous = self.score_history[self.score_history.len() - 2];
160
161        if previous.abs() < 1e-8 {
162            return f64::INFINITY;
163        }
164
165        (current - previous).abs() / previous.abs()
166    }
167
168    fn absolute_improvement(&self) -> f64 {
169        if self.score_history.len() < 2 {
170            return f64::INFINITY;
171        }
172
173        let current = self.score_history[self.score_history.len() - 1];
174        let previous = self.score_history[self.score_history.len() - 2];
175
176        (current - previous).abs()
177    }
178
179    fn ema_convergence(&self, threshold: f64) -> bool {
180        if self.recent_scores.len() < self.window_size {
181            return false;
182        }
183
184        let recent_avg: f64 =
185            self.recent_scores.iter().sum::<f64>() / self.recent_scores.len() as f64;
186        (self.ema_score - recent_avg).abs() < threshold
187    }
188
189    fn is_overfitting(&self, lookback: usize) -> bool {
190        if self.score_history.len() < lookback + 2 {
191            return false;
192        }
193
194        let len = self.score_history.len();
195        let recent_scores = &self.score_history[len - lookback..];
196        let previous_scores = &self.score_history[len - lookback - lookback..len - lookback];
197
198        let recent_avg: f64 = recent_scores.iter().sum::<f64>() / recent_scores.len() as f64;
199        let previous_avg: f64 = previous_scores.iter().sum::<f64>() / previous_scores.len() as f64;
200
201        // For maximization: overfitting if recent scores are decreasing
202        // For minimization: overfitting if recent scores are increasing
203        recent_avg < previous_avg
204    }
205}
206
207/// Early stopping monitor
208pub struct EarlyStoppingMonitor {
209    strategy: EarlyStoppingStrategy,
210    config: EarlyStoppingConfig,
211    state: EarlyStoppingState,
212    current_iteration: usize,
213}
214
215impl EarlyStoppingMonitor {
216    /// Create a new early stopping monitor
217    pub fn new(strategy: EarlyStoppingStrategy, config: EarlyStoppingConfig) -> Self {
218        Self {
219            strategy,
220            config,
221            state: EarlyStoppingState::new(10), // Default window size
222            current_iteration: 0,
223        }
224    }
225
226    /// Update the monitor with a new score
227    pub fn update(&mut self, score: f64) -> Result<()> {
228        self.state
229            .update(score, self.current_iteration, &self.config);
230        self.current_iteration += 1;
231        Ok(())
232    }
233
234    /// Check if early stopping criteria are met
235    pub fn should_stop(&self) -> bool {
236        if self.current_iteration < self.config.min_iterations {
237            return false;
238        }
239
240        self.check_strategy(&self.strategy)
241    }
242
243    fn check_strategy(&self, strategy: &EarlyStoppingStrategy) -> bool {
244        match strategy {
245            EarlyStoppingStrategy::Patience => self.state.patience_counter > self.config.patience,
246            EarlyStoppingStrategy::ImprovementRate(threshold) => {
247                self.state.improvement_rate().abs() < *threshold
248            }
249            EarlyStoppingStrategy::ExponentialMovingAverage => {
250                self.state.ema_convergence(self.config.min_delta)
251            }
252            EarlyStoppingStrategy::ValidationLoss => {
253                self.state.is_overfitting(5) // Lookback of 5 iterations
254            }
255            EarlyStoppingStrategy::RelativeImprovement(threshold) => {
256                self.state.relative_improvement() < *threshold
257            }
258            EarlyStoppingStrategy::AbsoluteImprovement(threshold) => {
259                self.state.absolute_improvement() < *threshold
260            }
261            EarlyStoppingStrategy::Combined(strategies) => {
262                strategies.iter().any(|s| self.check_strategy(s))
263            }
264        }
265    }
266
267    /// Get the current state
268    pub fn state(&self) -> &EarlyStoppingState {
269        &self.state
270    }
271
272    /// Get the best score and iteration
273    pub fn best_result(&self) -> (f64, usize) {
274        (self.state.best_score, self.state.best_iteration)
275    }
276
277    /// Reset the monitor
278    pub fn reset(&mut self) {
279        self.state = EarlyStoppingState::new(self.state.window_size);
280        self.current_iteration = 0;
281    }
282
283    /// Check if minimum iterations have been reached
284    pub fn min_iterations_reached(&self) -> bool {
285        self.current_iteration >= self.config.min_iterations
286    }
287
288    /// Get convergence metrics
289    pub fn convergence_metrics(&self) -> ConvergenceMetrics {
290        ConvergenceMetrics {
291            improvement_rate: self.state.improvement_rate(),
292            relative_improvement: self.state.relative_improvement(),
293            absolute_improvement: self.state.absolute_improvement(),
294            patience_remaining: self
295                .config
296                .patience
297                .saturating_sub(self.state.patience_counter),
298            iterations_since_best: self
299                .current_iteration
300                .saturating_sub(self.state.best_iteration),
301            ema_score: self.state.ema_score,
302            current_score: self.state.score_history.last().copied().unwrap_or(0.0),
303        }
304    }
305}
306
307/// Convergence metrics for monitoring optimization progress
308#[derive(Debug, Clone)]
309pub struct ConvergenceMetrics {
310    /// Rate of improvement in recent iterations
311    pub improvement_rate: f64,
312    /// Relative improvement in the last iteration
313    pub relative_improvement: f64,
314    /// Absolute improvement in the last iteration
315    pub absolute_improvement: f64,
316    /// Number of patience iterations remaining
317    pub patience_remaining: usize,
318    /// Number of iterations since best score
319    pub iterations_since_best: usize,
320    /// Current exponential moving average score
321    pub ema_score: f64,
322    /// Most recent score
323    pub current_score: f64,
324}
325
326/// Early stopping callback trait for use with optimizers
327pub trait EarlyStoppingCallback {
328    fn on_iteration(&mut self, score: f64) -> Result<bool>;
329
330    fn on_early_stop(&mut self, reason: &str) -> Result<()>;
331
332    fn best_score(&self) -> f64;
333
334    fn convergence_info(&self) -> String;
335}
336
337impl EarlyStoppingCallback for EarlyStoppingMonitor {
338    fn on_iteration(&mut self, score: f64) -> Result<bool> {
339        self.update(score)?;
340        Ok(self.should_stop())
341    }
342
343    fn on_early_stop(&mut self, _reason: &str) -> Result<()> {
344        // Default implementation does nothing
345        Ok(())
346    }
347
348    fn best_score(&self) -> f64 {
349        self.state.best_score
350    }
351
352    fn convergence_info(&self) -> String {
353        let metrics = self.convergence_metrics();
354        format!(
355            "Best: {:.6}, Current: {:.6}, Improvement Rate: {:.6}, Patience: {}/{}",
356            self.state.best_score,
357            metrics.current_score,
358            metrics.improvement_rate,
359            self.state.patience_counter,
360            self.config.patience
361        )
362    }
363}
364
365/// Adaptive early stopping that adjusts parameters based on optimization progress
366pub struct AdaptiveEarlyStopping {
367    base_monitor: EarlyStoppingMonitor,
368    adaptation_config: AdaptationConfig,
369    adaptation_state: AdaptationState,
370}
371
372#[derive(Debug, Clone)]
373pub struct AdaptationConfig {
374    /// How often to adapt parameters (in iterations)
375    pub adaptation_frequency: usize,
376    /// Factor to increase patience when making good progress
377    pub patience_increase_factor: f64,
378    /// Factor to decrease patience when making poor progress
379    pub patience_decrease_factor: f64,
380    /// Maximum patience allowed
381    pub max_patience: usize,
382    /// Minimum patience allowed
383    pub min_patience: usize,
384    /// Threshold for "good progress"
385    pub good_progress_threshold: f64,
386    /// Threshold for "poor progress"
387    pub poor_progress_threshold: f64,
388}
389
390impl Default for AdaptationConfig {
391    fn default() -> Self {
392        Self {
393            adaptation_frequency: 20,
394            patience_increase_factor: 1.5,
395            patience_decrease_factor: 0.8,
396            max_patience: 50,
397            min_patience: 5,
398            good_progress_threshold: 0.01,
399            poor_progress_threshold: 0.001,
400        }
401    }
402}
403
404#[derive(Debug, Clone)]
405struct AdaptationState {
406    last_adaptation_iteration: usize,
407    adaptation_history: Vec<(usize, usize)>, // (iteration, patience)
408}
409
410impl AdaptiveEarlyStopping {
411    /// Create a new adaptive early stopping monitor
412    pub fn new(
413        strategy: EarlyStoppingStrategy,
414        config: EarlyStoppingConfig,
415        adaptation_config: AdaptationConfig,
416    ) -> Self {
417        Self {
418            base_monitor: EarlyStoppingMonitor::new(strategy, config),
419            adaptation_config,
420            adaptation_state: AdaptationState {
421                last_adaptation_iteration: 0,
422                adaptation_history: Vec::new(),
423            },
424        }
425    }
426
427    /// Update with adaptive behavior
428    pub fn update_adaptive(&mut self, score: f64) -> Result<()> {
429        self.base_monitor.update(score)?;
430
431        // Check if it's time to adapt
432        if self.base_monitor.current_iteration
433            >= self.adaptation_state.last_adaptation_iteration
434                + self.adaptation_config.adaptation_frequency
435        {
436            self.adapt_parameters();
437        }
438
439        Ok(())
440    }
441
442    fn adapt_parameters(&mut self) {
443        let metrics = self.base_monitor.convergence_metrics();
444        let current_patience = self.base_monitor.config.patience;
445
446        let new_patience =
447            if metrics.improvement_rate > self.adaptation_config.good_progress_threshold {
448                // Good progress: increase patience
449                let increased = (current_patience as f64
450                    * self.adaptation_config.patience_increase_factor)
451                    as usize;
452                increased.min(self.adaptation_config.max_patience)
453            } else if metrics.improvement_rate < self.adaptation_config.poor_progress_threshold {
454                // Poor progress: decrease patience
455                let decreased = (current_patience as f64
456                    * self.adaptation_config.patience_decrease_factor)
457                    as usize;
458                decreased.max(self.adaptation_config.min_patience)
459            } else {
460                current_patience // No change
461            };
462
463        if new_patience != current_patience {
464            self.base_monitor.config.patience = new_patience;
465            self.adaptation_state
466                .adaptation_history
467                .push((self.base_monitor.current_iteration, new_patience));
468        }
469
470        self.adaptation_state.last_adaptation_iteration = self.base_monitor.current_iteration;
471    }
472
473    /// Get the underlying monitor
474    pub fn monitor(&self) -> &EarlyStoppingMonitor {
475        &self.base_monitor
476    }
477
478    /// Get the underlying monitor mutably
479    pub fn monitor_mut(&mut self) -> &mut EarlyStoppingMonitor {
480        &mut self.base_monitor
481    }
482
483    /// Get adaptation history
484    pub fn adaptation_history(&self) -> &[(usize, usize)] {
485        &self.adaptation_state.adaptation_history
486    }
487}
488
489impl EarlyStoppingCallback for AdaptiveEarlyStopping {
490    fn on_iteration(&mut self, score: f64) -> Result<bool> {
491        self.update_adaptive(score)?;
492        Ok(self.base_monitor.should_stop())
493    }
494
495    fn on_early_stop(&mut self, reason: &str) -> Result<()> {
496        self.base_monitor.on_early_stop(reason)
497    }
498
499    fn best_score(&self) -> f64 {
500        self.base_monitor.best_score()
501    }
502
503    fn convergence_info(&self) -> String {
504        format!(
505            "{} | Adaptations: {}",
506            self.base_monitor.convergence_info(),
507            self.adaptation_state.adaptation_history.len()
508        )
509    }
510}
511
512#[allow(non_snake_case)]
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    #[test]
518    fn test_early_stopping_patience() {
519        let config = EarlyStoppingConfig {
520            min_iterations: 5,
521            patience: 3,
522            min_delta: 0.01,
523            maximize: true,
524            ..Default::default()
525        };
526
527        let mut monitor = EarlyStoppingMonitor::new(EarlyStoppingStrategy::Patience, config);
528
529        // Should not stop before min_iterations
530        for i in 0..5 {
531            // Provide improving scores initially to avoid early patience trigger
532            monitor.update(1.0 + i as f64 * 0.02).unwrap();
533            assert!(!monitor.should_stop(), "Should not stop at iteration {}", i);
534        }
535
536        // Add scores without improvement
537        monitor.update(1.0).unwrap(); // No improvement
538        assert!(!monitor.should_stop());
539
540        monitor.update(0.99).unwrap(); // Worse
541        assert!(!monitor.should_stop());
542
543        monitor.update(0.98).unwrap(); // Worse
544        assert!(!monitor.should_stop());
545
546        monitor.update(0.97).unwrap(); // Worse - should trigger early stopping
547        assert!(monitor.should_stop());
548    }
549
550    #[test]
551    fn test_early_stopping_improvement_rate() {
552        let config = EarlyStoppingConfig {
553            min_iterations: 3,
554            maximize: true,
555            ..Default::default()
556        };
557
558        let mut monitor = EarlyStoppingMonitor::new(
559            EarlyStoppingStrategy::ImprovementRate(0.01), // 1% threshold
560            config,
561        );
562
563        // Rapid improvement
564        monitor.update(1.0).unwrap();
565        monitor.update(1.1).unwrap();
566        monitor.update(1.2).unwrap();
567        assert!(!monitor.should_stop()); // Good improvement rate
568
569        // Fill the window with consistently slow improvement
570        // Starting from 1.2, add small increments that result in overall rate < 0.01
571        monitor.update(1.2001).unwrap(); // Very small improvement
572        monitor.update(1.2002).unwrap(); // Very small improvement
573        monitor.update(1.2003).unwrap(); // Very small improvement
574                                         // Now the window is [1.1, 1.2, 1.2001, 1.2002, 1.2003]
575                                         // Improvement rate = (1.2003 - 1.1) / 1.1 = 0.0913... > 0.01, still too high
576
577        // Need to replace the good improvement scores in the window
578        monitor.update(1.2003).unwrap(); // No improvement
579        monitor.update(1.2003).unwrap(); // No improvement
580                                         // Now window is [1.2, 1.2001, 1.2002, 1.2003, 1.2003]
581                                         // Improvement rate = (1.2003 - 1.2) / 1.2 = 0.00025 < 0.01
582        assert!(monitor.should_stop()); // Should stop with poor improvement rate
583    }
584
585    #[test]
586    fn test_early_stopping_combined_strategy() {
587        let config = EarlyStoppingConfig {
588            min_iterations: 2,
589            patience: 5,
590            maximize: true,
591            ..Default::default()
592        };
593
594        let strategy = EarlyStoppingStrategy::Combined(vec![
595            EarlyStoppingStrategy::Patience,
596            EarlyStoppingStrategy::ImprovementRate(0.001),
597        ]);
598
599        let mut monitor = EarlyStoppingMonitor::new(strategy, config);
600
601        monitor.update(1.0).unwrap();
602        monitor.update(1.0001).unwrap(); // Tiny improvement
603        monitor.update(1.0002).unwrap(); // Tiny improvement
604
605        // Should stop due to low improvement rate, even though patience hasn't run out
606        assert!(monitor.should_stop());
607    }
608
609    #[test]
610    fn test_convergence_metrics() {
611        let config = EarlyStoppingConfig::default();
612        let mut monitor = EarlyStoppingMonitor::new(EarlyStoppingStrategy::Patience, config);
613
614        monitor.update(1.0).unwrap();
615        monitor.update(1.1).unwrap();
616        monitor.update(1.05).unwrap();
617
618        let metrics = monitor.convergence_metrics();
619        assert!(metrics.improvement_rate.is_finite());
620        assert!(metrics.relative_improvement >= 0.0);
621        assert_eq!(
622            metrics.patience_remaining,
623            monitor.config.patience - monitor.state.patience_counter
624        );
625    }
626
627    #[test]
628    fn test_adaptive_early_stopping() {
629        let config = EarlyStoppingConfig {
630            min_iterations: 5,
631            patience: 10,
632            maximize: true,
633            ..Default::default()
634        };
635
636        let adaptation_config = AdaptationConfig {
637            adaptation_frequency: 5,
638            good_progress_threshold: 0.1,
639            poor_progress_threshold: 0.01,
640            ..Default::default()
641        };
642
643        let mut adaptive =
644            AdaptiveEarlyStopping::new(EarlyStoppingStrategy::Patience, config, adaptation_config);
645
646        // Good progress should increase patience
647        for i in 0..10 {
648            adaptive.update_adaptive(1.0 + i as f64 * 0.2).unwrap();
649        }
650
651        assert!(adaptive.monitor().config.patience > 10); // Should have increased
652        assert!(!adaptive.adaptation_history().is_empty());
653    }
654
655    #[test]
656    fn test_early_stopping_callback() {
657        let config = EarlyStoppingConfig {
658            min_iterations: 2,
659            patience: 2,
660            maximize: true,
661            min_delta: 0.0, // No minimum delta required for improvement
662            ..Default::default()
663        };
664
665        let mut monitor = EarlyStoppingMonitor::new(EarlyStoppingStrategy::Patience, config);
666
667        assert!(!monitor.on_iteration(1.0).unwrap()); // iteration 1: best = 1.0, patience = 0
668        assert!(!monitor.on_iteration(1.0).unwrap()); // iteration 2: no improvement, patience = 1
669        assert!(!monitor.on_iteration(0.9).unwrap()); // iteration 3: no improvement, patience = 2
670        assert!(monitor.on_iteration(0.8).unwrap()); // iteration 4: no improvement, patience = 3, should stop (patience_counter > 2)
671
672        assert_eq!(monitor.best_score(), 1.0);
673        assert!(monitor.convergence_info().contains("Best: 1.000000"));
674    }
675}