Skip to main content

tensorlogic_train/callbacks/
early_stopping.rs

1//! Early stopping and learning rate reduction callbacks.
2
3use crate::callbacks::core::Callback;
4use crate::{TrainResult, TrainingState};
5
6/// Callback for early stopping based on validation loss.
7pub struct EarlyStoppingCallback {
8    /// Number of epochs with no improvement after which training will be stopped.
9    pub patience: usize,
10    /// Minimum change to qualify as an improvement.
11    pub min_delta: f64,
12    /// Best validation loss seen so far.
13    best_val_loss: Option<f64>,
14    /// Counter for epochs without improvement.
15    wait: usize,
16    /// Whether to stop training.
17    stop_training: bool,
18}
19
20impl EarlyStoppingCallback {
21    /// Create a new early stopping callback.
22    pub fn new(patience: usize, min_delta: f64) -> Self {
23        Self {
24            patience,
25            min_delta,
26            best_val_loss: None,
27            wait: 0,
28            stop_training: false,
29        }
30    }
31}
32
33impl Callback for EarlyStoppingCallback {
34    fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
35        if let Some(val_loss) = state.val_loss {
36            let improved = self
37                .best_val_loss
38                .map(|best| val_loss < best - self.min_delta)
39                .unwrap_or(true);
40
41            if improved {
42                self.best_val_loss = Some(val_loss);
43                self.wait = 0;
44            } else {
45                self.wait += 1;
46                if self.wait >= self.patience {
47                    println!(
48                        "Early stopping at epoch {} (no improvement for {} epochs)",
49                        epoch, self.patience
50                    );
51                    self.stop_training = true;
52                }
53            }
54        }
55
56        Ok(())
57    }
58
59    fn should_stop(&self) -> bool {
60        self.stop_training
61    }
62}
63
64/// Callback for learning rate reduction on plateau.
65#[allow(dead_code)]
66pub struct ReduceLrOnPlateauCallback {
67    /// Factor by which to reduce learning rate.
68    pub factor: f64,
69    /// Number of epochs with no improvement after which learning rate will be reduced.
70    pub patience: usize,
71    /// Minimum change to qualify as an improvement.
72    pub min_delta: f64,
73    /// Lower bound on the learning rate.
74    pub min_lr: f64,
75    /// Best validation loss seen so far.
76    best_val_loss: Option<f64>,
77    /// Counter for epochs without improvement.
78    wait: usize,
79}
80
81impl ReduceLrOnPlateauCallback {
82    /// Create a new reduce LR on plateau callback.
83    #[allow(dead_code)]
84    pub fn new(factor: f64, patience: usize, min_delta: f64, min_lr: f64) -> Self {
85        Self {
86            factor,
87            patience,
88            min_delta,
89            min_lr,
90            best_val_loss: None,
91            wait: 0,
92        }
93    }
94}
95
96impl Callback for ReduceLrOnPlateauCallback {
97    fn on_epoch_end(&mut self, _epoch: usize, state: &TrainingState) -> TrainResult<()> {
98        if let Some(val_loss) = state.val_loss {
99            let improved = self
100                .best_val_loss
101                .map(|best| val_loss < best - self.min_delta)
102                .unwrap_or(true);
103
104            if improved {
105                self.best_val_loss = Some(val_loss);
106                self.wait = 0;
107            } else {
108                self.wait += 1;
109                if self.wait >= self.patience {
110                    // Note: We can't actually modify the optimizer here since we don't have a reference
111                    // This would need to be handled by the Trainer
112                    let new_lr = (state.learning_rate * self.factor).max(self.min_lr);
113                    if new_lr != state.learning_rate {
114                        println!("Reducing learning rate to {:.6}", new_lr);
115                    }
116                    self.wait = 0;
117                }
118            }
119        }
120
121        Ok(())
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use std::collections::HashMap;
129
130    fn create_test_state() -> TrainingState {
131        TrainingState {
132            epoch: 0,
133            batch: 0,
134            train_loss: 1.0,
135            val_loss: Some(0.8),
136            batch_loss: 0.5,
137            learning_rate: 0.001,
138            metrics: HashMap::new(),
139        }
140    }
141
142    #[test]
143    fn test_early_stopping() {
144        let mut callback = EarlyStoppingCallback::new(2, 0.01);
145        let mut state = create_test_state();
146
147        // First epoch - improvement
148        state.val_loss = Some(1.0);
149        callback.on_epoch_end(0, &state).unwrap();
150        assert!(!callback.should_stop());
151
152        // Second epoch - improvement
153        state.val_loss = Some(0.8);
154        callback.on_epoch_end(1, &state).unwrap();
155        assert!(!callback.should_stop());
156
157        // Third epoch - no improvement
158        state.val_loss = Some(0.81);
159        callback.on_epoch_end(2, &state).unwrap();
160        assert!(!callback.should_stop());
161
162        // Fourth epoch - no improvement (exceeds patience)
163        state.val_loss = Some(0.82);
164        callback.on_epoch_end(3, &state).unwrap();
165        assert!(callback.should_stop());
166    }
167}