Skip to main content

tensorlogic_train/callbacks/
core.rs

1//! Core callback infrastructure for training.
2
3use crate::{TrainResult, TrainingState};
4
5/// Trait for training callbacks.
6pub trait Callback {
7    /// Called at the beginning of training.
8    fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()> {
9        Ok(())
10    }
11
12    /// Called at the end of training.
13    fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
14        Ok(())
15    }
16
17    /// Called at the beginning of an epoch.
18    fn on_epoch_begin(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
19        Ok(())
20    }
21
22    /// Called at the end of an epoch.
23    fn on_epoch_end(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
24        Ok(())
25    }
26
27    /// Called at the beginning of a batch.
28    fn on_batch_begin(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
29        Ok(())
30    }
31
32    /// Called at the end of a batch.
33    fn on_batch_end(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
34        Ok(())
35    }
36
37    /// Called after validation.
38    fn on_validation_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
39        Ok(())
40    }
41
42    /// Check if training should stop early.
43    fn should_stop(&self) -> bool {
44        false
45    }
46}
47
48/// List of callbacks to execute in order.
49pub struct CallbackList {
50    callbacks: Vec<Box<dyn Callback>>,
51}
52
53impl CallbackList {
54    /// Create a new callback list.
55    pub fn new() -> Self {
56        Self {
57            callbacks: Vec::new(),
58        }
59    }
60
61    /// Add a callback to the list.
62    pub fn add(&mut self, callback: Box<dyn Callback>) {
63        self.callbacks.push(callback);
64    }
65
66    /// Execute on_train_begin for all callbacks.
67    pub fn on_train_begin(&mut self, state: &TrainingState) -> TrainResult<()> {
68        for callback in &mut self.callbacks {
69            callback.on_train_begin(state)?;
70        }
71        Ok(())
72    }
73
74    /// Execute on_train_end for all callbacks.
75    pub fn on_train_end(&mut self, state: &TrainingState) -> TrainResult<()> {
76        for callback in &mut self.callbacks {
77            callback.on_train_end(state)?;
78        }
79        Ok(())
80    }
81
82    /// Execute on_epoch_begin for all callbacks.
83    pub fn on_epoch_begin(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
84        for callback in &mut self.callbacks {
85            callback.on_epoch_begin(epoch, state)?;
86        }
87        Ok(())
88    }
89
90    /// Execute on_epoch_end for all callbacks.
91    pub fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
92        for callback in &mut self.callbacks {
93            callback.on_epoch_end(epoch, state)?;
94        }
95        Ok(())
96    }
97
98    /// Execute on_batch_begin for all callbacks.
99    pub fn on_batch_begin(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
100        for callback in &mut self.callbacks {
101            callback.on_batch_begin(batch, state)?;
102        }
103        Ok(())
104    }
105
106    /// Execute on_batch_end for all callbacks.
107    pub fn on_batch_end(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
108        for callback in &mut self.callbacks {
109            callback.on_batch_end(batch, state)?;
110        }
111        Ok(())
112    }
113
114    /// Execute on_validation_end for all callbacks.
115    pub fn on_validation_end(&mut self, state: &TrainingState) -> TrainResult<()> {
116        for callback in &mut self.callbacks {
117            callback.on_validation_end(state)?;
118        }
119        Ok(())
120    }
121
122    /// Check if any callback requests early stopping.
123    pub fn should_stop(&self) -> bool {
124        self.callbacks.iter().any(|cb| cb.should_stop())
125    }
126}
127
128impl Default for CallbackList {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134/// Callback that logs training progress.
135pub struct EpochCallback {
136    /// Whether to print detailed information.
137    pub verbose: bool,
138}
139
140impl EpochCallback {
141    /// Create a new epoch callback.
142    pub fn new(verbose: bool) -> Self {
143        Self { verbose }
144    }
145}
146
147impl Callback for EpochCallback {
148    fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
149        if self.verbose {
150            println!(
151                "Epoch {}: loss={:.6}, val_loss={:.6}",
152                epoch,
153                state.train_loss,
154                state.val_loss.unwrap_or(f64::NAN)
155            );
156        }
157        Ok(())
158    }
159}
160
161/// Callback that logs batch progress.
162pub struct BatchCallback {
163    /// Frequency of logging (every N batches).
164    pub log_frequency: usize,
165}
166
167impl BatchCallback {
168    /// Create a new batch callback.
169    pub fn new(log_frequency: usize) -> Self {
170        Self { log_frequency }
171    }
172}
173
174impl Callback for BatchCallback {
175    fn on_batch_end(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
176        if batch.is_multiple_of(self.log_frequency) {
177            println!("Batch {}: loss={:.6}", batch, state.batch_loss);
178        }
179        Ok(())
180    }
181}
182
183/// Callback for validation during training.
184pub struct ValidationCallback {
185    /// Frequency of validation (every N epochs).
186    pub validation_frequency: usize,
187}
188
189impl ValidationCallback {
190    /// Create a new validation callback.
191    pub fn new(validation_frequency: usize) -> Self {
192        Self {
193            validation_frequency,
194        }
195    }
196}
197
198impl Callback for ValidationCallback {
199    fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
200        if epoch.is_multiple_of(self.validation_frequency) {
201            if let Some(val_loss) = state.val_loss {
202                println!("Validation at epoch {}: val_loss={:.6}", epoch, val_loss);
203            }
204        }
205        Ok(())
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use std::collections::HashMap;
213
214    fn create_test_state() -> TrainingState {
215        TrainingState {
216            epoch: 0,
217            batch: 0,
218            train_loss: 1.0,
219            val_loss: Some(0.8),
220            batch_loss: 0.5,
221            learning_rate: 0.001,
222            metrics: HashMap::new(),
223        }
224    }
225
226    #[test]
227    fn test_callback_list() {
228        let mut callbacks = CallbackList::new();
229        callbacks.add(Box::new(EpochCallback::new(false)));
230
231        let state = create_test_state();
232        callbacks.on_train_begin(&state).unwrap();
233        callbacks.on_epoch_begin(0, &state).unwrap();
234        callbacks.on_epoch_end(0, &state).unwrap();
235        callbacks.on_train_end(&state).unwrap();
236    }
237}