tensorlogic_train/callbacks/
early_stopping.rs1use crate::callbacks::core::Callback;
4use crate::{TrainResult, TrainingState};
5
6pub struct EarlyStoppingCallback {
8 pub patience: usize,
10 pub min_delta: f64,
12 best_val_loss: Option<f64>,
14 wait: usize,
16 stop_training: bool,
18}
19
20impl EarlyStoppingCallback {
21 pub fn new(patience: usize, min_delta: f64) -> Self {
23 Self {
24 patience,
25 min_delta,
26 best_val_loss: None,
27 wait: 0,
28 stop_training: false,
29 }
30 }
31}
32
33impl Callback for EarlyStoppingCallback {
34 fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
35 if let Some(val_loss) = state.val_loss {
36 let improved = self
37 .best_val_loss
38 .map(|best| val_loss < best - self.min_delta)
39 .unwrap_or(true);
40
41 if improved {
42 self.best_val_loss = Some(val_loss);
43 self.wait = 0;
44 } else {
45 self.wait += 1;
46 if self.wait >= self.patience {
47 println!(
48 "Early stopping at epoch {} (no improvement for {} epochs)",
49 epoch, self.patience
50 );
51 self.stop_training = true;
52 }
53 }
54 }
55
56 Ok(())
57 }
58
59 fn should_stop(&self) -> bool {
60 self.stop_training
61 }
62}
63
64#[allow(dead_code)]
66pub struct ReduceLrOnPlateauCallback {
67 pub factor: f64,
69 pub patience: usize,
71 pub min_delta: f64,
73 pub min_lr: f64,
75 best_val_loss: Option<f64>,
77 wait: usize,
79}
80
81impl ReduceLrOnPlateauCallback {
82 #[allow(dead_code)]
84 pub fn new(factor: f64, patience: usize, min_delta: f64, min_lr: f64) -> Self {
85 Self {
86 factor,
87 patience,
88 min_delta,
89 min_lr,
90 best_val_loss: None,
91 wait: 0,
92 }
93 }
94}
95
96impl Callback for ReduceLrOnPlateauCallback {
97 fn on_epoch_end(&mut self, _epoch: usize, state: &TrainingState) -> TrainResult<()> {
98 if let Some(val_loss) = state.val_loss {
99 let improved = self
100 .best_val_loss
101 .map(|best| val_loss < best - self.min_delta)
102 .unwrap_or(true);
103
104 if improved {
105 self.best_val_loss = Some(val_loss);
106 self.wait = 0;
107 } else {
108 self.wait += 1;
109 if self.wait >= self.patience {
110 let new_lr = (state.learning_rate * self.factor).max(self.min_lr);
113 if new_lr != state.learning_rate {
114 println!("Reducing learning rate to {:.6}", new_lr);
115 }
116 self.wait = 0;
117 }
118 }
119 }
120
121 Ok(())
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use std::collections::HashMap;
129
130 fn create_test_state() -> TrainingState {
131 TrainingState {
132 epoch: 0,
133 batch: 0,
134 train_loss: 1.0,
135 val_loss: Some(0.8),
136 batch_loss: 0.5,
137 learning_rate: 0.001,
138 metrics: HashMap::new(),
139 }
140 }
141
142 #[test]
143 fn test_early_stopping() {
144 let mut callback = EarlyStoppingCallback::new(2, 0.01);
145 let mut state = create_test_state();
146
147 state.val_loss = Some(1.0);
149 callback.on_epoch_end(0, &state).unwrap();
150 assert!(!callback.should_stop());
151
152 state.val_loss = Some(0.8);
154 callback.on_epoch_end(1, &state).unwrap();
155 assert!(!callback.should_stop());
156
157 state.val_loss = Some(0.81);
159 callback.on_epoch_end(2, &state).unwrap();
160 assert!(!callback.should_stop());
161
162 state.val_loss = Some(0.82);
164 callback.on_epoch_end(3, &state).unwrap();
165 assert!(callback.should_stop());
166 }
167}