1use sklears_core::error::Result;
8use std::collections::VecDeque;
9
10#[derive(Debug, Clone)]
12pub struct EarlyStoppingConfig {
13 pub min_iterations: usize,
15 pub patience: usize,
17 pub min_delta: f64,
19 pub restore_best_weights: bool,
21 pub maximize: bool,
23 pub baseline: Option<f64>,
25 pub smoothing_factor: f64,
27}
28
29impl Default for EarlyStoppingConfig {
30 fn default() -> Self {
31 Self {
32 min_iterations: 10,
33 patience: 10,
34 min_delta: 1e-4,
35 restore_best_weights: true,
36 maximize: true,
37 baseline: None,
38 smoothing_factor: 0.9,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub enum EarlyStoppingStrategy {
46 Patience,
48 ImprovementRate(f64),
50 ExponentialMovingAverage,
52 ValidationLoss,
54 RelativeImprovement(f64),
56 AbsoluteImprovement(f64),
58 Combined(Vec<EarlyStoppingStrategy>),
60}
61
62#[derive(Debug, Clone)]
64pub struct EarlyStoppingState {
65 pub best_score: f64,
67 pub best_iteration: usize,
69 pub patience_counter: usize,
71 pub score_history: Vec<f64>,
73 pub ema_score: f64,
75 pub ema_initialized: bool,
77 pub recent_scores: VecDeque<f64>,
79 pub window_size: usize,
81}
82
83impl EarlyStoppingState {
84 fn new(window_size: usize) -> Self {
85 Self {
86 best_score: f64::NEG_INFINITY,
87 best_iteration: 0,
88 patience_counter: 0,
89 score_history: Vec::new(),
90 ema_score: 0.0,
91 ema_initialized: false,
92 recent_scores: VecDeque::with_capacity(window_size),
93 window_size,
94 }
95 }
96
97 fn update(&mut self, score: f64, iteration: usize, config: &EarlyStoppingConfig) {
98 self.score_history.push(score);
99
100 if self.recent_scores.len() >= self.window_size {
102 self.recent_scores.pop_front();
103 }
104 self.recent_scores.push_back(score);
105
106 if !self.ema_initialized {
108 self.ema_score = score;
109 self.ema_initialized = true;
110 } else {
111 self.ema_score =
112 config.smoothing_factor * self.ema_score + (1.0 - config.smoothing_factor) * score;
113 }
114
115 let is_improvement = if config.maximize {
117 score > self.best_score + config.min_delta
118 } else {
119 score < self.best_score - config.min_delta
120 };
121
122 if is_improvement {
123 self.best_score = score;
124 self.best_iteration = iteration;
125 self.patience_counter = 0;
126 } else {
127 self.patience_counter += 1;
128 }
129 }
130
131 fn improvement_rate(&self) -> f64 {
132 if self.score_history.len() < 2 {
133 return 0.0;
134 }
135
136 let recent_window = self.score_history.len().min(5);
137 let recent_scores = &self.score_history[self.score_history.len() - recent_window..];
138
139 if recent_scores.len() < 2 {
140 return 0.0;
141 }
142
143 let start_score = recent_scores[0];
144 let end_score = recent_scores[recent_scores.len() - 1];
145
146 if start_score.abs() < 1e-8 {
147 return 0.0;
148 }
149
150 (end_score - start_score) / start_score.abs()
151 }
152
153 fn relative_improvement(&self) -> f64 {
154 if self.score_history.len() < 2 {
155 return f64::INFINITY;
156 }
157
158 let current = self.score_history[self.score_history.len() - 1];
159 let previous = self.score_history[self.score_history.len() - 2];
160
161 if previous.abs() < 1e-8 {
162 return f64::INFINITY;
163 }
164
165 (current - previous).abs() / previous.abs()
166 }
167
168 fn absolute_improvement(&self) -> f64 {
169 if self.score_history.len() < 2 {
170 return f64::INFINITY;
171 }
172
173 let current = self.score_history[self.score_history.len() - 1];
174 let previous = self.score_history[self.score_history.len() - 2];
175
176 (current - previous).abs()
177 }
178
179 fn ema_convergence(&self, threshold: f64) -> bool {
180 if self.recent_scores.len() < self.window_size {
181 return false;
182 }
183
184 let recent_avg: f64 =
185 self.recent_scores.iter().sum::<f64>() / self.recent_scores.len() as f64;
186 (self.ema_score - recent_avg).abs() < threshold
187 }
188
189 fn is_overfitting(&self, lookback: usize) -> bool {
190 if self.score_history.len() < lookback + 2 {
191 return false;
192 }
193
194 let len = self.score_history.len();
195 let recent_scores = &self.score_history[len - lookback..];
196 let previous_scores = &self.score_history[len - lookback - lookback..len - lookback];
197
198 let recent_avg: f64 = recent_scores.iter().sum::<f64>() / recent_scores.len() as f64;
199 let previous_avg: f64 = previous_scores.iter().sum::<f64>() / previous_scores.len() as f64;
200
201 recent_avg < previous_avg
204 }
205}
206
207pub struct EarlyStoppingMonitor {
209 strategy: EarlyStoppingStrategy,
210 config: EarlyStoppingConfig,
211 state: EarlyStoppingState,
212 current_iteration: usize,
213}
214
215impl EarlyStoppingMonitor {
216 pub fn new(strategy: EarlyStoppingStrategy, config: EarlyStoppingConfig) -> Self {
218 Self {
219 strategy,
220 config,
221 state: EarlyStoppingState::new(10), current_iteration: 0,
223 }
224 }
225
226 pub fn update(&mut self, score: f64) -> Result<()> {
228 self.state
229 .update(score, self.current_iteration, &self.config);
230 self.current_iteration += 1;
231 Ok(())
232 }
233
234 pub fn should_stop(&self) -> bool {
236 if self.current_iteration < self.config.min_iterations {
237 return false;
238 }
239
240 self.check_strategy(&self.strategy)
241 }
242
243 fn check_strategy(&self, strategy: &EarlyStoppingStrategy) -> bool {
244 match strategy {
245 EarlyStoppingStrategy::Patience => self.state.patience_counter > self.config.patience,
246 EarlyStoppingStrategy::ImprovementRate(threshold) => {
247 self.state.improvement_rate().abs() < *threshold
248 }
249 EarlyStoppingStrategy::ExponentialMovingAverage => {
250 self.state.ema_convergence(self.config.min_delta)
251 }
252 EarlyStoppingStrategy::ValidationLoss => {
253 self.state.is_overfitting(5) }
255 EarlyStoppingStrategy::RelativeImprovement(threshold) => {
256 self.state.relative_improvement() < *threshold
257 }
258 EarlyStoppingStrategy::AbsoluteImprovement(threshold) => {
259 self.state.absolute_improvement() < *threshold
260 }
261 EarlyStoppingStrategy::Combined(strategies) => {
262 strategies.iter().any(|s| self.check_strategy(s))
263 }
264 }
265 }
266
267 pub fn state(&self) -> &EarlyStoppingState {
269 &self.state
270 }
271
272 pub fn best_result(&self) -> (f64, usize) {
274 (self.state.best_score, self.state.best_iteration)
275 }
276
277 pub fn reset(&mut self) {
279 self.state = EarlyStoppingState::new(self.state.window_size);
280 self.current_iteration = 0;
281 }
282
283 pub fn min_iterations_reached(&self) -> bool {
285 self.current_iteration >= self.config.min_iterations
286 }
287
288 pub fn convergence_metrics(&self) -> ConvergenceMetrics {
290 ConvergenceMetrics {
291 improvement_rate: self.state.improvement_rate(),
292 relative_improvement: self.state.relative_improvement(),
293 absolute_improvement: self.state.absolute_improvement(),
294 patience_remaining: self
295 .config
296 .patience
297 .saturating_sub(self.state.patience_counter),
298 iterations_since_best: self
299 .current_iteration
300 .saturating_sub(self.state.best_iteration),
301 ema_score: self.state.ema_score,
302 current_score: self.state.score_history.last().copied().unwrap_or(0.0),
303 }
304 }
305}
306
307#[derive(Debug, Clone)]
309pub struct ConvergenceMetrics {
310 pub improvement_rate: f64,
312 pub relative_improvement: f64,
314 pub absolute_improvement: f64,
316 pub patience_remaining: usize,
318 pub iterations_since_best: usize,
320 pub ema_score: f64,
322 pub current_score: f64,
324}
325
326pub trait EarlyStoppingCallback {
328 fn on_iteration(&mut self, score: f64) -> Result<bool>;
329
330 fn on_early_stop(&mut self, reason: &str) -> Result<()>;
331
332 fn best_score(&self) -> f64;
333
334 fn convergence_info(&self) -> String;
335}
336
337impl EarlyStoppingCallback for EarlyStoppingMonitor {
338 fn on_iteration(&mut self, score: f64) -> Result<bool> {
339 self.update(score)?;
340 Ok(self.should_stop())
341 }
342
343 fn on_early_stop(&mut self, _reason: &str) -> Result<()> {
344 Ok(())
346 }
347
348 fn best_score(&self) -> f64 {
349 self.state.best_score
350 }
351
352 fn convergence_info(&self) -> String {
353 let metrics = self.convergence_metrics();
354 format!(
355 "Best: {:.6}, Current: {:.6}, Improvement Rate: {:.6}, Patience: {}/{}",
356 self.state.best_score,
357 metrics.current_score,
358 metrics.improvement_rate,
359 self.state.patience_counter,
360 self.config.patience
361 )
362 }
363}
364
365pub struct AdaptiveEarlyStopping {
367 base_monitor: EarlyStoppingMonitor,
368 adaptation_config: AdaptationConfig,
369 adaptation_state: AdaptationState,
370}
371
372#[derive(Debug, Clone)]
373pub struct AdaptationConfig {
374 pub adaptation_frequency: usize,
376 pub patience_increase_factor: f64,
378 pub patience_decrease_factor: f64,
380 pub max_patience: usize,
382 pub min_patience: usize,
384 pub good_progress_threshold: f64,
386 pub poor_progress_threshold: f64,
388}
389
390impl Default for AdaptationConfig {
391 fn default() -> Self {
392 Self {
393 adaptation_frequency: 20,
394 patience_increase_factor: 1.5,
395 patience_decrease_factor: 0.8,
396 max_patience: 50,
397 min_patience: 5,
398 good_progress_threshold: 0.01,
399 poor_progress_threshold: 0.001,
400 }
401 }
402}
403
404#[derive(Debug, Clone)]
405struct AdaptationState {
406 last_adaptation_iteration: usize,
407 adaptation_history: Vec<(usize, usize)>, }
409
410impl AdaptiveEarlyStopping {
411 pub fn new(
413 strategy: EarlyStoppingStrategy,
414 config: EarlyStoppingConfig,
415 adaptation_config: AdaptationConfig,
416 ) -> Self {
417 Self {
418 base_monitor: EarlyStoppingMonitor::new(strategy, config),
419 adaptation_config,
420 adaptation_state: AdaptationState {
421 last_adaptation_iteration: 0,
422 adaptation_history: Vec::new(),
423 },
424 }
425 }
426
427 pub fn update_adaptive(&mut self, score: f64) -> Result<()> {
429 self.base_monitor.update(score)?;
430
431 if self.base_monitor.current_iteration
433 >= self.adaptation_state.last_adaptation_iteration
434 + self.adaptation_config.adaptation_frequency
435 {
436 self.adapt_parameters();
437 }
438
439 Ok(())
440 }
441
442 fn adapt_parameters(&mut self) {
443 let metrics = self.base_monitor.convergence_metrics();
444 let current_patience = self.base_monitor.config.patience;
445
446 let new_patience =
447 if metrics.improvement_rate > self.adaptation_config.good_progress_threshold {
448 let increased = (current_patience as f64
450 * self.adaptation_config.patience_increase_factor)
451 as usize;
452 increased.min(self.adaptation_config.max_patience)
453 } else if metrics.improvement_rate < self.adaptation_config.poor_progress_threshold {
454 let decreased = (current_patience as f64
456 * self.adaptation_config.patience_decrease_factor)
457 as usize;
458 decreased.max(self.adaptation_config.min_patience)
459 } else {
460 current_patience };
462
463 if new_patience != current_patience {
464 self.base_monitor.config.patience = new_patience;
465 self.adaptation_state
466 .adaptation_history
467 .push((self.base_monitor.current_iteration, new_patience));
468 }
469
470 self.adaptation_state.last_adaptation_iteration = self.base_monitor.current_iteration;
471 }
472
473 pub fn monitor(&self) -> &EarlyStoppingMonitor {
475 &self.base_monitor
476 }
477
478 pub fn monitor_mut(&mut self) -> &mut EarlyStoppingMonitor {
480 &mut self.base_monitor
481 }
482
483 pub fn adaptation_history(&self) -> &[(usize, usize)] {
485 &self.adaptation_state.adaptation_history
486 }
487}
488
489impl EarlyStoppingCallback for AdaptiveEarlyStopping {
490 fn on_iteration(&mut self, score: f64) -> Result<bool> {
491 self.update_adaptive(score)?;
492 Ok(self.base_monitor.should_stop())
493 }
494
495 fn on_early_stop(&mut self, reason: &str) -> Result<()> {
496 self.base_monitor.on_early_stop(reason)
497 }
498
499 fn best_score(&self) -> f64 {
500 self.base_monitor.best_score()
501 }
502
503 fn convergence_info(&self) -> String {
504 format!(
505 "{} | Adaptations: {}",
506 self.base_monitor.convergence_info(),
507 self.adaptation_state.adaptation_history.len()
508 )
509 }
510}
511
512#[allow(non_snake_case)]
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn test_early_stopping_patience() {
519 let config = EarlyStoppingConfig {
520 min_iterations: 5,
521 patience: 3,
522 min_delta: 0.01,
523 maximize: true,
524 ..Default::default()
525 };
526
527 let mut monitor = EarlyStoppingMonitor::new(EarlyStoppingStrategy::Patience, config);
528
529 for i in 0..5 {
531 monitor.update(1.0 + i as f64 * 0.02).unwrap();
533 assert!(!monitor.should_stop(), "Should not stop at iteration {}", i);
534 }
535
536 monitor.update(1.0).unwrap(); assert!(!monitor.should_stop());
539
540 monitor.update(0.99).unwrap(); assert!(!monitor.should_stop());
542
543 monitor.update(0.98).unwrap(); assert!(!monitor.should_stop());
545
546 monitor.update(0.97).unwrap(); assert!(monitor.should_stop());
548 }
549
550 #[test]
551 fn test_early_stopping_improvement_rate() {
552 let config = EarlyStoppingConfig {
553 min_iterations: 3,
554 maximize: true,
555 ..Default::default()
556 };
557
558 let mut monitor = EarlyStoppingMonitor::new(
559 EarlyStoppingStrategy::ImprovementRate(0.01), config,
561 );
562
563 monitor.update(1.0).unwrap();
565 monitor.update(1.1).unwrap();
566 monitor.update(1.2).unwrap();
567 assert!(!monitor.should_stop()); monitor.update(1.2001).unwrap(); monitor.update(1.2002).unwrap(); monitor.update(1.2003).unwrap(); monitor.update(1.2003).unwrap(); monitor.update(1.2003).unwrap(); assert!(monitor.should_stop()); }
584
585 #[test]
586 fn test_early_stopping_combined_strategy() {
587 let config = EarlyStoppingConfig {
588 min_iterations: 2,
589 patience: 5,
590 maximize: true,
591 ..Default::default()
592 };
593
594 let strategy = EarlyStoppingStrategy::Combined(vec![
595 EarlyStoppingStrategy::Patience,
596 EarlyStoppingStrategy::ImprovementRate(0.001),
597 ]);
598
599 let mut monitor = EarlyStoppingMonitor::new(strategy, config);
600
601 monitor.update(1.0).unwrap();
602 monitor.update(1.0001).unwrap(); monitor.update(1.0002).unwrap(); assert!(monitor.should_stop());
607 }
608
609 #[test]
610 fn test_convergence_metrics() {
611 let config = EarlyStoppingConfig::default();
612 let mut monitor = EarlyStoppingMonitor::new(EarlyStoppingStrategy::Patience, config);
613
614 monitor.update(1.0).unwrap();
615 monitor.update(1.1).unwrap();
616 monitor.update(1.05).unwrap();
617
618 let metrics = monitor.convergence_metrics();
619 assert!(metrics.improvement_rate.is_finite());
620 assert!(metrics.relative_improvement >= 0.0);
621 assert_eq!(
622 metrics.patience_remaining,
623 monitor.config.patience - monitor.state.patience_counter
624 );
625 }
626
627 #[test]
628 fn test_adaptive_early_stopping() {
629 let config = EarlyStoppingConfig {
630 min_iterations: 5,
631 patience: 10,
632 maximize: true,
633 ..Default::default()
634 };
635
636 let adaptation_config = AdaptationConfig {
637 adaptation_frequency: 5,
638 good_progress_threshold: 0.1,
639 poor_progress_threshold: 0.01,
640 ..Default::default()
641 };
642
643 let mut adaptive =
644 AdaptiveEarlyStopping::new(EarlyStoppingStrategy::Patience, config, adaptation_config);
645
646 for i in 0..10 {
648 adaptive.update_adaptive(1.0 + i as f64 * 0.2).unwrap();
649 }
650
651 assert!(adaptive.monitor().config.patience > 10); assert!(!adaptive.adaptation_history().is_empty());
653 }
654
655 #[test]
656 fn test_early_stopping_callback() {
657 let config = EarlyStoppingConfig {
658 min_iterations: 2,
659 patience: 2,
660 maximize: true,
661 min_delta: 0.0, ..Default::default()
663 };
664
665 let mut monitor = EarlyStoppingMonitor::new(EarlyStoppingStrategy::Patience, config);
666
667 assert!(!monitor.on_iteration(1.0).unwrap()); assert!(!monitor.on_iteration(1.0).unwrap()); assert!(!monitor.on_iteration(0.9).unwrap()); assert!(monitor.on_iteration(0.8).unwrap()); assert_eq!(monitor.best_score(), 1.0);
673 assert!(monitor.convergence_info().contains("Best: 1.000000"));
674 }
675}