tensorlogic_train/callbacks/
core.rs1use crate::{TrainResult, TrainingState};
4
5pub trait Callback {
7 fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()> {
9 Ok(())
10 }
11
12 fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
14 Ok(())
15 }
16
17 fn on_epoch_begin(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
19 Ok(())
20 }
21
22 fn on_epoch_end(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
24 Ok(())
25 }
26
27 fn on_batch_begin(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
29 Ok(())
30 }
31
32 fn on_batch_end(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
34 Ok(())
35 }
36
37 fn on_validation_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
39 Ok(())
40 }
41
42 fn should_stop(&self) -> bool {
44 false
45 }
46}
47
48pub struct CallbackList {
50 callbacks: Vec<Box<dyn Callback>>,
51}
52
53impl CallbackList {
54 pub fn new() -> Self {
56 Self {
57 callbacks: Vec::new(),
58 }
59 }
60
61 pub fn add(&mut self, callback: Box<dyn Callback>) {
63 self.callbacks.push(callback);
64 }
65
66 pub fn on_train_begin(&mut self, state: &TrainingState) -> TrainResult<()> {
68 for callback in &mut self.callbacks {
69 callback.on_train_begin(state)?;
70 }
71 Ok(())
72 }
73
74 pub fn on_train_end(&mut self, state: &TrainingState) -> TrainResult<()> {
76 for callback in &mut self.callbacks {
77 callback.on_train_end(state)?;
78 }
79 Ok(())
80 }
81
82 pub fn on_epoch_begin(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
84 for callback in &mut self.callbacks {
85 callback.on_epoch_begin(epoch, state)?;
86 }
87 Ok(())
88 }
89
90 pub fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
92 for callback in &mut self.callbacks {
93 callback.on_epoch_end(epoch, state)?;
94 }
95 Ok(())
96 }
97
98 pub fn on_batch_begin(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
100 for callback in &mut self.callbacks {
101 callback.on_batch_begin(batch, state)?;
102 }
103 Ok(())
104 }
105
106 pub fn on_batch_end(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
108 for callback in &mut self.callbacks {
109 callback.on_batch_end(batch, state)?;
110 }
111 Ok(())
112 }
113
114 pub fn on_validation_end(&mut self, state: &TrainingState) -> TrainResult<()> {
116 for callback in &mut self.callbacks {
117 callback.on_validation_end(state)?;
118 }
119 Ok(())
120 }
121
122 pub fn should_stop(&self) -> bool {
124 self.callbacks.iter().any(|cb| cb.should_stop())
125 }
126}
127
128impl Default for CallbackList {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134pub struct EpochCallback {
136 pub verbose: bool,
138}
139
140impl EpochCallback {
141 pub fn new(verbose: bool) -> Self {
143 Self { verbose }
144 }
145}
146
147impl Callback for EpochCallback {
148 fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
149 if self.verbose {
150 println!(
151 "Epoch {}: loss={:.6}, val_loss={:.6}",
152 epoch,
153 state.train_loss,
154 state.val_loss.unwrap_or(f64::NAN)
155 );
156 }
157 Ok(())
158 }
159}
160
161pub struct BatchCallback {
163 pub log_frequency: usize,
165}
166
167impl BatchCallback {
168 pub fn new(log_frequency: usize) -> Self {
170 Self { log_frequency }
171 }
172}
173
174impl Callback for BatchCallback {
175 fn on_batch_end(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
176 if batch.is_multiple_of(self.log_frequency) {
177 println!("Batch {}: loss={:.6}", batch, state.batch_loss);
178 }
179 Ok(())
180 }
181}
182
183pub struct ValidationCallback {
185 pub validation_frequency: usize,
187}
188
189impl ValidationCallback {
190 pub fn new(validation_frequency: usize) -> Self {
192 Self {
193 validation_frequency,
194 }
195 }
196}
197
198impl Callback for ValidationCallback {
199 fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
200 if epoch.is_multiple_of(self.validation_frequency) {
201 if let Some(val_loss) = state.val_loss {
202 println!("Validation at epoch {}: val_loss={:.6}", epoch, val_loss);
203 }
204 }
205 Ok(())
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use std::collections::HashMap;
213
214 fn create_test_state() -> TrainingState {
215 TrainingState {
216 epoch: 0,
217 batch: 0,
218 train_loss: 1.0,
219 val_loss: Some(0.8),
220 batch_loss: 0.5,
221 learning_rate: 0.001,
222 metrics: HashMap::new(),
223 }
224 }
225
226 #[test]
227 fn test_callback_list() {
228 let mut callbacks = CallbackList::new();
229 callbacks.add(Box::new(EpochCallback::new(false)));
230
231 let state = create_test_state();
232 callbacks.on_train_begin(&state).unwrap();
233 callbacks.on_epoch_begin(0, &state).unwrap();
234 callbacks.on_epoch_end(0, &state).unwrap();
235 callbacks.on_train_end(&state).unwrap();
236 }
237}