Skip to main content

tensorlogic_train/
early_stopping.rs

1//! Early stopping monitor for training loops.
2//!
3//! Provides configurable early stopping based on metric monitoring,
4//! multi-metric policies, plateau detection, and training progress tracking.
5
6use std::collections::HashMap;
7
8/// Configuration for early stopping.
9#[derive(Debug, Clone)]
10pub struct EarlyStoppingConfig {
11    /// Number of epochs to wait after last improvement before stopping.
12    pub patience: usize,
13    /// Minimum change to qualify as an improvement.
14    pub min_delta: f64,
15    /// Whether to minimize (loss) or maximize (accuracy) the metric.
16    pub mode: MonitorMode,
17    /// If `Some`, the metric must beat this baseline to count as improvement.
18    pub baseline: Option<f64>,
19    /// Whether to signal restoring the best model state on stop.
20    pub restore_best: bool,
21    /// Minimum number of epochs before early stopping can trigger.
22    pub min_epochs: usize,
23}
24
25impl Default for EarlyStoppingConfig {
26    fn default() -> Self {
27        Self {
28            patience: 10,
29            min_delta: 0.0,
30            mode: MonitorMode::Minimize,
31            baseline: None,
32            restore_best: true,
33            min_epochs: 1,
34        }
35    }
36}
37
38/// Whether we want to minimize (loss) or maximize (accuracy) the metric.
39#[derive(Debug, Clone, PartialEq)]
40pub enum MonitorMode {
41    /// Improvement means the metric is decreasing (e.g. loss).
42    Minimize,
43    /// Improvement means the metric is increasing (e.g. accuracy).
44    Maximize,
45}
46
47/// Decision returned by the early stopping monitor after each step.
48#[derive(Debug, Clone, PartialEq)]
49pub enum EarlyStoppingDecision {
50    /// Keep training, no special event.
51    Continue,
52    /// Training should stop.
53    Stop {
54        /// Human-readable reason for stopping.
55        reason: String,
56    },
57    /// A new best metric value was observed (training continues).
58    NewBest {
59        /// The new best metric value.
60        value: f64,
61        /// The epoch at which the new best was observed.
62        epoch: usize,
63    },
64}
65
66/// The early stopping monitor.
67///
68/// Tracks a single metric over epochs and decides when to stop training
69/// based on patience, minimum delta, baseline, and minimum epoch constraints.
70#[derive(Debug, Clone)]
71pub struct EarlyStoppingMonitor {
72    config: EarlyStoppingConfig,
73    best_value: Option<f64>,
74    best_epoch: usize,
75    epochs_without_improvement: usize,
76    current_epoch: usize,
77    history: Vec<f64>,
78    stopped: bool,
79}
80
81impl EarlyStoppingMonitor {
82    /// Create a new monitor with the given configuration.
83    pub fn new(config: EarlyStoppingConfig) -> Self {
84        Self {
85            config,
86            best_value: None,
87            best_epoch: 0,
88            epochs_without_improvement: 0,
89            current_epoch: 0,
90            history: Vec::new(),
91            stopped: false,
92        }
93    }
94
95    /// Create a new monitor with default configuration.
96    pub fn with_default() -> Self {
97        Self::new(EarlyStoppingConfig::default())
98    }
99
100    /// Report the metric for the current epoch and return a decision.
101    ///
102    /// This advances the epoch counter, records the metric value,
103    /// and evaluates whether training should continue or stop.
104    pub fn step(&mut self, metric_value: f64) -> EarlyStoppingDecision {
105        self.current_epoch += 1;
106        self.history.push(metric_value);
107
108        // Check baseline constraint: if a baseline is set and the metric
109        // doesn't beat it, we don't consider it an improvement.
110        if let Some(baseline) = self.config.baseline {
111            let beats_baseline = match self.config.mode {
112                MonitorMode::Minimize => metric_value < baseline,
113                MonitorMode::Maximize => metric_value > baseline,
114            };
115            if !beats_baseline {
116                self.epochs_without_improvement += 1;
117                return self.evaluate_stop();
118            }
119        }
120
121        // Check if this is an improvement over the best value seen so far.
122        let is_new_best = match self.best_value {
123            None => true,
124            Some(best) => self.is_improvement(metric_value, best),
125        };
126
127        if is_new_best {
128            self.best_value = Some(metric_value);
129            self.best_epoch = self.current_epoch;
130            self.epochs_without_improvement = 0;
131            EarlyStoppingDecision::NewBest {
132                value: metric_value,
133                epoch: self.current_epoch,
134            }
135        } else {
136            self.epochs_without_improvement += 1;
137            self.evaluate_stop()
138        }
139    }
140
141    /// Evaluate whether training should stop based on patience and min_epochs.
142    fn evaluate_stop(&mut self) -> EarlyStoppingDecision {
143        if self.current_epoch < self.config.min_epochs {
144            return EarlyStoppingDecision::Continue;
145        }
146
147        if self.epochs_without_improvement >= self.config.patience {
148            self.stopped = true;
149            let best_str = self
150                .best_value
151                .map(|v| format!("{v:.6}"))
152                .unwrap_or_else(|| "N/A".to_string());
153            EarlyStoppingDecision::Stop {
154                reason: format!(
155                    "No improvement for {} epochs. Best value: {} at epoch {}.",
156                    self.config.patience, best_str, self.best_epoch
157                ),
158            }
159        } else {
160            EarlyStoppingDecision::Continue
161        }
162    }
163
164    /// Check if training should stop (without advancing epoch).
165    pub fn should_stop(&self) -> bool {
166        self.stopped
167    }
168
169    /// Best metric value seen so far.
170    pub fn best_value(&self) -> Option<f64> {
171        self.best_value
172    }
173
174    /// Epoch at which the best value was seen.
175    pub fn best_epoch(&self) -> usize {
176        self.best_epoch
177    }
178
179    /// Current epoch number.
180    pub fn current_epoch(&self) -> usize {
181        self.current_epoch
182    }
183
184    /// Number of epochs since last improvement.
185    pub fn epochs_since_improvement(&self) -> usize {
186        self.epochs_without_improvement
187    }
188
189    /// Full metric history.
190    pub fn history(&self) -> &[f64] {
191        &self.history
192    }
193
194    /// Reset the monitor to its initial state.
195    pub fn reset(&mut self) {
196        self.best_value = None;
197        self.best_epoch = 0;
198        self.epochs_without_improvement = 0;
199        self.current_epoch = 0;
200        self.history.clear();
201        self.stopped = false;
202    }
203
204    /// Whether the current value is an improvement over the best.
205    ///
206    /// For `Minimize` mode: `current < best - min_delta`.
207    /// For `Maximize` mode: `current > best + min_delta`.
208    fn is_improvement(&self, current: f64, best: f64) -> bool {
209        match self.config.mode {
210            MonitorMode::Minimize => current < best - self.config.min_delta,
211            MonitorMode::Maximize => current > best + self.config.min_delta,
212        }
213    }
214
215    /// Return a human-readable summary of the monitor's current state.
216    pub fn summary(&self) -> String {
217        let best_str = self
218            .best_value
219            .map(|v| format!("{v:.6}"))
220            .unwrap_or_else(|| "N/A".to_string());
221        let mode_str = match self.config.mode {
222            MonitorMode::Minimize => "minimize",
223            MonitorMode::Maximize => "maximize",
224        };
225        format!(
226            "EarlyStoppingMonitor(mode={}, epoch={}, best={} at epoch {}, \
227             patience={}/{}, stopped={})",
228            mode_str,
229            self.current_epoch,
230            best_str,
231            self.best_epoch,
232            self.epochs_without_improvement,
233            self.config.patience,
234            self.stopped,
235        )
236    }
237}
238
239/// Policy for combining decisions from multiple metric monitors.
240#[derive(Debug, Clone, PartialEq)]
241pub enum MultiMetricPolicy {
242    /// Stop when ALL monitored metrics signal stop.
243    All,
244    /// Stop when ANY monitored metric signals stop.
245    Any,
246}
247
248/// A multi-metric early stopping monitor.
249///
250/// Monitors multiple metrics simultaneously and applies a policy
251/// (`All` or `Any`) to decide when to stop training.
252#[derive(Debug, Clone)]
253pub struct MultiMetricMonitor {
254    monitors: Vec<(String, EarlyStoppingMonitor)>,
255    policy: MultiMetricPolicy,
256}
257
258impl MultiMetricMonitor {
259    /// Create a new multi-metric monitor with the given policy.
260    pub fn new(policy: MultiMetricPolicy) -> Self {
261        Self {
262            monitors: Vec::new(),
263            policy,
264        }
265    }
266
267    /// Register a new metric to monitor.
268    pub fn add_metric(&mut self, name: impl Into<String>, config: EarlyStoppingConfig) {
269        let name = name.into();
270        let monitor = EarlyStoppingMonitor::new(config);
271        self.monitors.push((name, monitor));
272    }
273
274    /// Report values for all metrics and return a combined decision.
275    ///
276    /// Keys in `values` must match registered metric names.
277    /// Metrics not present in `values` are skipped for that step.
278    pub fn step(&mut self, values: &[(String, f64)]) -> EarlyStoppingDecision {
279        let value_map: HashMap<&str, f64> = values.iter().map(|(k, v)| (k.as_str(), *v)).collect();
280
281        let mut decisions = Vec::new();
282
283        for (name, monitor) in &mut self.monitors {
284            if let Some(&val) = value_map.get(name.as_str()) {
285                let decision = monitor.step(val);
286                decisions.push(decision);
287            }
288        }
289
290        // Apply the policy to decide the combined outcome.
291        let any_stop = decisions
292            .iter()
293            .any(|d| matches!(d, EarlyStoppingDecision::Stop { .. }));
294        let all_stop = !decisions.is_empty()
295            && decisions
296                .iter()
297                .all(|d| matches!(d, EarlyStoppingDecision::Stop { .. }));
298
299        let should_stop = match self.policy {
300            MultiMetricPolicy::All => all_stop,
301            MultiMetricPolicy::Any => any_stop,
302        };
303
304        if should_stop {
305            // Collect reasons from all monitors that signaled stop.
306            let reasons: Vec<String> = decisions
307                .into_iter()
308                .filter_map(|d| {
309                    if let EarlyStoppingDecision::Stop { reason } = d {
310                        Some(reason)
311                    } else {
312                        None
313                    }
314                })
315                .collect();
316            EarlyStoppingDecision::Stop {
317                reason: reasons.join("; "),
318            }
319        } else {
320            // Check if any monitor found a new best.
321            let new_best = decisions
322                .iter()
323                .find(|d| matches!(d, EarlyStoppingDecision::NewBest { .. }));
324            match new_best {
325                Some(EarlyStoppingDecision::NewBest { value, epoch }) => {
326                    EarlyStoppingDecision::NewBest {
327                        value: *value,
328                        epoch: *epoch,
329                    }
330                }
331                _ => EarlyStoppingDecision::Continue,
332            }
333        }
334    }
335
336    /// Get an individual monitor by name.
337    pub fn get_monitor(&self, name: &str) -> Option<&EarlyStoppingMonitor> {
338        self.monitors
339            .iter()
340            .find(|(n, _)| n == name)
341            .map(|(_, m)| m)
342    }
343
344    /// Number of registered metrics.
345    pub fn num_metrics(&self) -> usize {
346        self.monitors.len()
347    }
348
349    /// Summary across all monitors.
350    pub fn summary(&self) -> String {
351        let mut parts = Vec::new();
352        parts.push(format!(
353            "MultiMetricMonitor(policy={:?}, metrics={})",
354            self.policy,
355            self.monitors.len()
356        ));
357        for (name, monitor) in &self.monitors {
358            parts.push(format!("  {}: {}", name, monitor.summary()));
359        }
360        parts.join("\n")
361    }
362}
363
364/// Plateau detector: detects when a metric has plateaued.
365///
366/// A plateau is detected when the variance of the most recent values
367/// within a sliding window falls below a configurable threshold.
368#[derive(Debug, Clone)]
369pub struct PlateauDetector {
370    /// Size of the sliding window.
371    pub window_size: usize,
372    /// Variance threshold below which a plateau is declared.
373    pub variance_threshold: f64,
374    history: Vec<f64>,
375}
376
377impl PlateauDetector {
378    /// Create a new plateau detector.
379    pub fn new(window_size: usize, variance_threshold: f64) -> Self {
380        Self {
381            window_size,
382            variance_threshold,
383            history: Vec::new(),
384        }
385    }
386
387    /// Push a new metric value.
388    pub fn push(&mut self, value: f64) {
389        self.history.push(value);
390    }
391
392    /// Whether the metric has plateaued (window full and variance below threshold).
393    pub fn is_plateau(&self) -> bool {
394        if self.history.len() < self.window_size {
395            return false;
396        }
397        match self.current_variance() {
398            Some(var) => var < self.variance_threshold,
399            None => false,
400        }
401    }
402
403    /// Compute the variance of the values in the current window.
404    ///
405    /// Returns `None` if there are fewer values than the window size.
406    pub fn current_variance(&self) -> Option<f64> {
407        if self.history.len() < self.window_size {
408            return None;
409        }
410        let window = self.values_in_window();
411        let n = window.len() as f64;
412        if n < 1.0 {
413            return None;
414        }
415        let mean = window.iter().sum::<f64>() / n;
416        let variance = window.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
417        Some(variance)
418    }
419
420    /// The values currently in the sliding window.
421    pub fn values_in_window(&self) -> &[f64] {
422        if self.history.len() < self.window_size {
423            &self.history
424        } else {
425            &self.history[self.history.len() - self.window_size..]
426        }
427    }
428}
429
430/// Training progress tracker.
431///
432/// Combines epoch tracking with multi-metric recording and progress reporting.
433#[derive(Debug, Clone)]
434pub struct TrainingProgress {
435    /// Total number of planned epochs.
436    pub total_epochs: usize,
437    /// Current epoch (0-indexed, incremented by `advance_epoch`).
438    pub current_epoch: usize,
439    /// Recorded metrics keyed by name, each with a history of values.
440    pub metrics: HashMap<String, Vec<f64>>,
441}
442
443impl TrainingProgress {
444    /// Create a new training progress tracker.
445    pub fn new(total_epochs: usize) -> Self {
446        Self {
447            total_epochs,
448            current_epoch: 0,
449            metrics: HashMap::new(),
450        }
451    }
452
453    /// Record a metric value for the current epoch.
454    pub fn record(&mut self, metric_name: impl Into<String>, value: f64) {
455        self.metrics
456            .entry(metric_name.into())
457            .or_default()
458            .push(value);
459    }
460
461    /// Fraction of training completed (current / total).
462    pub fn progress_fraction(&self) -> f64 {
463        if self.total_epochs == 0 {
464            return 0.0;
465        }
466        self.current_epoch as f64 / self.total_epochs as f64
467    }
468
469    /// Advance to the next epoch.
470    pub fn advance_epoch(&mut self) {
471        self.current_epoch += 1;
472    }
473
474    /// Get the latest recorded value for a metric.
475    pub fn latest(&self, metric_name: &str) -> Option<f64> {
476        self.metrics
477            .get(metric_name)
478            .and_then(|v| v.last().copied())
479    }
480
481    /// Get the best (min or max) recorded value for a metric.
482    pub fn best(&self, metric_name: &str, mode: &MonitorMode) -> Option<f64> {
483        self.metrics.get(metric_name).and_then(|values| {
484            if values.is_empty() {
485                return None;
486            }
487            match mode {
488                MonitorMode::Minimize => values
489                    .iter()
490                    .copied()
491                    .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)),
492                MonitorMode::Maximize => values
493                    .iter()
494                    .copied()
495                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)),
496            }
497        })
498    }
499
500    /// Human-readable summary of training progress.
501    pub fn summary(&self) -> String {
502        let pct = self.progress_fraction() * 100.0;
503        let mut parts = vec![format!(
504            "TrainingProgress: epoch {}/{} ({:.1}%)",
505            self.current_epoch, self.total_epochs, pct
506        )];
507        for (name, values) in &self.metrics {
508            let latest = values.last().map(|v| format!("{v:.6}")).unwrap_or_default();
509            parts.push(format!(
510                "  {}: latest={}, entries={}",
511                name,
512                latest,
513                values.len()
514            ));
515        }
516        parts.join("\n")
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523
524    #[test]
525    fn test_early_stopping_config_default() {
526        let config = EarlyStoppingConfig::default();
527        assert_eq!(config.patience, 10);
528        assert_eq!(config.min_delta, 0.0);
529        assert_eq!(config.mode, MonitorMode::Minimize);
530        assert!(config.baseline.is_none());
531        assert!(config.restore_best);
532        assert_eq!(config.min_epochs, 1);
533    }
534
535    #[test]
536    fn test_monitor_new_best_on_first_step() {
537        let mut monitor = EarlyStoppingMonitor::with_default();
538        let decision = monitor.step(1.0);
539        assert_eq!(
540            decision,
541            EarlyStoppingDecision::NewBest {
542                value: 1.0,
543                epoch: 1
544            }
545        );
546    }
547
548    #[test]
549    fn test_monitor_continue_while_improving() {
550        let config = EarlyStoppingConfig {
551            patience: 3,
552            ..Default::default()
553        };
554        let mut monitor = EarlyStoppingMonitor::new(config);
555
556        // Decreasing loss values = improvement in Minimize mode.
557        let d1 = monitor.step(1.0);
558        assert!(matches!(d1, EarlyStoppingDecision::NewBest { .. }));
559
560        let d2 = monitor.step(0.8);
561        assert!(matches!(d2, EarlyStoppingDecision::NewBest { .. }));
562
563        let d3 = monitor.step(0.6);
564        assert!(matches!(d3, EarlyStoppingDecision::NewBest { .. }));
565
566        let d4 = monitor.step(0.4);
567        assert!(matches!(d4, EarlyStoppingDecision::NewBest { .. }));
568    }
569
570    #[test]
571    fn test_monitor_stop_after_patience() {
572        let config = EarlyStoppingConfig {
573            patience: 3,
574            ..Default::default()
575        };
576        let mut monitor = EarlyStoppingMonitor::new(config);
577
578        // First step sets the best.
579        monitor.step(1.0);
580
581        // No improvement for 3 epochs.
582        let d1 = monitor.step(1.5);
583        assert_eq!(d1, EarlyStoppingDecision::Continue);
584
585        let d2 = monitor.step(1.5);
586        assert_eq!(d2, EarlyStoppingDecision::Continue);
587
588        let d3 = monitor.step(1.5);
589        assert!(matches!(d3, EarlyStoppingDecision::Stop { .. }));
590        assert!(monitor.should_stop());
591    }
592
593    #[test]
594    fn test_monitor_min_delta_threshold() {
595        let config = EarlyStoppingConfig {
596            patience: 2,
597            min_delta: 0.1,
598            ..Default::default()
599        };
600        let mut monitor = EarlyStoppingMonitor::new(config);
601
602        // Best = 1.0
603        monitor.step(1.0);
604
605        // 0.95 is not < 1.0 - 0.1 = 0.9, so NOT an improvement.
606        let d = monitor.step(0.95);
607        assert_eq!(d, EarlyStoppingDecision::Continue);
608
609        // 0.89 IS < 0.9, so it IS an improvement.
610        let d = monitor.step(0.89);
611        assert!(matches!(d, EarlyStoppingDecision::NewBest { .. }));
612    }
613
614    #[test]
615    fn test_monitor_maximize_mode() {
616        let config = EarlyStoppingConfig {
617            patience: 3,
618            mode: MonitorMode::Maximize,
619            ..Default::default()
620        };
621        let mut monitor = EarlyStoppingMonitor::new(config);
622
623        let d1 = monitor.step(0.5);
624        assert!(matches!(d1, EarlyStoppingDecision::NewBest { .. }));
625
626        let d2 = monitor.step(0.7);
627        assert!(matches!(d2, EarlyStoppingDecision::NewBest { .. }));
628
629        let d3 = monitor.step(0.9);
630        assert!(matches!(d3, EarlyStoppingDecision::NewBest { .. }));
631
632        // Metric goes down — not an improvement in Maximize mode.
633        let d4 = monitor.step(0.8);
634        assert_eq!(d4, EarlyStoppingDecision::Continue);
635    }
636
637    #[test]
638    fn test_monitor_baseline_required() {
639        let config = EarlyStoppingConfig {
640            patience: 5,
641            baseline: Some(0.5),
642            mode: MonitorMode::Minimize,
643            ..Default::default()
644        };
645        let mut monitor = EarlyStoppingMonitor::new(config);
646
647        // 0.8 does NOT beat baseline of 0.5 (for Minimize, need < 0.5).
648        let d = monitor.step(0.8);
649        assert_eq!(d, EarlyStoppingDecision::Continue);
650        assert!(monitor.best_value().is_none());
651
652        // 0.4 DOES beat baseline.
653        let d = monitor.step(0.4);
654        assert!(matches!(d, EarlyStoppingDecision::NewBest { .. }));
655    }
656
657    #[test]
658    fn test_monitor_min_epochs_prevents_early_stop() {
659        let config = EarlyStoppingConfig {
660            patience: 1,
661            min_epochs: 5,
662            ..Default::default()
663        };
664        let mut monitor = EarlyStoppingMonitor::new(config);
665
666        // First step sets best.
667        monitor.step(1.0);
668
669        // No improvement, but we're under min_epochs (5), so Continue.
670        let d = monitor.step(2.0);
671        assert_eq!(d, EarlyStoppingDecision::Continue);
672
673        let d = monitor.step(2.0);
674        assert_eq!(d, EarlyStoppingDecision::Continue);
675
676        let d = monitor.step(2.0);
677        assert_eq!(d, EarlyStoppingDecision::Continue);
678
679        // Epoch 5 — now min_epochs is satisfied, patience (1) exceeded.
680        let d = monitor.step(2.0);
681        assert!(matches!(d, EarlyStoppingDecision::Stop { .. }));
682    }
683
684    #[test]
685    fn test_monitor_best_value_tracked() {
686        let mut monitor = EarlyStoppingMonitor::with_default();
687        monitor.step(1.0);
688        monitor.step(0.5);
689        monitor.step(0.8);
690        assert_eq!(monitor.best_value(), Some(0.5));
691        assert_eq!(monitor.best_epoch(), 2);
692    }
693
694    #[test]
695    fn test_monitor_reset() {
696        let mut monitor = EarlyStoppingMonitor::with_default();
697        monitor.step(1.0);
698        monitor.step(0.5);
699        assert!(monitor.best_value().is_some());
700
701        monitor.reset();
702        assert!(monitor.best_value().is_none());
703        assert_eq!(monitor.current_epoch(), 0);
704        assert!(monitor.history().is_empty());
705        assert!(!monitor.should_stop());
706    }
707
708    #[test]
709    fn test_monitor_history() {
710        let mut monitor = EarlyStoppingMonitor::with_default();
711        monitor.step(1.0);
712        monitor.step(0.8);
713        monitor.step(0.6);
714        assert_eq!(monitor.history().len(), 3);
715        assert_eq!(monitor.history(), &[1.0, 0.8, 0.6]);
716    }
717
718    #[test]
719    fn test_monitor_summary_nonempty() {
720        let mut monitor = EarlyStoppingMonitor::with_default();
721        monitor.step(1.0);
722        let summary = monitor.summary();
723        assert!(!summary.is_empty());
724        assert!(summary.contains("minimize"));
725    }
726
727    #[test]
728    fn test_multi_metric_any_policy() {
729        let mut mm = MultiMetricMonitor::new(MultiMetricPolicy::Any);
730        mm.add_metric(
731            "loss",
732            EarlyStoppingConfig {
733                patience: 2,
734                ..Default::default()
735            },
736        );
737        mm.add_metric(
738            "accuracy",
739            EarlyStoppingConfig {
740                patience: 100, // very high patience
741                mode: MonitorMode::Maximize,
742                ..Default::default()
743            },
744        );
745
746        // Step 1: both get new best.
747        let d = mm.step(&[("loss".to_string(), 1.0), ("accuracy".to_string(), 0.5)]);
748        assert!(matches!(d, EarlyStoppingDecision::NewBest { .. }));
749
750        // Step 2: loss not improving, accuracy not improving.
751        let d = mm.step(&[("loss".to_string(), 1.5), ("accuracy".to_string(), 0.3)]);
752        assert_eq!(d, EarlyStoppingDecision::Continue);
753
754        // Step 3: loss patience exhausted (2 epochs without improvement) → Any triggers stop.
755        let d = mm.step(&[("loss".to_string(), 1.5), ("accuracy".to_string(), 0.3)]);
756        assert!(matches!(d, EarlyStoppingDecision::Stop { .. }));
757    }
758
759    #[test]
760    fn test_multi_metric_all_policy() {
761        let mut mm = MultiMetricMonitor::new(MultiMetricPolicy::All);
762        mm.add_metric(
763            "loss",
764            EarlyStoppingConfig {
765                patience: 2,
766                ..Default::default()
767            },
768        );
769        mm.add_metric(
770            "accuracy",
771            EarlyStoppingConfig {
772                patience: 2,
773                mode: MonitorMode::Maximize,
774                ..Default::default()
775            },
776        );
777
778        // Step 1: both new best.
779        mm.step(&[("loss".to_string(), 1.0), ("accuracy".to_string(), 0.5)]);
780
781        // Step 2-3: no improvement in either.
782        mm.step(&[("loss".to_string(), 1.5), ("accuracy".to_string(), 0.3)]);
783
784        // Step 3: loss has patience=2 exhausted, accuracy has patience=2 exhausted → All triggers.
785        let d = mm.step(&[("loss".to_string(), 1.5), ("accuracy".to_string(), 0.3)]);
786        assert!(matches!(d, EarlyStoppingDecision::Stop { .. }));
787    }
788
789    #[test]
790    fn test_multi_metric_all_policy_no_stop_when_one_improving() {
791        let mut mm = MultiMetricMonitor::new(MultiMetricPolicy::All);
792        mm.add_metric(
793            "loss",
794            EarlyStoppingConfig {
795                patience: 2,
796                ..Default::default()
797            },
798        );
799        mm.add_metric(
800            "accuracy",
801            EarlyStoppingConfig {
802                patience: 2,
803                mode: MonitorMode::Maximize,
804                ..Default::default()
805            },
806        );
807
808        // Step 1
809        mm.step(&[("loss".to_string(), 1.0), ("accuracy".to_string(), 0.5)]);
810
811        // Step 2: loss stagnant, accuracy improving.
812        mm.step(&[("loss".to_string(), 1.5), ("accuracy".to_string(), 0.7)]);
813
814        // Step 3: loss patience exhausted but accuracy still improving → All does NOT stop.
815        let d = mm.step(&[("loss".to_string(), 1.5), ("accuracy".to_string(), 0.9)]);
816        assert!(!matches!(d, EarlyStoppingDecision::Stop { .. }));
817    }
818
819    #[test]
820    fn test_multi_metric_get_monitor() {
821        let mut mm = MultiMetricMonitor::new(MultiMetricPolicy::Any);
822        mm.add_metric("loss", EarlyStoppingConfig::default());
823        mm.add_metric(
824            "accuracy",
825            EarlyStoppingConfig {
826                mode: MonitorMode::Maximize,
827                ..Default::default()
828            },
829        );
830
831        assert!(mm.get_monitor("loss").is_some());
832        assert!(mm.get_monitor("accuracy").is_some());
833        assert!(mm.get_monitor("nonexistent").is_none());
834        assert_eq!(mm.num_metrics(), 2);
835    }
836
837    #[test]
838    fn test_multi_metric_summary() {
839        let mut mm = MultiMetricMonitor::new(MultiMetricPolicy::Any);
840        mm.add_metric("loss", EarlyStoppingConfig::default());
841        let summary = mm.summary();
842        assert!(!summary.is_empty());
843        assert!(summary.contains("loss"));
844    }
845
846    #[test]
847    fn test_plateau_detector_no_plateau() {
848        let mut detector = PlateauDetector::new(3, 0.001);
849        detector.push(1.0);
850        detector.push(2.0);
851        detector.push(3.0);
852        assert!(!detector.is_plateau());
853        assert!(detector.current_variance().is_some());
854    }
855
856    #[test]
857    fn test_plateau_detector_plateau() {
858        let mut detector = PlateauDetector::new(3, 0.001);
859        detector.push(1.0);
860        detector.push(1.0);
861        detector.push(1.0);
862        assert!(detector.is_plateau());
863        assert_eq!(detector.current_variance(), Some(0.0));
864    }
865
866    #[test]
867    fn test_plateau_detector_insufficient_data() {
868        let mut detector = PlateauDetector::new(5, 0.001);
869        detector.push(1.0);
870        detector.push(1.0);
871        assert!(!detector.is_plateau());
872        assert!(detector.current_variance().is_none());
873    }
874
875    #[test]
876    fn test_plateau_detector_window_slides() {
877        let mut detector = PlateauDetector::new(3, 0.001);
878        // Push varying values, then constant ones.
879        detector.push(1.0);
880        detector.push(5.0);
881        detector.push(10.0);
882        assert!(!detector.is_plateau()); // high variance
883
884        detector.push(2.0);
885        detector.push(2.0);
886        detector.push(2.0);
887        assert!(detector.is_plateau()); // window is [2.0, 2.0, 2.0]
888    }
889
890    #[test]
891    fn test_training_progress_advance() {
892        let mut progress = TrainingProgress::new(100);
893        assert_eq!(progress.current_epoch, 0);
894        progress.advance_epoch();
895        assert_eq!(progress.current_epoch, 1);
896        progress.advance_epoch();
897        assert_eq!(progress.current_epoch, 2);
898    }
899
900    #[test]
901    fn test_training_progress_best_minimize() {
902        let mut progress = TrainingProgress::new(10);
903        progress.record("loss", 1.0);
904        progress.record("loss", 0.5);
905        progress.record("loss", 0.8);
906
907        let best = progress.best("loss", &MonitorMode::Minimize);
908        assert_eq!(best, Some(0.5));
909    }
910
911    #[test]
912    fn test_training_progress_best_maximize() {
913        let mut progress = TrainingProgress::new(10);
914        progress.record("accuracy", 0.6);
915        progress.record("accuracy", 0.9);
916        progress.record("accuracy", 0.7);
917
918        let best = progress.best("accuracy", &MonitorMode::Maximize);
919        assert_eq!(best, Some(0.9));
920    }
921
922    #[test]
923    fn test_training_progress_latest() {
924        let mut progress = TrainingProgress::new(10);
925        progress.record("loss", 1.0);
926        progress.record("loss", 0.5);
927        assert_eq!(progress.latest("loss"), Some(0.5));
928        assert_eq!(progress.latest("nonexistent"), None);
929    }
930
931    #[test]
932    fn test_training_progress_fraction() {
933        let mut progress = TrainingProgress::new(10);
934        assert_eq!(progress.progress_fraction(), 0.0);
935        progress.advance_epoch();
936        progress.advance_epoch();
937        progress.advance_epoch();
938        assert!((progress.progress_fraction() - 0.3).abs() < 1e-10);
939    }
940
941    #[test]
942    fn test_training_progress_summary() {
943        let mut progress = TrainingProgress::new(10);
944        progress.advance_epoch();
945        progress.record("loss", 0.5);
946        let summary = progress.summary();
947        assert!(!summary.is_empty());
948        assert!(summary.contains("loss"));
949    }
950
951    #[test]
952    fn test_training_progress_zero_total_epochs() {
953        let progress = TrainingProgress::new(0);
954        assert_eq!(progress.progress_fraction(), 0.0);
955    }
956}