1use std::collections::HashMap;
7
8#[derive(Debug, Clone)]
10pub struct EarlyStoppingConfig {
11 pub patience: usize,
13 pub min_delta: f64,
15 pub mode: MonitorMode,
17 pub baseline: Option<f64>,
19 pub restore_best: bool,
21 pub min_epochs: usize,
23}
24
25impl Default for EarlyStoppingConfig {
26 fn default() -> Self {
27 Self {
28 patience: 10,
29 min_delta: 0.0,
30 mode: MonitorMode::Minimize,
31 baseline: None,
32 restore_best: true,
33 min_epochs: 1,
34 }
35 }
36}
37
38#[derive(Debug, Clone, PartialEq)]
40pub enum MonitorMode {
41 Minimize,
43 Maximize,
45}
46
47#[derive(Debug, Clone, PartialEq)]
49pub enum EarlyStoppingDecision {
50 Continue,
52 Stop {
54 reason: String,
56 },
57 NewBest {
59 value: f64,
61 epoch: usize,
63 },
64}
65
66#[derive(Debug, Clone)]
71pub struct EarlyStoppingMonitor {
72 config: EarlyStoppingConfig,
73 best_value: Option<f64>,
74 best_epoch: usize,
75 epochs_without_improvement: usize,
76 current_epoch: usize,
77 history: Vec<f64>,
78 stopped: bool,
79}
80
81impl EarlyStoppingMonitor {
82 pub fn new(config: EarlyStoppingConfig) -> Self {
84 Self {
85 config,
86 best_value: None,
87 best_epoch: 0,
88 epochs_without_improvement: 0,
89 current_epoch: 0,
90 history: Vec::new(),
91 stopped: false,
92 }
93 }
94
95 pub fn with_default() -> Self {
97 Self::new(EarlyStoppingConfig::default())
98 }
99
100 pub fn step(&mut self, metric_value: f64) -> EarlyStoppingDecision {
105 self.current_epoch += 1;
106 self.history.push(metric_value);
107
108 if let Some(baseline) = self.config.baseline {
111 let beats_baseline = match self.config.mode {
112 MonitorMode::Minimize => metric_value < baseline,
113 MonitorMode::Maximize => metric_value > baseline,
114 };
115 if !beats_baseline {
116 self.epochs_without_improvement += 1;
117 return self.evaluate_stop();
118 }
119 }
120
121 let is_new_best = match self.best_value {
123 None => true,
124 Some(best) => self.is_improvement(metric_value, best),
125 };
126
127 if is_new_best {
128 self.best_value = Some(metric_value);
129 self.best_epoch = self.current_epoch;
130 self.epochs_without_improvement = 0;
131 EarlyStoppingDecision::NewBest {
132 value: metric_value,
133 epoch: self.current_epoch,
134 }
135 } else {
136 self.epochs_without_improvement += 1;
137 self.evaluate_stop()
138 }
139 }
140
141 fn evaluate_stop(&mut self) -> EarlyStoppingDecision {
143 if self.current_epoch < self.config.min_epochs {
144 return EarlyStoppingDecision::Continue;
145 }
146
147 if self.epochs_without_improvement >= self.config.patience {
148 self.stopped = true;
149 let best_str = self
150 .best_value
151 .map(|v| format!("{v:.6}"))
152 .unwrap_or_else(|| "N/A".to_string());
153 EarlyStoppingDecision::Stop {
154 reason: format!(
155 "No improvement for {} epochs. Best value: {} at epoch {}.",
156 self.config.patience, best_str, self.best_epoch
157 ),
158 }
159 } else {
160 EarlyStoppingDecision::Continue
161 }
162 }
163
164 pub fn should_stop(&self) -> bool {
166 self.stopped
167 }
168
169 pub fn best_value(&self) -> Option<f64> {
171 self.best_value
172 }
173
174 pub fn best_epoch(&self) -> usize {
176 self.best_epoch
177 }
178
179 pub fn current_epoch(&self) -> usize {
181 self.current_epoch
182 }
183
184 pub fn epochs_since_improvement(&self) -> usize {
186 self.epochs_without_improvement
187 }
188
189 pub fn history(&self) -> &[f64] {
191 &self.history
192 }
193
194 pub fn reset(&mut self) {
196 self.best_value = None;
197 self.best_epoch = 0;
198 self.epochs_without_improvement = 0;
199 self.current_epoch = 0;
200 self.history.clear();
201 self.stopped = false;
202 }
203
204 fn is_improvement(&self, current: f64, best: f64) -> bool {
209 match self.config.mode {
210 MonitorMode::Minimize => current < best - self.config.min_delta,
211 MonitorMode::Maximize => current > best + self.config.min_delta,
212 }
213 }
214
215 pub fn summary(&self) -> String {
217 let best_str = self
218 .best_value
219 .map(|v| format!("{v:.6}"))
220 .unwrap_or_else(|| "N/A".to_string());
221 let mode_str = match self.config.mode {
222 MonitorMode::Minimize => "minimize",
223 MonitorMode::Maximize => "maximize",
224 };
225 format!(
226 "EarlyStoppingMonitor(mode={}, epoch={}, best={} at epoch {}, \
227 patience={}/{}, stopped={})",
228 mode_str,
229 self.current_epoch,
230 best_str,
231 self.best_epoch,
232 self.epochs_without_improvement,
233 self.config.patience,
234 self.stopped,
235 )
236 }
237}
238
239#[derive(Debug, Clone, PartialEq)]
241pub enum MultiMetricPolicy {
242 All,
244 Any,
246}
247
248#[derive(Debug, Clone)]
253pub struct MultiMetricMonitor {
254 monitors: Vec<(String, EarlyStoppingMonitor)>,
255 policy: MultiMetricPolicy,
256}
257
258impl MultiMetricMonitor {
259 pub fn new(policy: MultiMetricPolicy) -> Self {
261 Self {
262 monitors: Vec::new(),
263 policy,
264 }
265 }
266
267 pub fn add_metric(&mut self, name: impl Into<String>, config: EarlyStoppingConfig) {
269 let name = name.into();
270 let monitor = EarlyStoppingMonitor::new(config);
271 self.monitors.push((name, monitor));
272 }
273
274 pub fn step(&mut self, values: &[(String, f64)]) -> EarlyStoppingDecision {
279 let value_map: HashMap<&str, f64> = values.iter().map(|(k, v)| (k.as_str(), *v)).collect();
280
281 let mut decisions = Vec::new();
282
283 for (name, monitor) in &mut self.monitors {
284 if let Some(&val) = value_map.get(name.as_str()) {
285 let decision = monitor.step(val);
286 decisions.push(decision);
287 }
288 }
289
290 let any_stop = decisions
292 .iter()
293 .any(|d| matches!(d, EarlyStoppingDecision::Stop { .. }));
294 let all_stop = !decisions.is_empty()
295 && decisions
296 .iter()
297 .all(|d| matches!(d, EarlyStoppingDecision::Stop { .. }));
298
299 let should_stop = match self.policy {
300 MultiMetricPolicy::All => all_stop,
301 MultiMetricPolicy::Any => any_stop,
302 };
303
304 if should_stop {
305 let reasons: Vec<String> = decisions
307 .into_iter()
308 .filter_map(|d| {
309 if let EarlyStoppingDecision::Stop { reason } = d {
310 Some(reason)
311 } else {
312 None
313 }
314 })
315 .collect();
316 EarlyStoppingDecision::Stop {
317 reason: reasons.join("; "),
318 }
319 } else {
320 let new_best = decisions
322 .iter()
323 .find(|d| matches!(d, EarlyStoppingDecision::NewBest { .. }));
324 match new_best {
325 Some(EarlyStoppingDecision::NewBest { value, epoch }) => {
326 EarlyStoppingDecision::NewBest {
327 value: *value,
328 epoch: *epoch,
329 }
330 }
331 _ => EarlyStoppingDecision::Continue,
332 }
333 }
334 }
335
336 pub fn get_monitor(&self, name: &str) -> Option<&EarlyStoppingMonitor> {
338 self.monitors
339 .iter()
340 .find(|(n, _)| n == name)
341 .map(|(_, m)| m)
342 }
343
344 pub fn num_metrics(&self) -> usize {
346 self.monitors.len()
347 }
348
349 pub fn summary(&self) -> String {
351 let mut parts = Vec::new();
352 parts.push(format!(
353 "MultiMetricMonitor(policy={:?}, metrics={})",
354 self.policy,
355 self.monitors.len()
356 ));
357 for (name, monitor) in &self.monitors {
358 parts.push(format!(" {}: {}", name, monitor.summary()));
359 }
360 parts.join("\n")
361 }
362}
363
364#[derive(Debug, Clone)]
369pub struct PlateauDetector {
370 pub window_size: usize,
372 pub variance_threshold: f64,
374 history: Vec<f64>,
375}
376
377impl PlateauDetector {
378 pub fn new(window_size: usize, variance_threshold: f64) -> Self {
380 Self {
381 window_size,
382 variance_threshold,
383 history: Vec::new(),
384 }
385 }
386
387 pub fn push(&mut self, value: f64) {
389 self.history.push(value);
390 }
391
392 pub fn is_plateau(&self) -> bool {
394 if self.history.len() < self.window_size {
395 return false;
396 }
397 match self.current_variance() {
398 Some(var) => var < self.variance_threshold,
399 None => false,
400 }
401 }
402
403 pub fn current_variance(&self) -> Option<f64> {
407 if self.history.len() < self.window_size {
408 return None;
409 }
410 let window = self.values_in_window();
411 let n = window.len() as f64;
412 if n < 1.0 {
413 return None;
414 }
415 let mean = window.iter().sum::<f64>() / n;
416 let variance = window.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
417 Some(variance)
418 }
419
420 pub fn values_in_window(&self) -> &[f64] {
422 if self.history.len() < self.window_size {
423 &self.history
424 } else {
425 &self.history[self.history.len() - self.window_size..]
426 }
427 }
428}
429
430#[derive(Debug, Clone)]
434pub struct TrainingProgress {
435 pub total_epochs: usize,
437 pub current_epoch: usize,
439 pub metrics: HashMap<String, Vec<f64>>,
441}
442
443impl TrainingProgress {
444 pub fn new(total_epochs: usize) -> Self {
446 Self {
447 total_epochs,
448 current_epoch: 0,
449 metrics: HashMap::new(),
450 }
451 }
452
453 pub fn record(&mut self, metric_name: impl Into<String>, value: f64) {
455 self.metrics
456 .entry(metric_name.into())
457 .or_default()
458 .push(value);
459 }
460
461 pub fn progress_fraction(&self) -> f64 {
463 if self.total_epochs == 0 {
464 return 0.0;
465 }
466 self.current_epoch as f64 / self.total_epochs as f64
467 }
468
469 pub fn advance_epoch(&mut self) {
471 self.current_epoch += 1;
472 }
473
474 pub fn latest(&self, metric_name: &str) -> Option<f64> {
476 self.metrics
477 .get(metric_name)
478 .and_then(|v| v.last().copied())
479 }
480
481 pub fn best(&self, metric_name: &str, mode: &MonitorMode) -> Option<f64> {
483 self.metrics.get(metric_name).and_then(|values| {
484 if values.is_empty() {
485 return None;
486 }
487 match mode {
488 MonitorMode::Minimize => values
489 .iter()
490 .copied()
491 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)),
492 MonitorMode::Maximize => values
493 .iter()
494 .copied()
495 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)),
496 }
497 })
498 }
499
500 pub fn summary(&self) -> String {
502 let pct = self.progress_fraction() * 100.0;
503 let mut parts = vec![format!(
504 "TrainingProgress: epoch {}/{} ({:.1}%)",
505 self.current_epoch, self.total_epochs, pct
506 )];
507 for (name, values) in &self.metrics {
508 let latest = values.last().map(|v| format!("{v:.6}")).unwrap_or_default();
509 parts.push(format!(
510 " {}: latest={}, entries={}",
511 name,
512 latest,
513 values.len()
514 ));
515 }
516 parts.join("\n")
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523
524 #[test]
525 fn test_early_stopping_config_default() {
526 let config = EarlyStoppingConfig::default();
527 assert_eq!(config.patience, 10);
528 assert_eq!(config.min_delta, 0.0);
529 assert_eq!(config.mode, MonitorMode::Minimize);
530 assert!(config.baseline.is_none());
531 assert!(config.restore_best);
532 assert_eq!(config.min_epochs, 1);
533 }
534
535 #[test]
536 fn test_monitor_new_best_on_first_step() {
537 let mut monitor = EarlyStoppingMonitor::with_default();
538 let decision = monitor.step(1.0);
539 assert_eq!(
540 decision,
541 EarlyStoppingDecision::NewBest {
542 value: 1.0,
543 epoch: 1
544 }
545 );
546 }
547
548 #[test]
549 fn test_monitor_continue_while_improving() {
550 let config = EarlyStoppingConfig {
551 patience: 3,
552 ..Default::default()
553 };
554 let mut monitor = EarlyStoppingMonitor::new(config);
555
556 let d1 = monitor.step(1.0);
558 assert!(matches!(d1, EarlyStoppingDecision::NewBest { .. }));
559
560 let d2 = monitor.step(0.8);
561 assert!(matches!(d2, EarlyStoppingDecision::NewBest { .. }));
562
563 let d3 = monitor.step(0.6);
564 assert!(matches!(d3, EarlyStoppingDecision::NewBest { .. }));
565
566 let d4 = monitor.step(0.4);
567 assert!(matches!(d4, EarlyStoppingDecision::NewBest { .. }));
568 }
569
570 #[test]
571 fn test_monitor_stop_after_patience() {
572 let config = EarlyStoppingConfig {
573 patience: 3,
574 ..Default::default()
575 };
576 let mut monitor = EarlyStoppingMonitor::new(config);
577
578 monitor.step(1.0);
580
581 let d1 = monitor.step(1.5);
583 assert_eq!(d1, EarlyStoppingDecision::Continue);
584
585 let d2 = monitor.step(1.5);
586 assert_eq!(d2, EarlyStoppingDecision::Continue);
587
588 let d3 = monitor.step(1.5);
589 assert!(matches!(d3, EarlyStoppingDecision::Stop { .. }));
590 assert!(monitor.should_stop());
591 }
592
593 #[test]
594 fn test_monitor_min_delta_threshold() {
595 let config = EarlyStoppingConfig {
596 patience: 2,
597 min_delta: 0.1,
598 ..Default::default()
599 };
600 let mut monitor = EarlyStoppingMonitor::new(config);
601
602 monitor.step(1.0);
604
605 let d = monitor.step(0.95);
607 assert_eq!(d, EarlyStoppingDecision::Continue);
608
609 let d = monitor.step(0.89);
611 assert!(matches!(d, EarlyStoppingDecision::NewBest { .. }));
612 }
613
614 #[test]
615 fn test_monitor_maximize_mode() {
616 let config = EarlyStoppingConfig {
617 patience: 3,
618 mode: MonitorMode::Maximize,
619 ..Default::default()
620 };
621 let mut monitor = EarlyStoppingMonitor::new(config);
622
623 let d1 = monitor.step(0.5);
624 assert!(matches!(d1, EarlyStoppingDecision::NewBest { .. }));
625
626 let d2 = monitor.step(0.7);
627 assert!(matches!(d2, EarlyStoppingDecision::NewBest { .. }));
628
629 let d3 = monitor.step(0.9);
630 assert!(matches!(d3, EarlyStoppingDecision::NewBest { .. }));
631
632 let d4 = monitor.step(0.8);
634 assert_eq!(d4, EarlyStoppingDecision::Continue);
635 }
636
637 #[test]
638 fn test_monitor_baseline_required() {
639 let config = EarlyStoppingConfig {
640 patience: 5,
641 baseline: Some(0.5),
642 mode: MonitorMode::Minimize,
643 ..Default::default()
644 };
645 let mut monitor = EarlyStoppingMonitor::new(config);
646
647 let d = monitor.step(0.8);
649 assert_eq!(d, EarlyStoppingDecision::Continue);
650 assert!(monitor.best_value().is_none());
651
652 let d = monitor.step(0.4);
654 assert!(matches!(d, EarlyStoppingDecision::NewBest { .. }));
655 }
656
657 #[test]
658 fn test_monitor_min_epochs_prevents_early_stop() {
659 let config = EarlyStoppingConfig {
660 patience: 1,
661 min_epochs: 5,
662 ..Default::default()
663 };
664 let mut monitor = EarlyStoppingMonitor::new(config);
665
666 monitor.step(1.0);
668
669 let d = monitor.step(2.0);
671 assert_eq!(d, EarlyStoppingDecision::Continue);
672
673 let d = monitor.step(2.0);
674 assert_eq!(d, EarlyStoppingDecision::Continue);
675
676 let d = monitor.step(2.0);
677 assert_eq!(d, EarlyStoppingDecision::Continue);
678
679 let d = monitor.step(2.0);
681 assert!(matches!(d, EarlyStoppingDecision::Stop { .. }));
682 }
683
684 #[test]
685 fn test_monitor_best_value_tracked() {
686 let mut monitor = EarlyStoppingMonitor::with_default();
687 monitor.step(1.0);
688 monitor.step(0.5);
689 monitor.step(0.8);
690 assert_eq!(monitor.best_value(), Some(0.5));
691 assert_eq!(monitor.best_epoch(), 2);
692 }
693
694 #[test]
695 fn test_monitor_reset() {
696 let mut monitor = EarlyStoppingMonitor::with_default();
697 monitor.step(1.0);
698 monitor.step(0.5);
699 assert!(monitor.best_value().is_some());
700
701 monitor.reset();
702 assert!(monitor.best_value().is_none());
703 assert_eq!(monitor.current_epoch(), 0);
704 assert!(monitor.history().is_empty());
705 assert!(!monitor.should_stop());
706 }
707
708 #[test]
709 fn test_monitor_history() {
710 let mut monitor = EarlyStoppingMonitor::with_default();
711 monitor.step(1.0);
712 monitor.step(0.8);
713 monitor.step(0.6);
714 assert_eq!(monitor.history().len(), 3);
715 assert_eq!(monitor.history(), &[1.0, 0.8, 0.6]);
716 }
717
718 #[test]
719 fn test_monitor_summary_nonempty() {
720 let mut monitor = EarlyStoppingMonitor::with_default();
721 monitor.step(1.0);
722 let summary = monitor.summary();
723 assert!(!summary.is_empty());
724 assert!(summary.contains("minimize"));
725 }
726
727 #[test]
728 fn test_multi_metric_any_policy() {
729 let mut mm = MultiMetricMonitor::new(MultiMetricPolicy::Any);
730 mm.add_metric(
731 "loss",
732 EarlyStoppingConfig {
733 patience: 2,
734 ..Default::default()
735 },
736 );
737 mm.add_metric(
738 "accuracy",
739 EarlyStoppingConfig {
740 patience: 100, mode: MonitorMode::Maximize,
742 ..Default::default()
743 },
744 );
745
746 let d = mm.step(&[("loss".to_string(), 1.0), ("accuracy".to_string(), 0.5)]);
748 assert!(matches!(d, EarlyStoppingDecision::NewBest { .. }));
749
750 let d = mm.step(&[("loss".to_string(), 1.5), ("accuracy".to_string(), 0.3)]);
752 assert_eq!(d, EarlyStoppingDecision::Continue);
753
754 let d = mm.step(&[("loss".to_string(), 1.5), ("accuracy".to_string(), 0.3)]);
756 assert!(matches!(d, EarlyStoppingDecision::Stop { .. }));
757 }
758
759 #[test]
760 fn test_multi_metric_all_policy() {
761 let mut mm = MultiMetricMonitor::new(MultiMetricPolicy::All);
762 mm.add_metric(
763 "loss",
764 EarlyStoppingConfig {
765 patience: 2,
766 ..Default::default()
767 },
768 );
769 mm.add_metric(
770 "accuracy",
771 EarlyStoppingConfig {
772 patience: 2,
773 mode: MonitorMode::Maximize,
774 ..Default::default()
775 },
776 );
777
778 mm.step(&[("loss".to_string(), 1.0), ("accuracy".to_string(), 0.5)]);
780
781 mm.step(&[("loss".to_string(), 1.5), ("accuracy".to_string(), 0.3)]);
783
784 let d = mm.step(&[("loss".to_string(), 1.5), ("accuracy".to_string(), 0.3)]);
786 assert!(matches!(d, EarlyStoppingDecision::Stop { .. }));
787 }
788
789 #[test]
790 fn test_multi_metric_all_policy_no_stop_when_one_improving() {
791 let mut mm = MultiMetricMonitor::new(MultiMetricPolicy::All);
792 mm.add_metric(
793 "loss",
794 EarlyStoppingConfig {
795 patience: 2,
796 ..Default::default()
797 },
798 );
799 mm.add_metric(
800 "accuracy",
801 EarlyStoppingConfig {
802 patience: 2,
803 mode: MonitorMode::Maximize,
804 ..Default::default()
805 },
806 );
807
808 mm.step(&[("loss".to_string(), 1.0), ("accuracy".to_string(), 0.5)]);
810
811 mm.step(&[("loss".to_string(), 1.5), ("accuracy".to_string(), 0.7)]);
813
814 let d = mm.step(&[("loss".to_string(), 1.5), ("accuracy".to_string(), 0.9)]);
816 assert!(!matches!(d, EarlyStoppingDecision::Stop { .. }));
817 }
818
819 #[test]
820 fn test_multi_metric_get_monitor() {
821 let mut mm = MultiMetricMonitor::new(MultiMetricPolicy::Any);
822 mm.add_metric("loss", EarlyStoppingConfig::default());
823 mm.add_metric(
824 "accuracy",
825 EarlyStoppingConfig {
826 mode: MonitorMode::Maximize,
827 ..Default::default()
828 },
829 );
830
831 assert!(mm.get_monitor("loss").is_some());
832 assert!(mm.get_monitor("accuracy").is_some());
833 assert!(mm.get_monitor("nonexistent").is_none());
834 assert_eq!(mm.num_metrics(), 2);
835 }
836
837 #[test]
838 fn test_multi_metric_summary() {
839 let mut mm = MultiMetricMonitor::new(MultiMetricPolicy::Any);
840 mm.add_metric("loss", EarlyStoppingConfig::default());
841 let summary = mm.summary();
842 assert!(!summary.is_empty());
843 assert!(summary.contains("loss"));
844 }
845
846 #[test]
847 fn test_plateau_detector_no_plateau() {
848 let mut detector = PlateauDetector::new(3, 0.001);
849 detector.push(1.0);
850 detector.push(2.0);
851 detector.push(3.0);
852 assert!(!detector.is_plateau());
853 assert!(detector.current_variance().is_some());
854 }
855
856 #[test]
857 fn test_plateau_detector_plateau() {
858 let mut detector = PlateauDetector::new(3, 0.001);
859 detector.push(1.0);
860 detector.push(1.0);
861 detector.push(1.0);
862 assert!(detector.is_plateau());
863 assert_eq!(detector.current_variance(), Some(0.0));
864 }
865
866 #[test]
867 fn test_plateau_detector_insufficient_data() {
868 let mut detector = PlateauDetector::new(5, 0.001);
869 detector.push(1.0);
870 detector.push(1.0);
871 assert!(!detector.is_plateau());
872 assert!(detector.current_variance().is_none());
873 }
874
875 #[test]
876 fn test_plateau_detector_window_slides() {
877 let mut detector = PlateauDetector::new(3, 0.001);
878 detector.push(1.0);
880 detector.push(5.0);
881 detector.push(10.0);
882 assert!(!detector.is_plateau()); detector.push(2.0);
885 detector.push(2.0);
886 detector.push(2.0);
887 assert!(detector.is_plateau()); }
889
890 #[test]
891 fn test_training_progress_advance() {
892 let mut progress = TrainingProgress::new(100);
893 assert_eq!(progress.current_epoch, 0);
894 progress.advance_epoch();
895 assert_eq!(progress.current_epoch, 1);
896 progress.advance_epoch();
897 assert_eq!(progress.current_epoch, 2);
898 }
899
900 #[test]
901 fn test_training_progress_best_minimize() {
902 let mut progress = TrainingProgress::new(10);
903 progress.record("loss", 1.0);
904 progress.record("loss", 0.5);
905 progress.record("loss", 0.8);
906
907 let best = progress.best("loss", &MonitorMode::Minimize);
908 assert_eq!(best, Some(0.5));
909 }
910
911 #[test]
912 fn test_training_progress_best_maximize() {
913 let mut progress = TrainingProgress::new(10);
914 progress.record("accuracy", 0.6);
915 progress.record("accuracy", 0.9);
916 progress.record("accuracy", 0.7);
917
918 let best = progress.best("accuracy", &MonitorMode::Maximize);
919 assert_eq!(best, Some(0.9));
920 }
921
922 #[test]
923 fn test_training_progress_latest() {
924 let mut progress = TrainingProgress::new(10);
925 progress.record("loss", 1.0);
926 progress.record("loss", 0.5);
927 assert_eq!(progress.latest("loss"), Some(0.5));
928 assert_eq!(progress.latest("nonexistent"), None);
929 }
930
931 #[test]
932 fn test_training_progress_fraction() {
933 let mut progress = TrainingProgress::new(10);
934 assert_eq!(progress.progress_fraction(), 0.0);
935 progress.advance_epoch();
936 progress.advance_epoch();
937 progress.advance_epoch();
938 assert!((progress.progress_fraction() - 0.3).abs() < 1e-10);
939 }
940
941 #[test]
942 fn test_training_progress_summary() {
943 let mut progress = TrainingProgress::new(10);
944 progress.advance_epoch();
945 progress.record("loss", 0.5);
946 let summary = progress.summary();
947 assert!(!summary.is_empty());
948 assert!(summary.contains("loss"));
949 }
950
951 #[test]
952 fn test_training_progress_zero_total_epochs() {
953 let progress = TrainingProgress::new(0);
954 assert_eq!(progress.progress_fraction(), 0.0);
955 }
956}