Skip to main content

tensorlogic_train/callbacks/
advanced.rs

1//! Advanced training callbacks: EMA and SWA.
2
3use crate::callbacks::core::Callback;
4use crate::{TrainError, TrainResult, TrainingState};
5use std::collections::HashMap;
6
7/// Model EMA (Exponential Moving Average) callback.
8///
9/// Maintains an exponential moving average of model parameters during training.
10/// This often leads to better generalization and more stable predictions.
11///
12/// The shadow parameters are updated as:
13/// shadow_param = decay * shadow_param + (1 - decay) * param
14///
15/// Reference: Common practice in modern deep learning, popularized by Mean Teacher
16/// and other semi-supervised learning methods.
17pub struct ModelEMACallback {
18    /// Decay rate for EMA (typically 0.999 or 0.9999).
19    decay: f64,
20    /// Shadow parameters (EMA of model parameters).
21    shadow_params: HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
22    /// Whether to use warmup for the decay (start with smaller decay).
23    use_warmup: bool,
24    /// Current update step (for warmup).
25    num_updates: usize,
26    /// Whether callback is initialized.
27    initialized: bool,
28}
29
30impl ModelEMACallback {
31    /// Create a new Model EMA callback.
32    ///
33    /// # Arguments
34    /// * `decay` - EMA decay rate (e.g., 0.999, 0.9999)
35    /// * `use_warmup` - Whether to use decay warmup (recommended)
36    pub fn new(decay: f64, use_warmup: bool) -> Self {
37        Self {
38            decay,
39            shadow_params: HashMap::new(),
40            use_warmup,
41            num_updates: 0,
42            initialized: false,
43        }
44    }
45
46    /// Initialize shadow parameters from current model parameters.
47    pub fn initialize(
48        &mut self,
49        parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
50    ) {
51        self.shadow_params.clear();
52        for (name, param) in parameters {
53            self.shadow_params.insert(name.clone(), param.clone());
54        }
55        self.initialized = true;
56    }
57
58    /// Update EMA parameters.
59    pub fn update(
60        &mut self,
61        parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
62    ) -> TrainResult<()> {
63        if !self.initialized {
64            return Err(TrainError::CallbackError(
65                "ModelEMA not initialized. Call initialize() first.".to_string(),
66            ));
67        }
68
69        self.num_updates += 1;
70
71        // Compute effective decay with warmup
72        let decay = if self.use_warmup {
73            // Gradual warmup: start with (1 + num_updates) / (10 + num_updates)
74            // and approach self.decay
75            let warmup_decay = (1.0 + self.num_updates as f64) / (10.0 + self.num_updates as f64);
76            warmup_decay.min(self.decay)
77        } else {
78            self.decay
79        };
80
81        // Update shadow parameters
82        for (name, param) in parameters {
83            if let Some(shadow) = self.shadow_params.get_mut(name) {
84                // shadow = decay * shadow + (1 - decay) * param
85                *shadow = &*shadow * decay + &(param * (1.0 - decay));
86            }
87        }
88
89        Ok(())
90    }
91
92    /// Get the EMA parameters.
93    pub fn get_shadow_params(
94        &self,
95    ) -> &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>> {
96        &self.shadow_params
97    }
98
99    /// Apply EMA parameters to the model (for evaluation).
100    pub fn apply_shadow(
101        &self,
102        parameters: &mut HashMap<
103            String,
104            scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>,
105        >,
106    ) {
107        for (name, shadow) in &self.shadow_params {
108            if let Some(param) = parameters.get_mut(name) {
109                *param = shadow.clone();
110            }
111        }
112    }
113}
114
115impl Callback for ModelEMACallback {
116    fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()> {
117        // Note: Initialization must be done externally since we don't have access to parameters here
118        Ok(())
119    }
120
121    fn on_batch_end(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
122        // Note: Update must be called externally since we don't have access to parameters here
123        Ok(())
124    }
125}
126
127/// SWA (Stochastic Weight Averaging) callback.
128///
129/// Averages model parameters over the course of training, typically starting
130/// from a later epoch. This often leads to better generalization and wider optima.
131///
132/// Reference: Izmailov et al. "Averaging Weights Leads to Wider Optima and Better Generalization" (UAI 2018)
133pub struct SWACallback {
134    /// Epoch to start SWA (e.g., 75% through training).
135    start_epoch: usize,
136    /// Frequency of parameter averaging (every N epochs).
137    update_frequency: usize,
138    /// Running average of parameters.
139    swa_params: HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
140    /// Number of models averaged so far.
141    num_averaged: usize,
142    /// Whether SWA is active.
143    active: bool,
144    /// Whether SWA parameters are initialized.
145    initialized: bool,
146    /// Verbose output.
147    verbose: bool,
148}
149
150impl SWACallback {
151    /// Create a new SWA callback.
152    ///
153    /// # Arguments
154    /// * `start_epoch` - Epoch to start averaging (e.g., 0.75 * total_epochs)
155    /// * `update_frequency` - Average parameters every N epochs (typically 1)
156    /// * `verbose` - Whether to print progress
157    pub fn new(start_epoch: usize, update_frequency: usize, verbose: bool) -> Self {
158        Self {
159            start_epoch,
160            update_frequency,
161            swa_params: HashMap::new(),
162            num_averaged: 0,
163            active: false,
164            initialized: false,
165            verbose,
166        }
167    }
168
169    /// Update SWA parameters with current model parameters.
170    pub fn update_average(
171        &mut self,
172        parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
173    ) -> TrainResult<()> {
174        if !self.active {
175            return Ok(());
176        }
177
178        if !self.initialized {
179            // Initialize with first model
180            for (name, param) in parameters {
181                self.swa_params.insert(name.clone(), param.clone());
182            }
183            self.initialized = true;
184            self.num_averaged = 1;
185
186            if self.verbose {
187                println!("SWA: Initialized with model parameters");
188            }
189        } else {
190            // Running average: swa = (swa * n + param) / (n + 1)
191            let n = self.num_averaged as f64;
192            for (name, param) in parameters {
193                if let Some(swa_param) = self.swa_params.get_mut(name) {
194                    *swa_param = &(&*swa_param * n + param) / (n + 1.0);
195                }
196            }
197            self.num_averaged += 1;
198
199            if self.verbose {
200                println!("SWA: Updated average (n={})", self.num_averaged);
201            }
202        }
203
204        Ok(())
205    }
206
207    /// Get the SWA parameters.
208    pub fn get_swa_params(
209        &self,
210    ) -> &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>> {
211        &self.swa_params
212    }
213
214    /// Apply SWA parameters to the model.
215    pub fn apply_swa(
216        &self,
217        parameters: &mut HashMap<
218            String,
219            scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>,
220        >,
221    ) {
222        if self.initialized {
223            for (name, swa_param) in &self.swa_params {
224                if let Some(param) = parameters.get_mut(name) {
225                    *param = swa_param.clone();
226                }
227            }
228        }
229    }
230
231    /// Check if SWA has collected any averages.
232    pub fn is_ready(&self) -> bool {
233        self.initialized && self.num_averaged > 0
234    }
235}
236
237impl Callback for SWACallback {
238    fn on_epoch_end(&mut self, epoch: usize, _state: &TrainingState) -> TrainResult<()> {
239        // Activate SWA at start_epoch
240        if epoch >= self.start_epoch && !self.active {
241            self.active = true;
242            if self.verbose {
243                println!("\nSWA: Activated at epoch {}", epoch + 1);
244            }
245        }
246
247        // Check if we should update average
248        if self.active && epoch >= self.start_epoch {
249            let relative_epoch = epoch - self.start_epoch;
250            if relative_epoch.is_multiple_of(self.update_frequency) {
251                // Note: Actual update must be called externally with parameters
252                if self.verbose && self.initialized {
253                    println!(
254                        "SWA: Ready to update at epoch {} (call update_average with parameters)",
255                        epoch + 1
256                    );
257                }
258            }
259        }
260
261        Ok(())
262    }
263
264    fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
265        if self.verbose && self.initialized {
266            println!(
267                "\nSWA: Training complete. Averaged {} models.",
268                self.num_averaged
269            );
270            println!("SWA: Call apply_swa() to use averaged parameters.");
271        }
272        Ok(())
273    }
274}