tensorlogic_train/
callbacks.rs

1//! Training callbacks for monitoring and controlling training.
2
3use crate::{TrainError, TrainResult, TrainingState};
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7/// Trait for training callbacks.
8pub trait Callback {
9    /// Called at the beginning of training.
10    fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()> {
11        Ok(())
12    }
13
14    /// Called at the end of training.
15    fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
16        Ok(())
17    }
18
19    /// Called at the beginning of an epoch.
20    fn on_epoch_begin(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
21        Ok(())
22    }
23
24    /// Called at the end of an epoch.
25    fn on_epoch_end(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
26        Ok(())
27    }
28
29    /// Called at the beginning of a batch.
30    fn on_batch_begin(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
31        Ok(())
32    }
33
34    /// Called at the end of a batch.
35    fn on_batch_end(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
36        Ok(())
37    }
38
39    /// Called after validation.
40    fn on_validation_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
41        Ok(())
42    }
43
44    /// Check if training should stop early.
45    fn should_stop(&self) -> bool {
46        false
47    }
48}
49
50/// List of callbacks to execute in order.
51pub struct CallbackList {
52    callbacks: Vec<Box<dyn Callback>>,
53}
54
55impl CallbackList {
56    /// Create a new callback list.
57    pub fn new() -> Self {
58        Self {
59            callbacks: Vec::new(),
60        }
61    }
62
63    /// Add a callback to the list.
64    pub fn add(&mut self, callback: Box<dyn Callback>) {
65        self.callbacks.push(callback);
66    }
67
68    /// Execute on_train_begin for all callbacks.
69    pub fn on_train_begin(&mut self, state: &TrainingState) -> TrainResult<()> {
70        for callback in &mut self.callbacks {
71            callback.on_train_begin(state)?;
72        }
73        Ok(())
74    }
75
76    /// Execute on_train_end for all callbacks.
77    pub fn on_train_end(&mut self, state: &TrainingState) -> TrainResult<()> {
78        for callback in &mut self.callbacks {
79            callback.on_train_end(state)?;
80        }
81        Ok(())
82    }
83
84    /// Execute on_epoch_begin for all callbacks.
85    pub fn on_epoch_begin(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
86        for callback in &mut self.callbacks {
87            callback.on_epoch_begin(epoch, state)?;
88        }
89        Ok(())
90    }
91
92    /// Execute on_epoch_end for all callbacks.
93    pub fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
94        for callback in &mut self.callbacks {
95            callback.on_epoch_end(epoch, state)?;
96        }
97        Ok(())
98    }
99
100    /// Execute on_batch_begin for all callbacks.
101    pub fn on_batch_begin(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
102        for callback in &mut self.callbacks {
103            callback.on_batch_begin(batch, state)?;
104        }
105        Ok(())
106    }
107
108    /// Execute on_batch_end for all callbacks.
109    pub fn on_batch_end(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
110        for callback in &mut self.callbacks {
111            callback.on_batch_end(batch, state)?;
112        }
113        Ok(())
114    }
115
116    /// Execute on_validation_end for all callbacks.
117    pub fn on_validation_end(&mut self, state: &TrainingState) -> TrainResult<()> {
118        for callback in &mut self.callbacks {
119            callback.on_validation_end(state)?;
120        }
121        Ok(())
122    }
123
124    /// Check if any callback requests early stopping.
125    pub fn should_stop(&self) -> bool {
126        self.callbacks.iter().any(|cb| cb.should_stop())
127    }
128}
129
130impl Default for CallbackList {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136/// Callback that logs training progress.
137pub struct EpochCallback {
138    /// Whether to print detailed information.
139    pub verbose: bool,
140}
141
142impl EpochCallback {
143    /// Create a new epoch callback.
144    pub fn new(verbose: bool) -> Self {
145        Self { verbose }
146    }
147}
148
149impl Callback for EpochCallback {
150    fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
151        if self.verbose {
152            println!(
153                "Epoch {}: loss={:.6}, val_loss={:.6}",
154                epoch,
155                state.train_loss,
156                state.val_loss.unwrap_or(f64::NAN)
157            );
158        }
159        Ok(())
160    }
161}
162
163/// Callback that logs batch progress.
164pub struct BatchCallback {
165    /// Frequency of logging (every N batches).
166    pub log_frequency: usize,
167}
168
169impl BatchCallback {
170    /// Create a new batch callback.
171    pub fn new(log_frequency: usize) -> Self {
172        Self { log_frequency }
173    }
174}
175
176impl Callback for BatchCallback {
177    fn on_batch_end(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
178        if batch.is_multiple_of(self.log_frequency) {
179            println!("Batch {}: loss={:.6}", batch, state.batch_loss);
180        }
181        Ok(())
182    }
183}
184
185/// Callback for validation during training.
186pub struct ValidationCallback {
187    /// Frequency of validation (every N epochs).
188    pub validation_frequency: usize,
189}
190
191impl ValidationCallback {
192    /// Create a new validation callback.
193    pub fn new(validation_frequency: usize) -> Self {
194        Self {
195            validation_frequency,
196        }
197    }
198}
199
200impl Callback for ValidationCallback {
201    fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
202        if epoch.is_multiple_of(self.validation_frequency) {
203            if let Some(val_loss) = state.val_loss {
204                println!("Validation at epoch {}: val_loss={:.6}", epoch, val_loss);
205            }
206        }
207        Ok(())
208    }
209}
210
211/// Comprehensive checkpoint data structure.
212///
213/// This structure contains all the information needed to fully restore
214/// training state, including model parameters, optimizer state, and training history.
215#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
216pub struct TrainingCheckpoint {
217    /// Current epoch number.
218    pub epoch: usize,
219    /// Model parameters as flattened vectors.
220    pub parameters: HashMap<String, Vec<f64>>,
221    /// Optimizer state as flattened vectors.
222    pub optimizer_state: HashMap<String, Vec<f64>>,
223    /// Scheduler state (if present).
224    pub scheduler_state: Option<HashMap<String, f64>>,
225    /// Current training loss.
226    pub train_loss: f64,
227    /// Current validation loss (if available).
228    pub val_loss: Option<f64>,
229    /// Training loss history.
230    pub train_loss_history: Vec<f64>,
231    /// Validation loss history.
232    pub val_loss_history: Vec<f64>,
233    /// Metrics history.
234    pub metrics_history: HashMap<String, Vec<f64>>,
235    /// Current learning rate.
236    pub learning_rate: f64,
237    /// Best validation loss seen so far.
238    pub best_val_loss: Option<f64>,
239}
240
241impl TrainingCheckpoint {
242    /// Create a new checkpoint from current training state.
243    #[allow(clippy::too_many_arguments)]
244    pub fn new(
245        epoch: usize,
246        parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
247        optimizer_state: &HashMap<String, Vec<f64>>,
248        scheduler_state: Option<HashMap<String, f64>>,
249        state: &TrainingState,
250        train_loss_history: &[f64],
251        val_loss_history: &[f64],
252        metrics_history: &HashMap<String, Vec<f64>>,
253        best_val_loss: Option<f64>,
254    ) -> Self {
255        // Convert parameters to flat vectors
256        let parameters = parameters
257            .iter()
258            .map(|(name, param)| (name.clone(), param.iter().copied().collect()))
259            .collect();
260
261        Self {
262            epoch,
263            parameters,
264            optimizer_state: optimizer_state.clone(),
265            scheduler_state,
266            train_loss: state.train_loss,
267            val_loss: state.val_loss,
268            train_loss_history: train_loss_history.to_vec(),
269            val_loss_history: val_loss_history.to_vec(),
270            metrics_history: metrics_history.clone(),
271            learning_rate: state.learning_rate,
272            best_val_loss,
273        }
274    }
275
276    /// Save checkpoint to a file.
277    pub fn save(&self, path: &PathBuf) -> TrainResult<()> {
278        let json = serde_json::to_string_pretty(self).map_err(|e| {
279            TrainError::CheckpointError(format!("Failed to serialize checkpoint: {}", e))
280        })?;
281
282        if let Some(parent) = path.parent() {
283            std::fs::create_dir_all(parent).map_err(|e| {
284                TrainError::CheckpointError(format!("Failed to create checkpoint directory: {}", e))
285            })?;
286        }
287
288        std::fs::write(path, json).map_err(|e| {
289            TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
290        })?;
291
292        Ok(())
293    }
294
295    /// Load checkpoint from a file.
296    pub fn load(path: &PathBuf) -> TrainResult<Self> {
297        let json = std::fs::read_to_string(path).map_err(|e| {
298            TrainError::CheckpointError(format!("Failed to read checkpoint: {}", e))
299        })?;
300
301        let checkpoint: Self = serde_json::from_str(&json).map_err(|e| {
302            TrainError::CheckpointError(format!("Failed to deserialize checkpoint: {}", e))
303        })?;
304
305        Ok(checkpoint)
306    }
307}
308
309/// Callback for model checkpointing.
310pub struct CheckpointCallback {
311    /// Directory to save checkpoints.
312    pub checkpoint_dir: PathBuf,
313    /// Frequency of checkpointing (every N epochs).
314    pub save_frequency: usize,
315    /// Whether to save only the best model.
316    pub save_best_only: bool,
317    /// Best validation loss seen so far.
318    best_val_loss: Option<f64>,
319}
320
321impl CheckpointCallback {
322    /// Create a new checkpoint callback.
323    pub fn new(checkpoint_dir: PathBuf, save_frequency: usize, save_best_only: bool) -> Self {
324        Self {
325            checkpoint_dir,
326            save_frequency,
327            save_best_only,
328            best_val_loss: None,
329        }
330    }
331
332    /// Save checkpoint to disk (legacy simple format).
333    fn save_checkpoint(&self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
334        let checkpoint_path = self
335            .checkpoint_dir
336            .join(format!("checkpoint_epoch_{}.json", epoch));
337
338        // Create checkpoint data
339        let mut checkpoint = HashMap::new();
340        checkpoint.insert("epoch".to_string(), epoch as f64);
341        checkpoint.insert("train_loss".to_string(), state.train_loss);
342        if let Some(val_loss) = state.val_loss {
343            checkpoint.insert("val_loss".to_string(), val_loss);
344        }
345
346        // Save to JSON
347        let json = serde_json::to_string_pretty(&checkpoint).map_err(|e| {
348            TrainError::CheckpointError(format!("Failed to serialize checkpoint: {}", e))
349        })?;
350
351        std::fs::create_dir_all(&self.checkpoint_dir).map_err(|e| {
352            TrainError::CheckpointError(format!("Failed to create checkpoint directory: {}", e))
353        })?;
354
355        std::fs::write(&checkpoint_path, json).map_err(|e| {
356            TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
357        })?;
358
359        println!("Checkpoint saved to {:?}", checkpoint_path);
360        Ok(())
361    }
362}
363
364impl Callback for CheckpointCallback {
365    fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
366        if !epoch.is_multiple_of(self.save_frequency) {
367            return Ok(());
368        }
369
370        if self.save_best_only {
371            if let Some(val_loss) = state.val_loss {
372                let should_save = self
373                    .best_val_loss
374                    .map(|best| val_loss < best)
375                    .unwrap_or(true);
376
377                if should_save {
378                    self.best_val_loss = Some(val_loss);
379                    self.save_checkpoint(epoch, state)?;
380                }
381            }
382        } else {
383            self.save_checkpoint(epoch, state)?;
384        }
385
386        Ok(())
387    }
388}
389
390/// Callback for early stopping based on validation loss.
391pub struct EarlyStoppingCallback {
392    /// Number of epochs with no improvement after which training will be stopped.
393    pub patience: usize,
394    /// Minimum change to qualify as an improvement.
395    pub min_delta: f64,
396    /// Best validation loss seen so far.
397    best_val_loss: Option<f64>,
398    /// Counter for epochs without improvement.
399    wait: usize,
400    /// Whether to stop training.
401    stop_training: bool,
402}
403
404impl EarlyStoppingCallback {
405    /// Create a new early stopping callback.
406    pub fn new(patience: usize, min_delta: f64) -> Self {
407        Self {
408            patience,
409            min_delta,
410            best_val_loss: None,
411            wait: 0,
412            stop_training: false,
413        }
414    }
415}
416
417impl Callback for EarlyStoppingCallback {
418    fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
419        if let Some(val_loss) = state.val_loss {
420            let improved = self
421                .best_val_loss
422                .map(|best| val_loss < best - self.min_delta)
423                .unwrap_or(true);
424
425            if improved {
426                self.best_val_loss = Some(val_loss);
427                self.wait = 0;
428            } else {
429                self.wait += 1;
430                if self.wait >= self.patience {
431                    println!(
432                        "Early stopping at epoch {} (no improvement for {} epochs)",
433                        epoch, self.patience
434                    );
435                    self.stop_training = true;
436                }
437            }
438        }
439
440        Ok(())
441    }
442
443    fn should_stop(&self) -> bool {
444        self.stop_training
445    }
446}
447
448/// Callback for learning rate reduction on plateau.
449#[allow(dead_code)]
450pub struct ReduceLrOnPlateauCallback {
451    /// Factor by which to reduce learning rate.
452    pub factor: f64,
453    /// Number of epochs with no improvement after which learning rate will be reduced.
454    pub patience: usize,
455    /// Minimum change to qualify as an improvement.
456    pub min_delta: f64,
457    /// Lower bound on the learning rate.
458    pub min_lr: f64,
459    /// Best validation loss seen so far.
460    best_val_loss: Option<f64>,
461    /// Counter for epochs without improvement.
462    wait: usize,
463}
464
465impl ReduceLrOnPlateauCallback {
466    /// Create a new reduce LR on plateau callback.
467    #[allow(dead_code)]
468    pub fn new(factor: f64, patience: usize, min_delta: f64, min_lr: f64) -> Self {
469        Self {
470            factor,
471            patience,
472            min_delta,
473            min_lr,
474            best_val_loss: None,
475            wait: 0,
476        }
477    }
478}
479
480impl Callback for ReduceLrOnPlateauCallback {
481    fn on_epoch_end(&mut self, _epoch: usize, state: &TrainingState) -> TrainResult<()> {
482        if let Some(val_loss) = state.val_loss {
483            let improved = self
484                .best_val_loss
485                .map(|best| val_loss < best - self.min_delta)
486                .unwrap_or(true);
487
488            if improved {
489                self.best_val_loss = Some(val_loss);
490                self.wait = 0;
491            } else {
492                self.wait += 1;
493                if self.wait >= self.patience {
494                    // Note: We can't actually modify the optimizer here since we don't have a reference
495                    // This would need to be handled by the Trainer
496                    let new_lr = (state.learning_rate * self.factor).max(self.min_lr);
497                    if new_lr != state.learning_rate {
498                        println!("Reducing learning rate to {:.6}", new_lr);
499                    }
500                    self.wait = 0;
501                }
502            }
503        }
504
505        Ok(())
506    }
507}
508
509/// Learning rate finder callback using the LR range test.
510///
511/// This callback implements the learning rate range test proposed by Leslie N. Smith.
512/// It gradually increases the learning rate from a minimum to a maximum value over
513/// a specified number of iterations/epochs and tracks the loss at each step.
514///
515/// The optimal learning rate is typically found just before the loss starts to increase.
516///
517/// # Example
518/// ```rust,ignore
519/// use tensorlogic_train::{LearningRateFinder, CallbackList};
520///
521/// let mut callbacks = CallbackList::new();
522/// callbacks.add(Box::new(LearningRateFinder::new(
523///     1e-7,   // start_lr
524///     10.0,   // end_lr
525///     100,    // num_steps
526/// )));
527/// ```
528pub struct LearningRateFinder {
529    /// Starting learning rate.
530    start_lr: f64,
531    /// Ending learning rate.
532    end_lr: f64,
533    /// Number of steps to test.
534    num_steps: usize,
535    /// Current step.
536    current_step: usize,
537    /// History of (lr, loss) pairs.
538    pub history: Vec<(f64, f64)>,
539    /// Whether to use exponential or linear scaling.
540    exponential: bool,
541    /// Smoothing factor for loss (0.0 = no smoothing, 0.9 = heavy smoothing).
542    smoothing: f64,
543    /// Smoothed loss.
544    smoothed_loss: Option<f64>,
545}
546
547impl LearningRateFinder {
548    /// Create a new learning rate finder.
549    ///
550    /// # Arguments
551    /// * `start_lr` - Starting learning rate (e.g., 1e-7)
552    /// * `end_lr` - Ending learning rate (e.g., 10.0)
553    /// * `num_steps` - Number of steps to test
554    pub fn new(start_lr: f64, end_lr: f64, num_steps: usize) -> Self {
555        Self {
556            start_lr,
557            end_lr,
558            num_steps,
559            current_step: 0,
560            history: Vec::with_capacity(num_steps),
561            exponential: true, // Exponential scaling is recommended
562            smoothing: 0.0,    // No smoothing by default
563            smoothed_loss: None,
564        }
565    }
566
567    /// Enable exponential scaling (recommended, default).
568    pub fn with_exponential_scaling(mut self) -> Self {
569        self.exponential = true;
570        self
571    }
572
573    /// Enable linear scaling.
574    pub fn with_linear_scaling(mut self) -> Self {
575        self.exponential = false;
576        self
577    }
578
579    /// Set loss smoothing factor (0.0-1.0).
580    ///
581    /// Recommended: 0.9 for noisy losses, 0.0 for smooth losses.
582    pub fn with_smoothing(mut self, smoothing: f64) -> Self {
583        self.smoothing = smoothing.clamp(0.0, 1.0);
584        self
585    }
586
587    /// Compute the current learning rate based on step.
588    fn compute_lr(&self) -> f64 {
589        if self.num_steps <= 1 {
590            return self.start_lr;
591        }
592
593        let step_ratio = self.current_step as f64 / (self.num_steps - 1) as f64;
594
595        if self.exponential {
596            // Exponential scaling: lr = start_lr * (end_lr/start_lr)^step_ratio
597            self.start_lr * (self.end_lr / self.start_lr).powf(step_ratio)
598        } else {
599            // Linear scaling: lr = start_lr + (end_lr - start_lr) * step_ratio
600            self.start_lr + (self.end_lr - self.start_lr) * step_ratio
601        }
602    }
603
604    /// Get the smoothed loss.
605    fn smooth_loss(&mut self, loss: f64) -> f64 {
606        if self.smoothing == 0.0 {
607            return loss;
608        }
609
610        match self.smoothed_loss {
611            None => {
612                self.smoothed_loss = Some(loss);
613                loss
614            }
615            Some(prev) => {
616                let smoothed = self.smoothing * prev + (1.0 - self.smoothing) * loss;
617                self.smoothed_loss = Some(smoothed);
618                smoothed
619            }
620        }
621    }
622
623    /// Find the suggested optimal learning rate.
624    ///
625    /// Returns the learning rate with the steepest negative gradient (fastest decrease in loss).
626    pub fn suggest_lr(&self) -> Option<f64> {
627        if self.history.len() < 3 {
628            return None;
629        }
630
631        let mut best_lr = None;
632        let mut best_gradient = f64::INFINITY;
633
634        // Compute gradients and find steepest descent
635        for i in 1..self.history.len() {
636            let (lr1, loss1) = self.history[i - 1];
637            let (lr2, loss2) = self.history[i];
638
639            let gradient = (loss2 - loss1) / (lr2 - lr1);
640
641            if gradient < best_gradient {
642                best_gradient = gradient;
643                best_lr = Some(lr2);
644            }
645        }
646
647        best_lr
648    }
649
650    /// Print the LR finder results.
651    pub fn print_results(&self) {
652        println!("\n=== Learning Rate Finder Results ===");
653        println!(
654            "Tested {} learning rates from {:.2e} to {:.2e}",
655            self.history.len(),
656            self.start_lr,
657            self.end_lr
658        );
659
660        if let Some(suggested_lr) = self.suggest_lr() {
661            println!("Suggested optimal LR: {:.2e}", suggested_lr);
662            println!(
663                "Consider using LR between {:.2e} and {:.2e}",
664                suggested_lr / 10.0,
665                suggested_lr
666            );
667        }
668
669        println!("\nLR, Loss:");
670        for (lr, loss) in &self.history {
671            println!("{:.6e}, {:.6}", lr, loss);
672        }
673        println!("===================================\n");
674    }
675}
676
677impl Callback for LearningRateFinder {
678    fn on_batch_end(&mut self, _batch: usize, state: &TrainingState) -> TrainResult<()> {
679        if self.current_step >= self.num_steps {
680            return Ok(());
681        }
682
683        // Get current loss and smooth it
684        let loss = self.smooth_loss(state.batch_loss);
685
686        // Record (lr, loss) pair
687        let lr = self.compute_lr();
688        self.history.push((lr, loss));
689
690        self.current_step += 1;
691
692        // Note: The actual LR update happens via the trainer's optimizer
693        // This callback just tracks the relationship
694
695        Ok(())
696    }
697
698    fn should_stop(&self) -> bool {
699        // Stop after testing all LR values
700        self.current_step >= self.num_steps
701    }
702}
703
704/// Gradient flow monitor for tracking gradient statistics during training.
705///
706/// This callback tracks gradient norms, mean, std, and identifies vanishing/exploding gradients.
707/// Useful for debugging training issues and understanding gradient flow through the network.
708///
709/// # Example
710/// ```rust,ignore
711/// use tensorlogic_train::{GradientMonitor, CallbackList};
712///
713/// let mut callbacks = CallbackList::new();
714/// callbacks.add(Box::new(GradientMonitor::new(
715///     10,      // log_frequency
716///     1e-7,    // vanishing_threshold
717///     100.0,   // exploding_threshold
718/// )));
719/// ```
720pub struct GradientMonitor {
721    /// Frequency of logging (every N batches).
722    log_frequency: usize,
723    /// Threshold for detecting vanishing gradients.
724    vanishing_threshold: f64,
725    /// Threshold for detecting exploding gradients.
726    exploding_threshold: f64,
727    /// History of gradient norms.
728    pub gradient_norms: Vec<f64>,
729    /// History of gradient means.
730    pub gradient_means: Vec<f64>,
731    /// History of gradient stds.
732    pub gradient_stds: Vec<f64>,
733    /// Count of vanishing gradient warnings.
734    pub vanishing_count: usize,
735    /// Count of exploding gradient warnings.
736    pub exploding_count: usize,
737    /// Current batch counter.
738    batch_counter: usize,
739}
740
741impl GradientMonitor {
742    /// Create a new gradient monitor.
743    ///
744    /// # Arguments
745    /// * `log_frequency` - Log statistics every N batches
746    /// * `vanishing_threshold` - Threshold below which gradients are considered vanishing
747    /// * `exploding_threshold` - Threshold above which gradients are considered exploding
748    pub fn new(log_frequency: usize, vanishing_threshold: f64, exploding_threshold: f64) -> Self {
749        Self {
750            log_frequency,
751            vanishing_threshold,
752            exploding_threshold,
753            gradient_norms: Vec::new(),
754            gradient_means: Vec::new(),
755            gradient_stds: Vec::new(),
756            vanishing_count: 0,
757            exploding_count: 0,
758            batch_counter: 0,
759        }
760    }
761
762    /// Compute gradient statistics (placeholder - actual implementation needs gradient access).
763    fn compute_gradient_stats(&mut self, _state: &TrainingState) -> (f64, f64, f64) {
764        // In a real implementation, this would access actual gradients
765        // For now, return placeholder values
766        // (norm, mean, std)
767        (1.0, 0.0, 0.1)
768    }
769
770    /// Check for vanishing gradients.
771    fn check_vanishing(&mut self, norm: f64) -> bool {
772        if norm < self.vanishing_threshold {
773            self.vanishing_count += 1;
774            return true;
775        }
776        false
777    }
778
779    /// Check for exploding gradients.
780    fn check_exploding(&mut self, norm: f64) -> bool {
781        if norm > self.exploding_threshold {
782            self.exploding_count += 1;
783            return true;
784        }
785        false
786    }
787
788    /// Print gradient statistics.
789    fn print_stats(&self, norm: f64, mean: f64, std: f64) {
790        println!("Gradient Stats [Batch {}]:", self.batch_counter);
791        println!("  Norm: {:.6e}, Mean: {:.6e}, Std: {:.6e}", norm, mean, std);
792
793        if self.vanishing_count > 0 {
794            println!(
795                "  ⚠️  Vanishing gradient warnings: {}",
796                self.vanishing_count
797            );
798        }
799
800        if self.exploding_count > 0 {
801            println!(
802                "  ⚠️  Exploding gradient warnings: {}",
803                self.exploding_count
804            );
805        }
806    }
807
808    /// Get summary statistics.
809    pub fn summary(&self) -> GradientSummary {
810        let avg_norm = if !self.gradient_norms.is_empty() {
811            self.gradient_norms.iter().sum::<f64>() / self.gradient_norms.len() as f64
812        } else {
813            0.0
814        };
815
816        GradientSummary {
817            total_batches: self.batch_counter,
818            average_norm: avg_norm,
819            vanishing_count: self.vanishing_count,
820            exploding_count: self.exploding_count,
821        }
822    }
823}
824
825/// Summary of gradient statistics.
826#[derive(Debug, Clone)]
827pub struct GradientSummary {
828    /// Total number of batches monitored.
829    pub total_batches: usize,
830    /// Average gradient norm.
831    pub average_norm: f64,
832    /// Number of vanishing gradient warnings.
833    pub vanishing_count: usize,
834    /// Number of exploding gradient warnings.
835    pub exploding_count: usize,
836}
837
838impl Callback for GradientMonitor {
839    fn on_batch_end(&mut self, _batch: usize, state: &TrainingState) -> TrainResult<()> {
840        self.batch_counter += 1;
841
842        // Compute gradient statistics
843        let (norm, mean, std) = self.compute_gradient_stats(state);
844
845        // Record statistics
846        self.gradient_norms.push(norm);
847        self.gradient_means.push(mean);
848        self.gradient_stds.push(std);
849
850        // Check for issues
851        let vanishing = self.check_vanishing(norm);
852        let exploding = self.check_exploding(norm);
853
854        // Log if needed
855        if self.batch_counter.is_multiple_of(self.log_frequency) {
856            self.print_stats(norm, mean, std);
857        } else if vanishing || exploding {
858            // Always log warnings immediately
859            self.print_stats(norm, mean, std);
860        }
861
862        Ok(())
863    }
864
865    fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
866        let summary = self.summary();
867        println!("\n=== Gradient Monitoring Summary ===");
868        println!("Total batches: {}", summary.total_batches);
869        println!("Average gradient norm: {:.6e}", summary.average_norm);
870        println!("Vanishing gradient warnings: {}", summary.vanishing_count);
871        println!("Exploding gradient warnings: {}", summary.exploding_count);
872        println!("====================================\n");
873        Ok(())
874    }
875}
876
877#[cfg(test)]
878mod tests {
879    use super::*;
880
881    fn create_test_state() -> TrainingState {
882        TrainingState {
883            epoch: 0,
884            batch: 0,
885            train_loss: 1.0,
886            val_loss: Some(0.8),
887            batch_loss: 0.5,
888            learning_rate: 0.001,
889            metrics: HashMap::new(),
890        }
891    }
892
893    #[test]
894    fn test_callback_list() {
895        let mut callbacks = CallbackList::new();
896        callbacks.add(Box::new(EpochCallback::new(false)));
897
898        let state = create_test_state();
899        callbacks.on_train_begin(&state).unwrap();
900        callbacks.on_epoch_begin(0, &state).unwrap();
901        callbacks.on_epoch_end(0, &state).unwrap();
902        callbacks.on_train_end(&state).unwrap();
903    }
904
905    #[test]
906    fn test_early_stopping() {
907        let mut callback = EarlyStoppingCallback::new(2, 0.01);
908        let mut state = create_test_state();
909
910        // First epoch - improvement
911        state.val_loss = Some(1.0);
912        callback.on_epoch_end(0, &state).unwrap();
913        assert!(!callback.should_stop());
914
915        // Second epoch - improvement
916        state.val_loss = Some(0.8);
917        callback.on_epoch_end(1, &state).unwrap();
918        assert!(!callback.should_stop());
919
920        // Third epoch - no improvement
921        state.val_loss = Some(0.81);
922        callback.on_epoch_end(2, &state).unwrap();
923        assert!(!callback.should_stop());
924
925        // Fourth epoch - no improvement (exceeds patience)
926        state.val_loss = Some(0.82);
927        callback.on_epoch_end(3, &state).unwrap();
928        assert!(callback.should_stop());
929    }
930
931    #[test]
932    fn test_checkpoint_callback() {
933        use std::env::temp_dir;
934
935        let checkpoint_dir = temp_dir().join("tensorlogic_test_checkpoints");
936        let mut callback = CheckpointCallback::new(checkpoint_dir.clone(), 1, false);
937        let state = create_test_state();
938
939        callback.on_epoch_end(0, &state).unwrap();
940
941        // Verify checkpoint was created
942        let checkpoint_path = checkpoint_dir.join("checkpoint_epoch_0.json");
943        assert!(checkpoint_path.exists());
944
945        // Clean up
946        std::fs::remove_dir_all(checkpoint_dir).ok();
947    }
948
949    #[test]
950    fn test_training_checkpoint_save_load() {
951        use scirs2_core::ndarray::Array2;
952        use std::env::temp_dir;
953
954        // Create test parameters
955        let mut parameters = HashMap::new();
956        parameters.insert("weight".to_string(), Array2::from_elem((2, 3), 1.5));
957        parameters.insert("bias".to_string(), Array2::from_elem((1, 3), 0.5));
958
959        // Create test state
960        let state = TrainingState {
961            epoch: 5,
962            batch: 100,
963            train_loss: 0.75,
964            val_loss: Some(0.85),
965            batch_loss: 0.72,
966            learning_rate: 0.001,
967            metrics: HashMap::new(),
968        };
969
970        // Create optimizer state (mock)
971        let optimizer_state = {
972            let mut state = HashMap::new();
973            state.insert("momentum_weight".to_string(), vec![0.1, 0.2, 0.3]);
974            state.insert("momentum_bias".to_string(), vec![0.05]);
975            state
976        };
977
978        // Create checkpoint
979        let checkpoint = TrainingCheckpoint::new(
980            5,
981            &parameters,
982            &optimizer_state,
983            None,
984            &state,
985            &[1.0, 0.9, 0.8, 0.77, 0.75],
986            &[1.1, 0.95, 0.88, 0.87, 0.85],
987            &HashMap::new(),
988            Some(0.85),
989        );
990
991        // Save checkpoint
992        let checkpoint_path = temp_dir().join("test_training_checkpoint.json");
993        checkpoint.save(&checkpoint_path).unwrap();
994
995        // Verify file exists
996        assert!(checkpoint_path.exists());
997
998        // Load checkpoint
999        let loaded = TrainingCheckpoint::load(&checkpoint_path).unwrap();
1000
1001        // Verify data
1002        assert_eq!(loaded.epoch, 5);
1003        assert_eq!(loaded.train_loss, 0.75);
1004        assert_eq!(loaded.val_loss, Some(0.85));
1005        assert_eq!(loaded.learning_rate, 0.001);
1006        assert_eq!(loaded.train_loss_history.len(), 5);
1007        assert_eq!(loaded.val_loss_history.len(), 5);
1008        assert_eq!(loaded.best_val_loss, Some(0.85));
1009
1010        // Verify parameters
1011        assert_eq!(loaded.parameters.len(), 2);
1012        assert!(loaded.parameters.contains_key("weight"));
1013        assert!(loaded.parameters.contains_key("bias"));
1014
1015        // Verify optimizer state
1016        assert_eq!(loaded.optimizer_state.len(), 2);
1017        assert!(loaded.optimizer_state.contains_key("momentum_weight"));
1018
1019        // Clean up
1020        std::fs::remove_file(checkpoint_path).ok();
1021    }
1022
1023    #[test]
1024    fn test_training_checkpoint_with_metrics() {
1025        use scirs2_core::ndarray::Array2;
1026        use std::env::temp_dir;
1027
1028        let mut parameters = HashMap::new();
1029        parameters.insert("w".to_string(), Array2::zeros((2, 2)));
1030
1031        let state = create_test_state();
1032        let optimizer_state = HashMap::new();
1033
1034        // Add metrics history
1035        let mut metrics_history = HashMap::new();
1036        metrics_history.insert("accuracy".to_string(), vec![0.5, 0.6, 0.7]);
1037        metrics_history.insert("f1_score".to_string(), vec![0.45, 0.55, 0.65]);
1038
1039        let checkpoint = TrainingCheckpoint::new(
1040            2,
1041            &parameters,
1042            &optimizer_state,
1043            None,
1044            &state,
1045            &[1.0, 0.8, 0.6],
1046            &[1.1, 0.9, 0.7],
1047            &metrics_history,
1048            Some(0.7),
1049        );
1050
1051        let checkpoint_path = temp_dir().join("test_checkpoint_with_metrics.json");
1052        checkpoint.save(&checkpoint_path).unwrap();
1053
1054        let loaded = TrainingCheckpoint::load(&checkpoint_path).unwrap();
1055
1056        // Verify metrics
1057        assert_eq!(loaded.metrics_history.len(), 2);
1058        assert!(loaded.metrics_history.contains_key("accuracy"));
1059        assert!(loaded.metrics_history.contains_key("f1_score"));
1060        assert_eq!(loaded.metrics_history["accuracy"].len(), 3);
1061
1062        std::fs::remove_file(checkpoint_path).ok();
1063    }
1064}
1065
1066/// Weight histogram statistics for debugging and monitoring.
1067#[derive(Debug, Clone)]
1068pub struct HistogramStats {
1069    /// Parameter name.
1070    pub name: String,
1071    /// Minimum value.
1072    pub min: f64,
1073    /// Maximum value.
1074    pub max: f64,
1075    /// Mean value.
1076    pub mean: f64,
1077    /// Standard deviation.
1078    pub std: f64,
1079    /// Histogram bins (boundaries).
1080    pub bins: Vec<f64>,
1081    /// Histogram counts per bin.
1082    pub counts: Vec<usize>,
1083}
1084
1085impl HistogramStats {
1086    /// Compute histogram statistics from parameter values.
1087    pub fn compute(name: &str, values: &[f64], num_bins: usize) -> Self {
1088        if values.is_empty() {
1089            return Self {
1090                name: name.to_string(),
1091                min: 0.0,
1092                max: 0.0,
1093                mean: 0.0,
1094                std: 0.0,
1095                bins: vec![],
1096                counts: vec![],
1097            };
1098        }
1099
1100        // Basic statistics
1101        let min = values.iter().copied().fold(f64::INFINITY, f64::min);
1102        let max = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
1103        let sum: f64 = values.iter().sum();
1104        let mean = sum / values.len() as f64;
1105
1106        let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
1107        let std = variance.sqrt();
1108
1109        // Create histogram bins
1110        let mut bins = Vec::with_capacity(num_bins + 1);
1111        let mut counts = vec![0; num_bins];
1112
1113        let range = max - min;
1114        let bin_width = if range > 0.0 {
1115            range / num_bins as f64
1116        } else {
1117            1.0
1118        };
1119
1120        for i in 0..=num_bins {
1121            bins.push(min + i as f64 * bin_width);
1122        }
1123
1124        // Count values in each bin
1125        for &value in values {
1126            let bin_idx = if range > 0.0 {
1127                ((value - min) / bin_width).floor() as usize
1128            } else {
1129                0
1130            };
1131            let bin_idx = bin_idx.min(num_bins - 1);
1132            counts[bin_idx] += 1;
1133        }
1134
1135        Self {
1136            name: name.to_string(),
1137            min,
1138            max,
1139            mean,
1140            std,
1141            bins,
1142            counts,
1143        }
1144    }
1145
1146    /// Pretty print histogram as ASCII art.
1147    pub fn display(&self, width: usize) {
1148        println!("\n=== Histogram: {} ===", self.name);
1149        println!("  Min: {:.6}, Max: {:.6}", self.min, self.max);
1150        println!("  Mean: {:.6}, Std: {:.6}", self.mean, self.std);
1151        println!("\n  Distribution:");
1152
1153        if self.counts.is_empty() {
1154            println!("    (empty)");
1155            return;
1156        }
1157
1158        let max_count = *self.counts.iter().max().unwrap_or(&1);
1159
1160        for (i, &count) in self.counts.iter().enumerate() {
1161            let bar_len = if max_count > 0 {
1162                (count as f64 / max_count as f64 * width as f64) as usize
1163            } else {
1164                0
1165            };
1166
1167            let bar = "█".repeat(bar_len);
1168            let left = if i < self.bins.len() - 1 {
1169                self.bins[i]
1170            } else {
1171                self.bins[i - 1]
1172            };
1173            let right = if i < self.bins.len() - 1 {
1174                self.bins[i + 1]
1175            } else {
1176                self.bins[i]
1177            };
1178
1179            println!("  [{:>8.3}, {:>8.3}): {:>6} {}", left, right, count, bar);
1180        }
1181    }
1182}
1183
1184/// Callback for tracking weight histograms during training.
1185///
1186/// This callback computes and logs histogram statistics of model parameters
1187/// at regular intervals. Useful for:
1188/// - Detecting vanishing/exploding weights
1189/// - Monitoring weight distribution changes
1190/// - Debugging initialization issues
1191/// - Understanding parameter evolution
1192///
1193/// # Example
1194///
1195/// ```no_run
1196/// use tensorlogic_train::{CallbackList, HistogramCallback};
1197///
1198/// let mut callbacks = CallbackList::new();
1199/// callbacks.add(Box::new(HistogramCallback::new(
1200///     5,   // log_frequency: Every 5 epochs
1201///     10,  // num_bins: 10 histogram bins
1202///     true, // verbose: Print detailed histograms
1203/// )));
1204/// ```
1205pub struct HistogramCallback {
1206    /// Frequency of logging (every N epochs).
1207    log_frequency: usize,
1208    /// Number of histogram bins.
1209    #[allow(dead_code)]
1210    // Used in compute_histograms - will be active when parameters are accessible
1211    num_bins: usize,
1212    /// Whether to print detailed histograms.
1213    verbose: bool,
1214    /// History of histogram statistics.
1215    pub history: Vec<HashMap<String, HistogramStats>>,
1216}
1217
1218impl HistogramCallback {
1219    /// Create a new histogram callback.
1220    ///
1221    /// # Arguments
1222    /// * `log_frequency` - Log histograms every N epochs
1223    /// * `num_bins` - Number of bins in each histogram
1224    /// * `verbose` - Print detailed ASCII histograms
1225    pub fn new(log_frequency: usize, num_bins: usize, verbose: bool) -> Self {
1226        Self {
1227            log_frequency,
1228            num_bins,
1229            verbose,
1230            history: Vec::new(),
1231        }
1232    }
1233
1234    /// Compute histograms for all parameters in state.
1235    #[allow(dead_code)] // Placeholder - will be used when TrainingState includes parameters
1236    fn compute_histograms(&self, _state: &TrainingState) -> HashMap<String, HistogramStats> {
1237        // In a real implementation, we would access parameters from state
1238        // For now, this is a placeholder that would be populated when
1239        // TrainingState includes parameter access
1240
1241        // Example of what this would look like with actual parameters:
1242        // let mut histograms = HashMap::new();
1243        // for (name, param) in state.parameters.iter() {
1244        //     let values: Vec<f64> = param.iter().copied().collect();
1245        //     let stats = HistogramStats::compute(name, &values, self.num_bins);
1246        //     histograms.insert(name.clone(), stats);
1247        // }
1248        // histograms
1249
1250        HashMap::new()
1251    }
1252}
1253
1254impl Callback for HistogramCallback {
1255    fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
1256        if (epoch + 1).is_multiple_of(self.log_frequency) {
1257            let histograms = self.compute_histograms(state);
1258
1259            if self.verbose {
1260                println!("\n--- Weight Histograms (Epoch {}) ---", epoch + 1);
1261                for (_name, stats) in histograms.iter() {
1262                    stats.display(40); // 40 character width for ASCII bars
1263                }
1264            } else {
1265                println!(
1266                    "Epoch {}: Computed histograms for {} parameters",
1267                    epoch + 1,
1268                    histograms.len()
1269                );
1270            }
1271
1272            self.history.push(histograms);
1273        }
1274
1275        Ok(())
1276    }
1277}
1278
1279/// Performance profiling statistics.
1280#[derive(Debug, Clone, Default)]
1281pub struct ProfilingStats {
1282    /// Total training time (seconds).
1283    pub total_time: f64,
1284    /// Time per epoch (seconds).
1285    pub epoch_times: Vec<f64>,
1286    /// Samples per second.
1287    pub samples_per_sec: f64,
1288    /// Batches per second.
1289    pub batches_per_sec: f64,
1290    /// Average batch time (seconds).
1291    pub avg_batch_time: f64,
1292    /// Peak memory usage (MB) - placeholder.
1293    pub peak_memory_mb: f64,
1294}
1295
1296impl ProfilingStats {
1297    /// Pretty print profiling statistics.
1298    pub fn display(&self) {
1299        println!("\n=== Profiling Statistics ===");
1300        println!("Total time: {:.2}s", self.total_time);
1301        println!("Samples/sec: {:.2}", self.samples_per_sec);
1302        println!("Batches/sec: {:.2}", self.batches_per_sec);
1303        println!("Avg batch time: {:.4}s", self.avg_batch_time);
1304
1305        if !self.epoch_times.is_empty() {
1306            let avg_epoch = self.epoch_times.iter().sum::<f64>() / self.epoch_times.len() as f64;
1307            let min_epoch = self
1308                .epoch_times
1309                .iter()
1310                .copied()
1311                .fold(f64::INFINITY, f64::min);
1312            let max_epoch = self
1313                .epoch_times
1314                .iter()
1315                .copied()
1316                .fold(f64::NEG_INFINITY, f64::max);
1317
1318            println!("\nEpoch times:");
1319            println!("  Average: {:.2}s", avg_epoch);
1320            println!("  Min: {:.2}s", min_epoch);
1321            println!("  Max: {:.2}s", max_epoch);
1322        }
1323    }
1324}
1325
1326/// Callback for profiling training performance.
1327///
1328/// Tracks timing information and throughput metrics during training.
1329/// Useful for:
1330/// - Identifying performance bottlenecks
1331/// - Comparing different configurations
1332/// - Monitoring training speed
1333/// - Resource utilization tracking
1334///
1335/// # Example
1336///
1337/// ```no_run
1338/// use tensorlogic_train::{CallbackList, ProfilingCallback};
1339///
1340/// let mut callbacks = CallbackList::new();
1341/// callbacks.add(Box::new(ProfilingCallback::new(
1342///     true,  // verbose: Print detailed stats
1343///     5,     // log_frequency: Every 5 epochs
1344/// )));
1345/// ```
1346pub struct ProfilingCallback {
1347    /// Whether to print detailed profiling info.
1348    verbose: bool,
1349    /// Frequency of logging (every N epochs).
1350    log_frequency: usize,
1351    /// Training start time.
1352    start_time: Option<std::time::Instant>,
1353    /// Last epoch start time.
1354    epoch_start_time: Option<std::time::Instant>,
1355    /// Batch start time.
1356    batch_start_time: Option<std::time::Instant>,
1357    /// Accumulated statistics.
1358    pub stats: ProfilingStats,
1359    /// Batch times for current epoch.
1360    current_epoch_batch_times: Vec<f64>,
1361    /// Total batches processed.
1362    total_batches: usize,
1363}
1364
1365impl ProfilingCallback {
1366    /// Create a new profiling callback.
1367    ///
1368    /// # Arguments
1369    /// * `verbose` - Print detailed profiling information
1370    /// * `log_frequency` - Log stats every N epochs
1371    pub fn new(verbose: bool, log_frequency: usize) -> Self {
1372        Self {
1373            verbose,
1374            log_frequency,
1375            start_time: None,
1376            epoch_start_time: None,
1377            batch_start_time: None,
1378            stats: ProfilingStats::default(),
1379            current_epoch_batch_times: Vec::new(),
1380            total_batches: 0,
1381        }
1382    }
1383
1384    /// Get profiling statistics.
1385    pub fn get_stats(&self) -> &ProfilingStats {
1386        &self.stats
1387    }
1388}
1389
1390impl Callback for ProfilingCallback {
1391    fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()> {
1392        self.start_time = Some(std::time::Instant::now());
1393        if self.verbose {
1394            println!("⏱️  Profiling started");
1395        }
1396        Ok(())
1397    }
1398
1399    fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
1400        if let Some(start) = self.start_time {
1401            self.stats.total_time = start.elapsed().as_secs_f64();
1402
1403            // Compute aggregate statistics
1404            if self.total_batches > 0 {
1405                self.stats.avg_batch_time = self.stats.total_time / self.total_batches as f64;
1406                self.stats.batches_per_sec = self.total_batches as f64 / self.stats.total_time;
1407            }
1408
1409            if self.verbose {
1410                println!("\n⏱️  Profiling completed");
1411                self.stats.display();
1412            }
1413        }
1414        Ok(())
1415    }
1416
1417    fn on_epoch_begin(&mut self, epoch: usize, _state: &TrainingState) -> TrainResult<()> {
1418        self.epoch_start_time = Some(std::time::Instant::now());
1419        self.current_epoch_batch_times.clear();
1420
1421        if self.verbose && (epoch + 1).is_multiple_of(self.log_frequency) {
1422            println!("\n⏱️  Epoch {} profiling started", epoch + 1);
1423        }
1424        Ok(())
1425    }
1426
1427    fn on_epoch_end(&mut self, epoch: usize, _state: &TrainingState) -> TrainResult<()> {
1428        if let Some(epoch_start) = self.epoch_start_time {
1429            let epoch_time = epoch_start.elapsed().as_secs_f64();
1430            self.stats.epoch_times.push(epoch_time);
1431
1432            if self.verbose && (epoch + 1).is_multiple_of(self.log_frequency) {
1433                let avg_batch = if !self.current_epoch_batch_times.is_empty() {
1434                    self.current_epoch_batch_times.iter().sum::<f64>()
1435                        / self.current_epoch_batch_times.len() as f64
1436                } else {
1437                    0.0
1438                };
1439
1440                println!("⏱️  Epoch {} completed:", epoch + 1);
1441                println!("    Time: {:.2}s", epoch_time);
1442                println!(
1443                    "    Batches: {} ({:.4}s avg)",
1444                    self.current_epoch_batch_times.len(),
1445                    avg_batch
1446                );
1447            }
1448        }
1449        Ok(())
1450    }
1451
1452    fn on_batch_begin(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
1453        self.batch_start_time = Some(std::time::Instant::now());
1454        Ok(())
1455    }
1456
1457    fn on_batch_end(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
1458        if let Some(batch_start) = self.batch_start_time {
1459            let batch_time = batch_start.elapsed().as_secs_f64();
1460            self.current_epoch_batch_times.push(batch_time);
1461            self.total_batches += 1;
1462        }
1463        Ok(())
1464    }
1465}
1466
1467/// Model EMA (Exponential Moving Average) callback.
1468///
1469/// Maintains an exponential moving average of model parameters during training.
1470/// This often leads to better generalization and more stable predictions.
1471///
1472/// The shadow parameters are updated as:
1473/// shadow_param = decay * shadow_param + (1 - decay) * param
1474///
1475/// Reference: Common practice in modern deep learning, popularized by Mean Teacher
1476/// and other semi-supervised learning methods.
1477pub struct ModelEMACallback {
1478    /// Decay rate for EMA (typically 0.999 or 0.9999).
1479    decay: f64,
1480    /// Shadow parameters (EMA of model parameters).
1481    shadow_params: HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
1482    /// Whether to use warmup for the decay (start with smaller decay).
1483    use_warmup: bool,
1484    /// Current update step (for warmup).
1485    num_updates: usize,
1486    /// Whether callback is initialized.
1487    initialized: bool,
1488}
1489
1490impl ModelEMACallback {
1491    /// Create a new Model EMA callback.
1492    ///
1493    /// # Arguments
1494    /// * `decay` - EMA decay rate (e.g., 0.999, 0.9999)
1495    /// * `use_warmup` - Whether to use decay warmup (recommended)
1496    pub fn new(decay: f64, use_warmup: bool) -> Self {
1497        Self {
1498            decay,
1499            shadow_params: HashMap::new(),
1500            use_warmup,
1501            num_updates: 0,
1502            initialized: false,
1503        }
1504    }
1505
1506    /// Initialize shadow parameters from current model parameters.
1507    pub fn initialize(
1508        &mut self,
1509        parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
1510    ) {
1511        self.shadow_params.clear();
1512        for (name, param) in parameters {
1513            self.shadow_params.insert(name.clone(), param.clone());
1514        }
1515        self.initialized = true;
1516    }
1517
1518    /// Update EMA parameters.
1519    pub fn update(
1520        &mut self,
1521        parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
1522    ) -> TrainResult<()> {
1523        if !self.initialized {
1524            return Err(TrainError::CallbackError(
1525                "ModelEMA not initialized. Call initialize() first.".to_string(),
1526            ));
1527        }
1528
1529        self.num_updates += 1;
1530
1531        // Compute effective decay with warmup
1532        let decay = if self.use_warmup {
1533            // Gradual warmup: start with (1 + num_updates) / (10 + num_updates)
1534            // and approach self.decay
1535            let warmup_decay = (1.0 + self.num_updates as f64) / (10.0 + self.num_updates as f64);
1536            warmup_decay.min(self.decay)
1537        } else {
1538            self.decay
1539        };
1540
1541        // Update shadow parameters
1542        for (name, param) in parameters {
1543            if let Some(shadow) = self.shadow_params.get_mut(name) {
1544                // shadow = decay * shadow + (1 - decay) * param
1545                *shadow = &*shadow * decay + &(param * (1.0 - decay));
1546            }
1547        }
1548
1549        Ok(())
1550    }
1551
1552    /// Get the EMA parameters.
1553    pub fn get_shadow_params(
1554        &self,
1555    ) -> &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>> {
1556        &self.shadow_params
1557    }
1558
1559    /// Apply EMA parameters to the model (for evaluation).
1560    pub fn apply_shadow(
1561        &self,
1562        parameters: &mut HashMap<
1563            String,
1564            scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>,
1565        >,
1566    ) {
1567        for (name, shadow) in &self.shadow_params {
1568            if let Some(param) = parameters.get_mut(name) {
1569                *param = shadow.clone();
1570            }
1571        }
1572    }
1573}
1574
1575impl Callback for ModelEMACallback {
1576    fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()> {
1577        // Note: Initialization must be done externally since we don't have access to parameters here
1578        Ok(())
1579    }
1580
1581    fn on_batch_end(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
1582        // Note: Update must be called externally since we don't have access to parameters here
1583        Ok(())
1584    }
1585}
1586
1587/// Gradient Accumulation callback.
1588///
1589/// Simulates larger batch sizes by accumulating gradients over multiple
1590/// mini-batches before updating parameters. This is useful when GPU memory
1591/// is limited but you want to train with effectively larger batches.
1592///
1593/// Effective batch size = mini_batch_size * accumulation_steps
1594pub struct GradientAccumulationCallback {
1595    /// Number of steps to accumulate gradients before updating.
1596    accumulation_steps: usize,
1597    /// Current accumulation counter.
1598    current_step: usize,
1599    /// Accumulated gradients.
1600    accumulated_grads: HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
1601    /// Whether gradients are initialized.
1602    initialized: bool,
1603}
1604
1605impl GradientAccumulationCallback {
1606    /// Create a new Gradient Accumulation callback.
1607    ///
1608    /// # Arguments
1609    /// * `accumulation_steps` - Number of mini-batches to accumulate (e.g., 4, 8, 16)
1610    pub fn new(accumulation_steps: usize) -> TrainResult<Self> {
1611        if accumulation_steps == 0 {
1612            return Err(TrainError::CallbackError(
1613                "Accumulation steps must be greater than 0".to_string(),
1614            ));
1615        }
1616
1617        Ok(Self {
1618            accumulation_steps,
1619            current_step: 0,
1620            accumulated_grads: HashMap::new(),
1621            initialized: false,
1622        })
1623    }
1624
1625    /// Accumulate gradients.
1626    pub fn accumulate(
1627        &mut self,
1628        gradients: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
1629    ) -> TrainResult<()> {
1630        if !self.initialized {
1631            // Initialize on first call
1632            for (name, grad) in gradients {
1633                self.accumulated_grads.insert(name.clone(), grad.clone());
1634            }
1635            self.initialized = true;
1636        } else {
1637            // Accumulate
1638            for (name, grad) in gradients {
1639                if let Some(acc_grad) = self.accumulated_grads.get_mut(name) {
1640                    *acc_grad = &*acc_grad + grad;
1641                }
1642            }
1643        }
1644
1645        self.current_step += 1;
1646        Ok(())
1647    }
1648
1649    /// Check if we should perform an optimizer step.
1650    pub fn should_update(&self) -> bool {
1651        self.current_step >= self.accumulation_steps
1652    }
1653
1654    /// Get averaged accumulated gradients and reset.
1655    pub fn get_and_reset(
1656        &mut self,
1657    ) -> HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>> {
1658        let scale = 1.0 / self.accumulation_steps as f64;
1659
1660        let mut averaged_grads = HashMap::new();
1661        for (name, grad) in &self.accumulated_grads {
1662            averaged_grads.insert(name.clone(), grad * scale);
1663        }
1664
1665        // Reset
1666        self.current_step = 0;
1667        self.initialized = false;
1668        self.accumulated_grads.clear();
1669
1670        averaged_grads
1671    }
1672}
1673
1674impl Callback for GradientAccumulationCallback {
1675    fn on_epoch_begin(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
1676        // Reset at the beginning of each epoch
1677        self.current_step = 0;
1678        self.initialized = false;
1679        self.accumulated_grads.clear();
1680        Ok(())
1681    }
1682}
1683
1684/// SWA (Stochastic Weight Averaging) callback.
1685///
1686/// Averages model parameters over the course of training, typically starting
1687/// from a later epoch. This often leads to better generalization and wider optima.
1688///
1689/// Reference: Izmailov et al. "Averaging Weights Leads to Wider Optima and Better Generalization" (UAI 2018)
1690pub struct SWACallback {
1691    /// Epoch to start SWA (e.g., 75% through training).
1692    start_epoch: usize,
1693    /// Frequency of parameter averaging (every N epochs).
1694    update_frequency: usize,
1695    /// Running average of parameters.
1696    swa_params: HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
1697    /// Number of models averaged so far.
1698    num_averaged: usize,
1699    /// Whether SWA is active.
1700    active: bool,
1701    /// Whether SWA parameters are initialized.
1702    initialized: bool,
1703    /// Verbose output.
1704    verbose: bool,
1705}
1706
1707impl SWACallback {
1708    /// Create a new SWA callback.
1709    ///
1710    /// # Arguments
1711    /// * `start_epoch` - Epoch to start averaging (e.g., 0.75 * total_epochs)
1712    /// * `update_frequency` - Average parameters every N epochs (typically 1)
1713    /// * `verbose` - Whether to print progress
1714    pub fn new(start_epoch: usize, update_frequency: usize, verbose: bool) -> Self {
1715        Self {
1716            start_epoch,
1717            update_frequency,
1718            swa_params: HashMap::new(),
1719            num_averaged: 0,
1720            active: false,
1721            initialized: false,
1722            verbose,
1723        }
1724    }
1725
1726    /// Update SWA parameters with current model parameters.
1727    pub fn update_average(
1728        &mut self,
1729        parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
1730    ) -> TrainResult<()> {
1731        if !self.active {
1732            return Ok(());
1733        }
1734
1735        if !self.initialized {
1736            // Initialize with first model
1737            for (name, param) in parameters {
1738                self.swa_params.insert(name.clone(), param.clone());
1739            }
1740            self.initialized = true;
1741            self.num_averaged = 1;
1742
1743            if self.verbose {
1744                println!("📊 SWA: Initialized with model parameters");
1745            }
1746        } else {
1747            // Running average: swa = (swa * n + param) / (n + 1)
1748            let n = self.num_averaged as f64;
1749            for (name, param) in parameters {
1750                if let Some(swa_param) = self.swa_params.get_mut(name) {
1751                    *swa_param = &(&*swa_param * n + param) / (n + 1.0);
1752                }
1753            }
1754            self.num_averaged += 1;
1755
1756            if self.verbose {
1757                println!("📊 SWA: Updated average (n={})", self.num_averaged);
1758            }
1759        }
1760
1761        Ok(())
1762    }
1763
1764    /// Get the SWA parameters.
1765    pub fn get_swa_params(
1766        &self,
1767    ) -> &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>> {
1768        &self.swa_params
1769    }
1770
1771    /// Apply SWA parameters to the model.
1772    pub fn apply_swa(
1773        &self,
1774        parameters: &mut HashMap<
1775            String,
1776            scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>,
1777        >,
1778    ) {
1779        if self.initialized {
1780            for (name, swa_param) in &self.swa_params {
1781                if let Some(param) = parameters.get_mut(name) {
1782                    *param = swa_param.clone();
1783                }
1784            }
1785        }
1786    }
1787
1788    /// Check if SWA has collected any averages.
1789    pub fn is_ready(&self) -> bool {
1790        self.initialized && self.num_averaged > 0
1791    }
1792}
1793
1794impl Callback for SWACallback {
1795    fn on_epoch_end(&mut self, epoch: usize, _state: &TrainingState) -> TrainResult<()> {
1796        // Activate SWA at start_epoch
1797        if epoch >= self.start_epoch && !self.active {
1798            self.active = true;
1799            if self.verbose {
1800                println!("\n📊 SWA: Activated at epoch {}", epoch + 1);
1801            }
1802        }
1803
1804        // Check if we should update average
1805        if self.active && epoch >= self.start_epoch {
1806            let relative_epoch = epoch - self.start_epoch;
1807            if relative_epoch.is_multiple_of(self.update_frequency) {
1808                // Note: Actual update must be called externally with parameters
1809                if self.verbose && self.initialized {
1810                    println!(
1811                        "📊 SWA: Ready to update at epoch {} (call update_average with parameters)",
1812                        epoch + 1
1813                    );
1814                }
1815            }
1816        }
1817
1818        Ok(())
1819    }
1820
1821    fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
1822        if self.verbose && self.initialized {
1823            println!(
1824                "\n📊 SWA: Training complete. Averaged {} models.",
1825                self.num_averaged
1826            );
1827            println!("📊 SWA: Call apply_swa() to use averaged parameters.");
1828        }
1829        Ok(())
1830    }
1831}
1832
1833#[cfg(test)]
1834mod profiling_tests {
1835    use super::*;
1836
1837    #[test]
1838    fn test_histogram_stats() {
1839        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1840        let stats = HistogramStats::compute("test", &values, 5);
1841
1842        assert_eq!(stats.name, "test");
1843        assert_eq!(stats.min, 1.0);
1844        assert_eq!(stats.max, 10.0);
1845        assert!((stats.mean - 5.5).abs() < 1e-6);
1846        assert_eq!(stats.bins.len(), 6);
1847        assert_eq!(stats.counts.len(), 5);
1848        assert_eq!(stats.counts.iter().sum::<usize>(), 10);
1849    }
1850
1851    #[test]
1852    fn test_histogram_callback() {
1853        use std::collections::HashMap;
1854        let mut callback = HistogramCallback::new(2, 10, false);
1855        let state = TrainingState {
1856            epoch: 0,
1857            batch: 0,
1858            train_loss: 0.5,
1859            batch_loss: 0.5,
1860            val_loss: Some(0.6),
1861            learning_rate: 0.01,
1862            metrics: HashMap::new(),
1863        };
1864
1865        // Should not log on epoch 0
1866        callback.on_epoch_end(0, &state).unwrap();
1867        assert_eq!(callback.history.len(), 0);
1868
1869        // Should log on epoch 1 (frequency=2, so every 2 epochs)
1870        callback.on_epoch_end(1, &state).unwrap();
1871        assert_eq!(callback.history.len(), 1);
1872    }
1873
1874    #[test]
1875    fn test_profiling_callback() {
1876        use std::collections::HashMap;
1877        let mut callback = ProfilingCallback::new(false, 1);
1878        let state = TrainingState {
1879            epoch: 0,
1880            batch: 0,
1881            train_loss: 0.5,
1882            batch_loss: 0.5,
1883            val_loss: Some(0.6),
1884            learning_rate: 0.01,
1885            metrics: HashMap::new(),
1886        };
1887
1888        callback.on_train_begin(&state).unwrap();
1889        assert!(callback.start_time.is_some());
1890
1891        callback.on_epoch_begin(0, &state).unwrap();
1892        assert!(callback.epoch_start_time.is_some());
1893
1894        callback.on_batch_begin(0, &state).unwrap();
1895        std::thread::sleep(std::time::Duration::from_millis(10));
1896        callback.on_batch_end(0, &state).unwrap();
1897
1898        assert_eq!(callback.total_batches, 1);
1899        assert_eq!(callback.current_epoch_batch_times.len(), 1);
1900
1901        callback.on_epoch_end(0, &state).unwrap();
1902        assert_eq!(callback.stats.epoch_times.len(), 1);
1903
1904        callback.on_train_end(&state).unwrap();
1905        assert!(callback.stats.total_time > 0.0);
1906    }
1907}