tensorlogic_train/callbacks/
advanced.rs1use crate::callbacks::core::Callback;
4use crate::{TrainError, TrainResult, TrainingState};
5use std::collections::HashMap;
6
7pub struct ModelEMACallback {
18 decay: f64,
20 shadow_params: HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
22 use_warmup: bool,
24 num_updates: usize,
26 initialized: bool,
28}
29
30impl ModelEMACallback {
31 pub fn new(decay: f64, use_warmup: bool) -> Self {
37 Self {
38 decay,
39 shadow_params: HashMap::new(),
40 use_warmup,
41 num_updates: 0,
42 initialized: false,
43 }
44 }
45
46 pub fn initialize(
48 &mut self,
49 parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
50 ) {
51 self.shadow_params.clear();
52 for (name, param) in parameters {
53 self.shadow_params.insert(name.clone(), param.clone());
54 }
55 self.initialized = true;
56 }
57
58 pub fn update(
60 &mut self,
61 parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
62 ) -> TrainResult<()> {
63 if !self.initialized {
64 return Err(TrainError::CallbackError(
65 "ModelEMA not initialized. Call initialize() first.".to_string(),
66 ));
67 }
68
69 self.num_updates += 1;
70
71 let decay = if self.use_warmup {
73 let warmup_decay = (1.0 + self.num_updates as f64) / (10.0 + self.num_updates as f64);
76 warmup_decay.min(self.decay)
77 } else {
78 self.decay
79 };
80
81 for (name, param) in parameters {
83 if let Some(shadow) = self.shadow_params.get_mut(name) {
84 *shadow = &*shadow * decay + &(param * (1.0 - decay));
86 }
87 }
88
89 Ok(())
90 }
91
92 pub fn get_shadow_params(
94 &self,
95 ) -> &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>> {
96 &self.shadow_params
97 }
98
99 pub fn apply_shadow(
101 &self,
102 parameters: &mut HashMap<
103 String,
104 scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>,
105 >,
106 ) {
107 for (name, shadow) in &self.shadow_params {
108 if let Some(param) = parameters.get_mut(name) {
109 *param = shadow.clone();
110 }
111 }
112 }
113}
114
115impl Callback for ModelEMACallback {
116 fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()> {
117 Ok(())
119 }
120
121 fn on_batch_end(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
122 Ok(())
124 }
125}
126
127pub struct SWACallback {
134 start_epoch: usize,
136 update_frequency: usize,
138 swa_params: HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
140 num_averaged: usize,
142 active: bool,
144 initialized: bool,
146 verbose: bool,
148}
149
150impl SWACallback {
151 pub fn new(start_epoch: usize, update_frequency: usize, verbose: bool) -> Self {
158 Self {
159 start_epoch,
160 update_frequency,
161 swa_params: HashMap::new(),
162 num_averaged: 0,
163 active: false,
164 initialized: false,
165 verbose,
166 }
167 }
168
169 pub fn update_average(
171 &mut self,
172 parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
173 ) -> TrainResult<()> {
174 if !self.active {
175 return Ok(());
176 }
177
178 if !self.initialized {
179 for (name, param) in parameters {
181 self.swa_params.insert(name.clone(), param.clone());
182 }
183 self.initialized = true;
184 self.num_averaged = 1;
185
186 if self.verbose {
187 println!("SWA: Initialized with model parameters");
188 }
189 } else {
190 let n = self.num_averaged as f64;
192 for (name, param) in parameters {
193 if let Some(swa_param) = self.swa_params.get_mut(name) {
194 *swa_param = &(&*swa_param * n + param) / (n + 1.0);
195 }
196 }
197 self.num_averaged += 1;
198
199 if self.verbose {
200 println!("SWA: Updated average (n={})", self.num_averaged);
201 }
202 }
203
204 Ok(())
205 }
206
207 pub fn get_swa_params(
209 &self,
210 ) -> &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>> {
211 &self.swa_params
212 }
213
214 pub fn apply_swa(
216 &self,
217 parameters: &mut HashMap<
218 String,
219 scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>,
220 >,
221 ) {
222 if self.initialized {
223 for (name, swa_param) in &self.swa_params {
224 if let Some(param) = parameters.get_mut(name) {
225 *param = swa_param.clone();
226 }
227 }
228 }
229 }
230
231 pub fn is_ready(&self) -> bool {
233 self.initialized && self.num_averaged > 0
234 }
235}
236
237impl Callback for SWACallback {
238 fn on_epoch_end(&mut self, epoch: usize, _state: &TrainingState) -> TrainResult<()> {
239 if epoch >= self.start_epoch && !self.active {
241 self.active = true;
242 if self.verbose {
243 println!("\nSWA: Activated at epoch {}", epoch + 1);
244 }
245 }
246
247 if self.active && epoch >= self.start_epoch {
249 let relative_epoch = epoch - self.start_epoch;
250 if relative_epoch.is_multiple_of(self.update_frequency) {
251 if self.verbose && self.initialized {
253 println!(
254 "SWA: Ready to update at epoch {} (call update_average with parameters)",
255 epoch + 1
256 );
257 }
258 }
259 }
260
261 Ok(())
262 }
263
264 fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
265 if self.verbose && self.initialized {
266 println!(
267 "\nSWA: Training complete. Averaged {} models.",
268 self.num_averaged
269 );
270 println!("SWA: Call apply_swa() to use averaged parameters.");
271 }
272 Ok(())
273 }
274}