1use crate::{TrainError, TrainResult, TrainingState};
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7pub trait Callback {
9 fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()> {
11 Ok(())
12 }
13
14 fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
16 Ok(())
17 }
18
19 fn on_epoch_begin(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
21 Ok(())
22 }
23
24 fn on_epoch_end(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
26 Ok(())
27 }
28
29 fn on_batch_begin(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
31 Ok(())
32 }
33
34 fn on_batch_end(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
36 Ok(())
37 }
38
39 fn on_validation_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
41 Ok(())
42 }
43
44 fn should_stop(&self) -> bool {
46 false
47 }
48}
49
50pub struct CallbackList {
52 callbacks: Vec<Box<dyn Callback>>,
53}
54
55impl CallbackList {
56 pub fn new() -> Self {
58 Self {
59 callbacks: Vec::new(),
60 }
61 }
62
63 pub fn add(&mut self, callback: Box<dyn Callback>) {
65 self.callbacks.push(callback);
66 }
67
68 pub fn on_train_begin(&mut self, state: &TrainingState) -> TrainResult<()> {
70 for callback in &mut self.callbacks {
71 callback.on_train_begin(state)?;
72 }
73 Ok(())
74 }
75
76 pub fn on_train_end(&mut self, state: &TrainingState) -> TrainResult<()> {
78 for callback in &mut self.callbacks {
79 callback.on_train_end(state)?;
80 }
81 Ok(())
82 }
83
84 pub fn on_epoch_begin(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
86 for callback in &mut self.callbacks {
87 callback.on_epoch_begin(epoch, state)?;
88 }
89 Ok(())
90 }
91
92 pub fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
94 for callback in &mut self.callbacks {
95 callback.on_epoch_end(epoch, state)?;
96 }
97 Ok(())
98 }
99
100 pub fn on_batch_begin(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
102 for callback in &mut self.callbacks {
103 callback.on_batch_begin(batch, state)?;
104 }
105 Ok(())
106 }
107
108 pub fn on_batch_end(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
110 for callback in &mut self.callbacks {
111 callback.on_batch_end(batch, state)?;
112 }
113 Ok(())
114 }
115
116 pub fn on_validation_end(&mut self, state: &TrainingState) -> TrainResult<()> {
118 for callback in &mut self.callbacks {
119 callback.on_validation_end(state)?;
120 }
121 Ok(())
122 }
123
124 pub fn should_stop(&self) -> bool {
126 self.callbacks.iter().any(|cb| cb.should_stop())
127 }
128}
129
130impl Default for CallbackList {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136pub struct EpochCallback {
138 pub verbose: bool,
140}
141
142impl EpochCallback {
143 pub fn new(verbose: bool) -> Self {
145 Self { verbose }
146 }
147}
148
149impl Callback for EpochCallback {
150 fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
151 if self.verbose {
152 println!(
153 "Epoch {}: loss={:.6}, val_loss={:.6}",
154 epoch,
155 state.train_loss,
156 state.val_loss.unwrap_or(f64::NAN)
157 );
158 }
159 Ok(())
160 }
161}
162
163pub struct BatchCallback {
165 pub log_frequency: usize,
167}
168
169impl BatchCallback {
170 pub fn new(log_frequency: usize) -> Self {
172 Self { log_frequency }
173 }
174}
175
176impl Callback for BatchCallback {
177 fn on_batch_end(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
178 if batch.is_multiple_of(self.log_frequency) {
179 println!("Batch {}: loss={:.6}", batch, state.batch_loss);
180 }
181 Ok(())
182 }
183}
184
185pub struct ValidationCallback {
187 pub validation_frequency: usize,
189}
190
191impl ValidationCallback {
192 pub fn new(validation_frequency: usize) -> Self {
194 Self {
195 validation_frequency,
196 }
197 }
198}
199
200impl Callback for ValidationCallback {
201 fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
202 if epoch.is_multiple_of(self.validation_frequency) {
203 if let Some(val_loss) = state.val_loss {
204 println!("Validation at epoch {}: val_loss={:.6}", epoch, val_loss);
205 }
206 }
207 Ok(())
208 }
209}
210
211#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
216pub struct TrainingCheckpoint {
217 pub epoch: usize,
219 pub parameters: HashMap<String, Vec<f64>>,
221 pub optimizer_state: HashMap<String, Vec<f64>>,
223 pub scheduler_state: Option<HashMap<String, f64>>,
225 pub train_loss: f64,
227 pub val_loss: Option<f64>,
229 pub train_loss_history: Vec<f64>,
231 pub val_loss_history: Vec<f64>,
233 pub metrics_history: HashMap<String, Vec<f64>>,
235 pub learning_rate: f64,
237 pub best_val_loss: Option<f64>,
239}
240
241impl TrainingCheckpoint {
242 #[allow(clippy::too_many_arguments)]
244 pub fn new(
245 epoch: usize,
246 parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
247 optimizer_state: &HashMap<String, Vec<f64>>,
248 scheduler_state: Option<HashMap<String, f64>>,
249 state: &TrainingState,
250 train_loss_history: &[f64],
251 val_loss_history: &[f64],
252 metrics_history: &HashMap<String, Vec<f64>>,
253 best_val_loss: Option<f64>,
254 ) -> Self {
255 let parameters = parameters
257 .iter()
258 .map(|(name, param)| (name.clone(), param.iter().copied().collect()))
259 .collect();
260
261 Self {
262 epoch,
263 parameters,
264 optimizer_state: optimizer_state.clone(),
265 scheduler_state,
266 train_loss: state.train_loss,
267 val_loss: state.val_loss,
268 train_loss_history: train_loss_history.to_vec(),
269 val_loss_history: val_loss_history.to_vec(),
270 metrics_history: metrics_history.clone(),
271 learning_rate: state.learning_rate,
272 best_val_loss,
273 }
274 }
275
276 pub fn save(&self, path: &PathBuf) -> TrainResult<()> {
278 let json = serde_json::to_string_pretty(self).map_err(|e| {
279 TrainError::CheckpointError(format!("Failed to serialize checkpoint: {}", e))
280 })?;
281
282 if let Some(parent) = path.parent() {
283 std::fs::create_dir_all(parent).map_err(|e| {
284 TrainError::CheckpointError(format!("Failed to create checkpoint directory: {}", e))
285 })?;
286 }
287
288 std::fs::write(path, json).map_err(|e| {
289 TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
290 })?;
291
292 Ok(())
293 }
294
295 pub fn load(path: &PathBuf) -> TrainResult<Self> {
297 let json = std::fs::read_to_string(path).map_err(|e| {
298 TrainError::CheckpointError(format!("Failed to read checkpoint: {}", e))
299 })?;
300
301 let checkpoint: Self = serde_json::from_str(&json).map_err(|e| {
302 TrainError::CheckpointError(format!("Failed to deserialize checkpoint: {}", e))
303 })?;
304
305 Ok(checkpoint)
306 }
307}
308
309pub struct CheckpointCallback {
311 pub checkpoint_dir: PathBuf,
313 pub save_frequency: usize,
315 pub save_best_only: bool,
317 best_val_loss: Option<f64>,
319}
320
321impl CheckpointCallback {
322 pub fn new(checkpoint_dir: PathBuf, save_frequency: usize, save_best_only: bool) -> Self {
324 Self {
325 checkpoint_dir,
326 save_frequency,
327 save_best_only,
328 best_val_loss: None,
329 }
330 }
331
332 fn save_checkpoint(&self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
334 let checkpoint_path = self
335 .checkpoint_dir
336 .join(format!("checkpoint_epoch_{}.json", epoch));
337
338 let mut checkpoint = HashMap::new();
340 checkpoint.insert("epoch".to_string(), epoch as f64);
341 checkpoint.insert("train_loss".to_string(), state.train_loss);
342 if let Some(val_loss) = state.val_loss {
343 checkpoint.insert("val_loss".to_string(), val_loss);
344 }
345
346 let json = serde_json::to_string_pretty(&checkpoint).map_err(|e| {
348 TrainError::CheckpointError(format!("Failed to serialize checkpoint: {}", e))
349 })?;
350
351 std::fs::create_dir_all(&self.checkpoint_dir).map_err(|e| {
352 TrainError::CheckpointError(format!("Failed to create checkpoint directory: {}", e))
353 })?;
354
355 std::fs::write(&checkpoint_path, json).map_err(|e| {
356 TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
357 })?;
358
359 println!("Checkpoint saved to {:?}", checkpoint_path);
360 Ok(())
361 }
362}
363
364impl Callback for CheckpointCallback {
365 fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
366 if !epoch.is_multiple_of(self.save_frequency) {
367 return Ok(());
368 }
369
370 if self.save_best_only {
371 if let Some(val_loss) = state.val_loss {
372 let should_save = self
373 .best_val_loss
374 .map(|best| val_loss < best)
375 .unwrap_or(true);
376
377 if should_save {
378 self.best_val_loss = Some(val_loss);
379 self.save_checkpoint(epoch, state)?;
380 }
381 }
382 } else {
383 self.save_checkpoint(epoch, state)?;
384 }
385
386 Ok(())
387 }
388}
389
390pub struct EarlyStoppingCallback {
392 pub patience: usize,
394 pub min_delta: f64,
396 best_val_loss: Option<f64>,
398 wait: usize,
400 stop_training: bool,
402}
403
404impl EarlyStoppingCallback {
405 pub fn new(patience: usize, min_delta: f64) -> Self {
407 Self {
408 patience,
409 min_delta,
410 best_val_loss: None,
411 wait: 0,
412 stop_training: false,
413 }
414 }
415}
416
417impl Callback for EarlyStoppingCallback {
418 fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
419 if let Some(val_loss) = state.val_loss {
420 let improved = self
421 .best_val_loss
422 .map(|best| val_loss < best - self.min_delta)
423 .unwrap_or(true);
424
425 if improved {
426 self.best_val_loss = Some(val_loss);
427 self.wait = 0;
428 } else {
429 self.wait += 1;
430 if self.wait >= self.patience {
431 println!(
432 "Early stopping at epoch {} (no improvement for {} epochs)",
433 epoch, self.patience
434 );
435 self.stop_training = true;
436 }
437 }
438 }
439
440 Ok(())
441 }
442
443 fn should_stop(&self) -> bool {
444 self.stop_training
445 }
446}
447
448#[allow(dead_code)]
450pub struct ReduceLrOnPlateauCallback {
451 pub factor: f64,
453 pub patience: usize,
455 pub min_delta: f64,
457 pub min_lr: f64,
459 best_val_loss: Option<f64>,
461 wait: usize,
463}
464
465impl ReduceLrOnPlateauCallback {
466 #[allow(dead_code)]
468 pub fn new(factor: f64, patience: usize, min_delta: f64, min_lr: f64) -> Self {
469 Self {
470 factor,
471 patience,
472 min_delta,
473 min_lr,
474 best_val_loss: None,
475 wait: 0,
476 }
477 }
478}
479
480impl Callback for ReduceLrOnPlateauCallback {
481 fn on_epoch_end(&mut self, _epoch: usize, state: &TrainingState) -> TrainResult<()> {
482 if let Some(val_loss) = state.val_loss {
483 let improved = self
484 .best_val_loss
485 .map(|best| val_loss < best - self.min_delta)
486 .unwrap_or(true);
487
488 if improved {
489 self.best_val_loss = Some(val_loss);
490 self.wait = 0;
491 } else {
492 self.wait += 1;
493 if self.wait >= self.patience {
494 let new_lr = (state.learning_rate * self.factor).max(self.min_lr);
497 if new_lr != state.learning_rate {
498 println!("Reducing learning rate to {:.6}", new_lr);
499 }
500 self.wait = 0;
501 }
502 }
503 }
504
505 Ok(())
506 }
507}
508
509pub struct LearningRateFinder {
529 start_lr: f64,
531 end_lr: f64,
533 num_steps: usize,
535 current_step: usize,
537 pub history: Vec<(f64, f64)>,
539 exponential: bool,
541 smoothing: f64,
543 smoothed_loss: Option<f64>,
545}
546
547impl LearningRateFinder {
548 pub fn new(start_lr: f64, end_lr: f64, num_steps: usize) -> Self {
555 Self {
556 start_lr,
557 end_lr,
558 num_steps,
559 current_step: 0,
560 history: Vec::with_capacity(num_steps),
561 exponential: true, smoothing: 0.0, smoothed_loss: None,
564 }
565 }
566
567 pub fn with_exponential_scaling(mut self) -> Self {
569 self.exponential = true;
570 self
571 }
572
573 pub fn with_linear_scaling(mut self) -> Self {
575 self.exponential = false;
576 self
577 }
578
579 pub fn with_smoothing(mut self, smoothing: f64) -> Self {
583 self.smoothing = smoothing.clamp(0.0, 1.0);
584 self
585 }
586
587 fn compute_lr(&self) -> f64 {
589 if self.num_steps <= 1 {
590 return self.start_lr;
591 }
592
593 let step_ratio = self.current_step as f64 / (self.num_steps - 1) as f64;
594
595 if self.exponential {
596 self.start_lr * (self.end_lr / self.start_lr).powf(step_ratio)
598 } else {
599 self.start_lr + (self.end_lr - self.start_lr) * step_ratio
601 }
602 }
603
604 fn smooth_loss(&mut self, loss: f64) -> f64 {
606 if self.smoothing == 0.0 {
607 return loss;
608 }
609
610 match self.smoothed_loss {
611 None => {
612 self.smoothed_loss = Some(loss);
613 loss
614 }
615 Some(prev) => {
616 let smoothed = self.smoothing * prev + (1.0 - self.smoothing) * loss;
617 self.smoothed_loss = Some(smoothed);
618 smoothed
619 }
620 }
621 }
622
623 pub fn suggest_lr(&self) -> Option<f64> {
627 if self.history.len() < 3 {
628 return None;
629 }
630
631 let mut best_lr = None;
632 let mut best_gradient = f64::INFINITY;
633
634 for i in 1..self.history.len() {
636 let (lr1, loss1) = self.history[i - 1];
637 let (lr2, loss2) = self.history[i];
638
639 let gradient = (loss2 - loss1) / (lr2 - lr1);
640
641 if gradient < best_gradient {
642 best_gradient = gradient;
643 best_lr = Some(lr2);
644 }
645 }
646
647 best_lr
648 }
649
650 pub fn print_results(&self) {
652 println!("\n=== Learning Rate Finder Results ===");
653 println!(
654 "Tested {} learning rates from {:.2e} to {:.2e}",
655 self.history.len(),
656 self.start_lr,
657 self.end_lr
658 );
659
660 if let Some(suggested_lr) = self.suggest_lr() {
661 println!("Suggested optimal LR: {:.2e}", suggested_lr);
662 println!(
663 "Consider using LR between {:.2e} and {:.2e}",
664 suggested_lr / 10.0,
665 suggested_lr
666 );
667 }
668
669 println!("\nLR, Loss:");
670 for (lr, loss) in &self.history {
671 println!("{:.6e}, {:.6}", lr, loss);
672 }
673 println!("===================================\n");
674 }
675}
676
677impl Callback for LearningRateFinder {
678 fn on_batch_end(&mut self, _batch: usize, state: &TrainingState) -> TrainResult<()> {
679 if self.current_step >= self.num_steps {
680 return Ok(());
681 }
682
683 let loss = self.smooth_loss(state.batch_loss);
685
686 let lr = self.compute_lr();
688 self.history.push((lr, loss));
689
690 self.current_step += 1;
691
692 Ok(())
696 }
697
698 fn should_stop(&self) -> bool {
699 self.current_step >= self.num_steps
701 }
702}
703
704pub struct GradientMonitor {
721 log_frequency: usize,
723 vanishing_threshold: f64,
725 exploding_threshold: f64,
727 pub gradient_norms: Vec<f64>,
729 pub gradient_means: Vec<f64>,
731 pub gradient_stds: Vec<f64>,
733 pub vanishing_count: usize,
735 pub exploding_count: usize,
737 batch_counter: usize,
739}
740
741impl GradientMonitor {
742 pub fn new(log_frequency: usize, vanishing_threshold: f64, exploding_threshold: f64) -> Self {
749 Self {
750 log_frequency,
751 vanishing_threshold,
752 exploding_threshold,
753 gradient_norms: Vec::new(),
754 gradient_means: Vec::new(),
755 gradient_stds: Vec::new(),
756 vanishing_count: 0,
757 exploding_count: 0,
758 batch_counter: 0,
759 }
760 }
761
762 fn compute_gradient_stats(&mut self, _state: &TrainingState) -> (f64, f64, f64) {
764 (1.0, 0.0, 0.1)
768 }
769
770 fn check_vanishing(&mut self, norm: f64) -> bool {
772 if norm < self.vanishing_threshold {
773 self.vanishing_count += 1;
774 return true;
775 }
776 false
777 }
778
779 fn check_exploding(&mut self, norm: f64) -> bool {
781 if norm > self.exploding_threshold {
782 self.exploding_count += 1;
783 return true;
784 }
785 false
786 }
787
788 fn print_stats(&self, norm: f64, mean: f64, std: f64) {
790 println!("Gradient Stats [Batch {}]:", self.batch_counter);
791 println!(" Norm: {:.6e}, Mean: {:.6e}, Std: {:.6e}", norm, mean, std);
792
793 if self.vanishing_count > 0 {
794 println!(
795 " ⚠️ Vanishing gradient warnings: {}",
796 self.vanishing_count
797 );
798 }
799
800 if self.exploding_count > 0 {
801 println!(
802 " ⚠️ Exploding gradient warnings: {}",
803 self.exploding_count
804 );
805 }
806 }
807
808 pub fn summary(&self) -> GradientSummary {
810 let avg_norm = if !self.gradient_norms.is_empty() {
811 self.gradient_norms.iter().sum::<f64>() / self.gradient_norms.len() as f64
812 } else {
813 0.0
814 };
815
816 GradientSummary {
817 total_batches: self.batch_counter,
818 average_norm: avg_norm,
819 vanishing_count: self.vanishing_count,
820 exploding_count: self.exploding_count,
821 }
822 }
823}
824
825#[derive(Debug, Clone)]
827pub struct GradientSummary {
828 pub total_batches: usize,
830 pub average_norm: f64,
832 pub vanishing_count: usize,
834 pub exploding_count: usize,
836}
837
838impl Callback for GradientMonitor {
839 fn on_batch_end(&mut self, _batch: usize, state: &TrainingState) -> TrainResult<()> {
840 self.batch_counter += 1;
841
842 let (norm, mean, std) = self.compute_gradient_stats(state);
844
845 self.gradient_norms.push(norm);
847 self.gradient_means.push(mean);
848 self.gradient_stds.push(std);
849
850 let vanishing = self.check_vanishing(norm);
852 let exploding = self.check_exploding(norm);
853
854 if self.batch_counter.is_multiple_of(self.log_frequency) {
856 self.print_stats(norm, mean, std);
857 } else if vanishing || exploding {
858 self.print_stats(norm, mean, std);
860 }
861
862 Ok(())
863 }
864
865 fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
866 let summary = self.summary();
867 println!("\n=== Gradient Monitoring Summary ===");
868 println!("Total batches: {}", summary.total_batches);
869 println!("Average gradient norm: {:.6e}", summary.average_norm);
870 println!("Vanishing gradient warnings: {}", summary.vanishing_count);
871 println!("Exploding gradient warnings: {}", summary.exploding_count);
872 println!("====================================\n");
873 Ok(())
874 }
875}
876
877#[cfg(test)]
878mod tests {
879 use super::*;
880
881 fn create_test_state() -> TrainingState {
882 TrainingState {
883 epoch: 0,
884 batch: 0,
885 train_loss: 1.0,
886 val_loss: Some(0.8),
887 batch_loss: 0.5,
888 learning_rate: 0.001,
889 metrics: HashMap::new(),
890 }
891 }
892
893 #[test]
894 fn test_callback_list() {
895 let mut callbacks = CallbackList::new();
896 callbacks.add(Box::new(EpochCallback::new(false)));
897
898 let state = create_test_state();
899 callbacks.on_train_begin(&state).unwrap();
900 callbacks.on_epoch_begin(0, &state).unwrap();
901 callbacks.on_epoch_end(0, &state).unwrap();
902 callbacks.on_train_end(&state).unwrap();
903 }
904
905 #[test]
906 fn test_early_stopping() {
907 let mut callback = EarlyStoppingCallback::new(2, 0.01);
908 let mut state = create_test_state();
909
910 state.val_loss = Some(1.0);
912 callback.on_epoch_end(0, &state).unwrap();
913 assert!(!callback.should_stop());
914
915 state.val_loss = Some(0.8);
917 callback.on_epoch_end(1, &state).unwrap();
918 assert!(!callback.should_stop());
919
920 state.val_loss = Some(0.81);
922 callback.on_epoch_end(2, &state).unwrap();
923 assert!(!callback.should_stop());
924
925 state.val_loss = Some(0.82);
927 callback.on_epoch_end(3, &state).unwrap();
928 assert!(callback.should_stop());
929 }
930
931 #[test]
932 fn test_checkpoint_callback() {
933 use std::env::temp_dir;
934
935 let checkpoint_dir = temp_dir().join("tensorlogic_test_checkpoints");
936 let mut callback = CheckpointCallback::new(checkpoint_dir.clone(), 1, false);
937 let state = create_test_state();
938
939 callback.on_epoch_end(0, &state).unwrap();
940
941 let checkpoint_path = checkpoint_dir.join("checkpoint_epoch_0.json");
943 assert!(checkpoint_path.exists());
944
945 std::fs::remove_dir_all(checkpoint_dir).ok();
947 }
948
949 #[test]
950 fn test_training_checkpoint_save_load() {
951 use scirs2_core::ndarray::Array2;
952 use std::env::temp_dir;
953
954 let mut parameters = HashMap::new();
956 parameters.insert("weight".to_string(), Array2::from_elem((2, 3), 1.5));
957 parameters.insert("bias".to_string(), Array2::from_elem((1, 3), 0.5));
958
959 let state = TrainingState {
961 epoch: 5,
962 batch: 100,
963 train_loss: 0.75,
964 val_loss: Some(0.85),
965 batch_loss: 0.72,
966 learning_rate: 0.001,
967 metrics: HashMap::new(),
968 };
969
970 let optimizer_state = {
972 let mut state = HashMap::new();
973 state.insert("momentum_weight".to_string(), vec![0.1, 0.2, 0.3]);
974 state.insert("momentum_bias".to_string(), vec![0.05]);
975 state
976 };
977
978 let checkpoint = TrainingCheckpoint::new(
980 5,
981 ¶meters,
982 &optimizer_state,
983 None,
984 &state,
985 &[1.0, 0.9, 0.8, 0.77, 0.75],
986 &[1.1, 0.95, 0.88, 0.87, 0.85],
987 &HashMap::new(),
988 Some(0.85),
989 );
990
991 let checkpoint_path = temp_dir().join("test_training_checkpoint.json");
993 checkpoint.save(&checkpoint_path).unwrap();
994
995 assert!(checkpoint_path.exists());
997
998 let loaded = TrainingCheckpoint::load(&checkpoint_path).unwrap();
1000
1001 assert_eq!(loaded.epoch, 5);
1003 assert_eq!(loaded.train_loss, 0.75);
1004 assert_eq!(loaded.val_loss, Some(0.85));
1005 assert_eq!(loaded.learning_rate, 0.001);
1006 assert_eq!(loaded.train_loss_history.len(), 5);
1007 assert_eq!(loaded.val_loss_history.len(), 5);
1008 assert_eq!(loaded.best_val_loss, Some(0.85));
1009
1010 assert_eq!(loaded.parameters.len(), 2);
1012 assert!(loaded.parameters.contains_key("weight"));
1013 assert!(loaded.parameters.contains_key("bias"));
1014
1015 assert_eq!(loaded.optimizer_state.len(), 2);
1017 assert!(loaded.optimizer_state.contains_key("momentum_weight"));
1018
1019 std::fs::remove_file(checkpoint_path).ok();
1021 }
1022
1023 #[test]
1024 fn test_training_checkpoint_with_metrics() {
1025 use scirs2_core::ndarray::Array2;
1026 use std::env::temp_dir;
1027
1028 let mut parameters = HashMap::new();
1029 parameters.insert("w".to_string(), Array2::zeros((2, 2)));
1030
1031 let state = create_test_state();
1032 let optimizer_state = HashMap::new();
1033
1034 let mut metrics_history = HashMap::new();
1036 metrics_history.insert("accuracy".to_string(), vec![0.5, 0.6, 0.7]);
1037 metrics_history.insert("f1_score".to_string(), vec![0.45, 0.55, 0.65]);
1038
1039 let checkpoint = TrainingCheckpoint::new(
1040 2,
1041 ¶meters,
1042 &optimizer_state,
1043 None,
1044 &state,
1045 &[1.0, 0.8, 0.6],
1046 &[1.1, 0.9, 0.7],
1047 &metrics_history,
1048 Some(0.7),
1049 );
1050
1051 let checkpoint_path = temp_dir().join("test_checkpoint_with_metrics.json");
1052 checkpoint.save(&checkpoint_path).unwrap();
1053
1054 let loaded = TrainingCheckpoint::load(&checkpoint_path).unwrap();
1055
1056 assert_eq!(loaded.metrics_history.len(), 2);
1058 assert!(loaded.metrics_history.contains_key("accuracy"));
1059 assert!(loaded.metrics_history.contains_key("f1_score"));
1060 assert_eq!(loaded.metrics_history["accuracy"].len(), 3);
1061
1062 std::fs::remove_file(checkpoint_path).ok();
1063 }
1064}
1065
1066#[derive(Debug, Clone)]
1068pub struct HistogramStats {
1069 pub name: String,
1071 pub min: f64,
1073 pub max: f64,
1075 pub mean: f64,
1077 pub std: f64,
1079 pub bins: Vec<f64>,
1081 pub counts: Vec<usize>,
1083}
1084
1085impl HistogramStats {
1086 pub fn compute(name: &str, values: &[f64], num_bins: usize) -> Self {
1088 if values.is_empty() {
1089 return Self {
1090 name: name.to_string(),
1091 min: 0.0,
1092 max: 0.0,
1093 mean: 0.0,
1094 std: 0.0,
1095 bins: vec![],
1096 counts: vec![],
1097 };
1098 }
1099
1100 let min = values.iter().copied().fold(f64::INFINITY, f64::min);
1102 let max = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
1103 let sum: f64 = values.iter().sum();
1104 let mean = sum / values.len() as f64;
1105
1106 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
1107 let std = variance.sqrt();
1108
1109 let mut bins = Vec::with_capacity(num_bins + 1);
1111 let mut counts = vec![0; num_bins];
1112
1113 let range = max - min;
1114 let bin_width = if range > 0.0 {
1115 range / num_bins as f64
1116 } else {
1117 1.0
1118 };
1119
1120 for i in 0..=num_bins {
1121 bins.push(min + i as f64 * bin_width);
1122 }
1123
1124 for &value in values {
1126 let bin_idx = if range > 0.0 {
1127 ((value - min) / bin_width).floor() as usize
1128 } else {
1129 0
1130 };
1131 let bin_idx = bin_idx.min(num_bins - 1);
1132 counts[bin_idx] += 1;
1133 }
1134
1135 Self {
1136 name: name.to_string(),
1137 min,
1138 max,
1139 mean,
1140 std,
1141 bins,
1142 counts,
1143 }
1144 }
1145
1146 pub fn display(&self, width: usize) {
1148 println!("\n=== Histogram: {} ===", self.name);
1149 println!(" Min: {:.6}, Max: {:.6}", self.min, self.max);
1150 println!(" Mean: {:.6}, Std: {:.6}", self.mean, self.std);
1151 println!("\n Distribution:");
1152
1153 if self.counts.is_empty() {
1154 println!(" (empty)");
1155 return;
1156 }
1157
1158 let max_count = *self.counts.iter().max().unwrap_or(&1);
1159
1160 for (i, &count) in self.counts.iter().enumerate() {
1161 let bar_len = if max_count > 0 {
1162 (count as f64 / max_count as f64 * width as f64) as usize
1163 } else {
1164 0
1165 };
1166
1167 let bar = "█".repeat(bar_len);
1168 let left = if i < self.bins.len() - 1 {
1169 self.bins[i]
1170 } else {
1171 self.bins[i - 1]
1172 };
1173 let right = if i < self.bins.len() - 1 {
1174 self.bins[i + 1]
1175 } else {
1176 self.bins[i]
1177 };
1178
1179 println!(" [{:>8.3}, {:>8.3}): {:>6} {}", left, right, count, bar);
1180 }
1181 }
1182}
1183
1184pub struct HistogramCallback {
1206 log_frequency: usize,
1208 #[allow(dead_code)]
1210 num_bins: usize,
1212 verbose: bool,
1214 pub history: Vec<HashMap<String, HistogramStats>>,
1216}
1217
1218impl HistogramCallback {
1219 pub fn new(log_frequency: usize, num_bins: usize, verbose: bool) -> Self {
1226 Self {
1227 log_frequency,
1228 num_bins,
1229 verbose,
1230 history: Vec::new(),
1231 }
1232 }
1233
1234 #[allow(dead_code)] fn compute_histograms(&self, _state: &TrainingState) -> HashMap<String, HistogramStats> {
1237 HashMap::new()
1251 }
1252}
1253
1254impl Callback for HistogramCallback {
1255 fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
1256 if (epoch + 1).is_multiple_of(self.log_frequency) {
1257 let histograms = self.compute_histograms(state);
1258
1259 if self.verbose {
1260 println!("\n--- Weight Histograms (Epoch {}) ---", epoch + 1);
1261 for (_name, stats) in histograms.iter() {
1262 stats.display(40); }
1264 } else {
1265 println!(
1266 "Epoch {}: Computed histograms for {} parameters",
1267 epoch + 1,
1268 histograms.len()
1269 );
1270 }
1271
1272 self.history.push(histograms);
1273 }
1274
1275 Ok(())
1276 }
1277}
1278
1279#[derive(Debug, Clone, Default)]
1281pub struct ProfilingStats {
1282 pub total_time: f64,
1284 pub epoch_times: Vec<f64>,
1286 pub samples_per_sec: f64,
1288 pub batches_per_sec: f64,
1290 pub avg_batch_time: f64,
1292 pub peak_memory_mb: f64,
1294}
1295
1296impl ProfilingStats {
1297 pub fn display(&self) {
1299 println!("\n=== Profiling Statistics ===");
1300 println!("Total time: {:.2}s", self.total_time);
1301 println!("Samples/sec: {:.2}", self.samples_per_sec);
1302 println!("Batches/sec: {:.2}", self.batches_per_sec);
1303 println!("Avg batch time: {:.4}s", self.avg_batch_time);
1304
1305 if !self.epoch_times.is_empty() {
1306 let avg_epoch = self.epoch_times.iter().sum::<f64>() / self.epoch_times.len() as f64;
1307 let min_epoch = self
1308 .epoch_times
1309 .iter()
1310 .copied()
1311 .fold(f64::INFINITY, f64::min);
1312 let max_epoch = self
1313 .epoch_times
1314 .iter()
1315 .copied()
1316 .fold(f64::NEG_INFINITY, f64::max);
1317
1318 println!("\nEpoch times:");
1319 println!(" Average: {:.2}s", avg_epoch);
1320 println!(" Min: {:.2}s", min_epoch);
1321 println!(" Max: {:.2}s", max_epoch);
1322 }
1323 }
1324}
1325
1326pub struct ProfilingCallback {
1347 verbose: bool,
1349 log_frequency: usize,
1351 start_time: Option<std::time::Instant>,
1353 epoch_start_time: Option<std::time::Instant>,
1355 batch_start_time: Option<std::time::Instant>,
1357 pub stats: ProfilingStats,
1359 current_epoch_batch_times: Vec<f64>,
1361 total_batches: usize,
1363}
1364
1365impl ProfilingCallback {
1366 pub fn new(verbose: bool, log_frequency: usize) -> Self {
1372 Self {
1373 verbose,
1374 log_frequency,
1375 start_time: None,
1376 epoch_start_time: None,
1377 batch_start_time: None,
1378 stats: ProfilingStats::default(),
1379 current_epoch_batch_times: Vec::new(),
1380 total_batches: 0,
1381 }
1382 }
1383
1384 pub fn get_stats(&self) -> &ProfilingStats {
1386 &self.stats
1387 }
1388}
1389
1390impl Callback for ProfilingCallback {
1391 fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()> {
1392 self.start_time = Some(std::time::Instant::now());
1393 if self.verbose {
1394 println!("⏱️ Profiling started");
1395 }
1396 Ok(())
1397 }
1398
1399 fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
1400 if let Some(start) = self.start_time {
1401 self.stats.total_time = start.elapsed().as_secs_f64();
1402
1403 if self.total_batches > 0 {
1405 self.stats.avg_batch_time = self.stats.total_time / self.total_batches as f64;
1406 self.stats.batches_per_sec = self.total_batches as f64 / self.stats.total_time;
1407 }
1408
1409 if self.verbose {
1410 println!("\n⏱️ Profiling completed");
1411 self.stats.display();
1412 }
1413 }
1414 Ok(())
1415 }
1416
1417 fn on_epoch_begin(&mut self, epoch: usize, _state: &TrainingState) -> TrainResult<()> {
1418 self.epoch_start_time = Some(std::time::Instant::now());
1419 self.current_epoch_batch_times.clear();
1420
1421 if self.verbose && (epoch + 1).is_multiple_of(self.log_frequency) {
1422 println!("\n⏱️ Epoch {} profiling started", epoch + 1);
1423 }
1424 Ok(())
1425 }
1426
1427 fn on_epoch_end(&mut self, epoch: usize, _state: &TrainingState) -> TrainResult<()> {
1428 if let Some(epoch_start) = self.epoch_start_time {
1429 let epoch_time = epoch_start.elapsed().as_secs_f64();
1430 self.stats.epoch_times.push(epoch_time);
1431
1432 if self.verbose && (epoch + 1).is_multiple_of(self.log_frequency) {
1433 let avg_batch = if !self.current_epoch_batch_times.is_empty() {
1434 self.current_epoch_batch_times.iter().sum::<f64>()
1435 / self.current_epoch_batch_times.len() as f64
1436 } else {
1437 0.0
1438 };
1439
1440 println!("⏱️ Epoch {} completed:", epoch + 1);
1441 println!(" Time: {:.2}s", epoch_time);
1442 println!(
1443 " Batches: {} ({:.4}s avg)",
1444 self.current_epoch_batch_times.len(),
1445 avg_batch
1446 );
1447 }
1448 }
1449 Ok(())
1450 }
1451
1452 fn on_batch_begin(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
1453 self.batch_start_time = Some(std::time::Instant::now());
1454 Ok(())
1455 }
1456
1457 fn on_batch_end(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
1458 if let Some(batch_start) = self.batch_start_time {
1459 let batch_time = batch_start.elapsed().as_secs_f64();
1460 self.current_epoch_batch_times.push(batch_time);
1461 self.total_batches += 1;
1462 }
1463 Ok(())
1464 }
1465}
1466
1467pub struct ModelEMACallback {
1478 decay: f64,
1480 shadow_params: HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
1482 use_warmup: bool,
1484 num_updates: usize,
1486 initialized: bool,
1488}
1489
1490impl ModelEMACallback {
1491 pub fn new(decay: f64, use_warmup: bool) -> Self {
1497 Self {
1498 decay,
1499 shadow_params: HashMap::new(),
1500 use_warmup,
1501 num_updates: 0,
1502 initialized: false,
1503 }
1504 }
1505
1506 pub fn initialize(
1508 &mut self,
1509 parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
1510 ) {
1511 self.shadow_params.clear();
1512 for (name, param) in parameters {
1513 self.shadow_params.insert(name.clone(), param.clone());
1514 }
1515 self.initialized = true;
1516 }
1517
1518 pub fn update(
1520 &mut self,
1521 parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
1522 ) -> TrainResult<()> {
1523 if !self.initialized {
1524 return Err(TrainError::CallbackError(
1525 "ModelEMA not initialized. Call initialize() first.".to_string(),
1526 ));
1527 }
1528
1529 self.num_updates += 1;
1530
1531 let decay = if self.use_warmup {
1533 let warmup_decay = (1.0 + self.num_updates as f64) / (10.0 + self.num_updates as f64);
1536 warmup_decay.min(self.decay)
1537 } else {
1538 self.decay
1539 };
1540
1541 for (name, param) in parameters {
1543 if let Some(shadow) = self.shadow_params.get_mut(name) {
1544 *shadow = &*shadow * decay + &(param * (1.0 - decay));
1546 }
1547 }
1548
1549 Ok(())
1550 }
1551
1552 pub fn get_shadow_params(
1554 &self,
1555 ) -> &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>> {
1556 &self.shadow_params
1557 }
1558
1559 pub fn apply_shadow(
1561 &self,
1562 parameters: &mut HashMap<
1563 String,
1564 scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>,
1565 >,
1566 ) {
1567 for (name, shadow) in &self.shadow_params {
1568 if let Some(param) = parameters.get_mut(name) {
1569 *param = shadow.clone();
1570 }
1571 }
1572 }
1573}
1574
1575impl Callback for ModelEMACallback {
1576 fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()> {
1577 Ok(())
1579 }
1580
1581 fn on_batch_end(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
1582 Ok(())
1584 }
1585}
1586
1587pub struct GradientAccumulationCallback {
1595 accumulation_steps: usize,
1597 current_step: usize,
1599 accumulated_grads: HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
1601 initialized: bool,
1603}
1604
1605impl GradientAccumulationCallback {
1606 pub fn new(accumulation_steps: usize) -> TrainResult<Self> {
1611 if accumulation_steps == 0 {
1612 return Err(TrainError::CallbackError(
1613 "Accumulation steps must be greater than 0".to_string(),
1614 ));
1615 }
1616
1617 Ok(Self {
1618 accumulation_steps,
1619 current_step: 0,
1620 accumulated_grads: HashMap::new(),
1621 initialized: false,
1622 })
1623 }
1624
1625 pub fn accumulate(
1627 &mut self,
1628 gradients: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
1629 ) -> TrainResult<()> {
1630 if !self.initialized {
1631 for (name, grad) in gradients {
1633 self.accumulated_grads.insert(name.clone(), grad.clone());
1634 }
1635 self.initialized = true;
1636 } else {
1637 for (name, grad) in gradients {
1639 if let Some(acc_grad) = self.accumulated_grads.get_mut(name) {
1640 *acc_grad = &*acc_grad + grad;
1641 }
1642 }
1643 }
1644
1645 self.current_step += 1;
1646 Ok(())
1647 }
1648
1649 pub fn should_update(&self) -> bool {
1651 self.current_step >= self.accumulation_steps
1652 }
1653
1654 pub fn get_and_reset(
1656 &mut self,
1657 ) -> HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>> {
1658 let scale = 1.0 / self.accumulation_steps as f64;
1659
1660 let mut averaged_grads = HashMap::new();
1661 for (name, grad) in &self.accumulated_grads {
1662 averaged_grads.insert(name.clone(), grad * scale);
1663 }
1664
1665 self.current_step = 0;
1667 self.initialized = false;
1668 self.accumulated_grads.clear();
1669
1670 averaged_grads
1671 }
1672}
1673
1674impl Callback for GradientAccumulationCallback {
1675 fn on_epoch_begin(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
1676 self.current_step = 0;
1678 self.initialized = false;
1679 self.accumulated_grads.clear();
1680 Ok(())
1681 }
1682}
1683
1684pub struct SWACallback {
1691 start_epoch: usize,
1693 update_frequency: usize,
1695 swa_params: HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
1697 num_averaged: usize,
1699 active: bool,
1701 initialized: bool,
1703 verbose: bool,
1705}
1706
1707impl SWACallback {
1708 pub fn new(start_epoch: usize, update_frequency: usize, verbose: bool) -> Self {
1715 Self {
1716 start_epoch,
1717 update_frequency,
1718 swa_params: HashMap::new(),
1719 num_averaged: 0,
1720 active: false,
1721 initialized: false,
1722 verbose,
1723 }
1724 }
1725
1726 pub fn update_average(
1728 &mut self,
1729 parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
1730 ) -> TrainResult<()> {
1731 if !self.active {
1732 return Ok(());
1733 }
1734
1735 if !self.initialized {
1736 for (name, param) in parameters {
1738 self.swa_params.insert(name.clone(), param.clone());
1739 }
1740 self.initialized = true;
1741 self.num_averaged = 1;
1742
1743 if self.verbose {
1744 println!("📊 SWA: Initialized with model parameters");
1745 }
1746 } else {
1747 let n = self.num_averaged as f64;
1749 for (name, param) in parameters {
1750 if let Some(swa_param) = self.swa_params.get_mut(name) {
1751 *swa_param = &(&*swa_param * n + param) / (n + 1.0);
1752 }
1753 }
1754 self.num_averaged += 1;
1755
1756 if self.verbose {
1757 println!("📊 SWA: Updated average (n={})", self.num_averaged);
1758 }
1759 }
1760
1761 Ok(())
1762 }
1763
1764 pub fn get_swa_params(
1766 &self,
1767 ) -> &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>> {
1768 &self.swa_params
1769 }
1770
1771 pub fn apply_swa(
1773 &self,
1774 parameters: &mut HashMap<
1775 String,
1776 scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>,
1777 >,
1778 ) {
1779 if self.initialized {
1780 for (name, swa_param) in &self.swa_params {
1781 if let Some(param) = parameters.get_mut(name) {
1782 *param = swa_param.clone();
1783 }
1784 }
1785 }
1786 }
1787
1788 pub fn is_ready(&self) -> bool {
1790 self.initialized && self.num_averaged > 0
1791 }
1792}
1793
1794impl Callback for SWACallback {
1795 fn on_epoch_end(&mut self, epoch: usize, _state: &TrainingState) -> TrainResult<()> {
1796 if epoch >= self.start_epoch && !self.active {
1798 self.active = true;
1799 if self.verbose {
1800 println!("\n📊 SWA: Activated at epoch {}", epoch + 1);
1801 }
1802 }
1803
1804 if self.active && epoch >= self.start_epoch {
1806 let relative_epoch = epoch - self.start_epoch;
1807 if relative_epoch.is_multiple_of(self.update_frequency) {
1808 if self.verbose && self.initialized {
1810 println!(
1811 "📊 SWA: Ready to update at epoch {} (call update_average with parameters)",
1812 epoch + 1
1813 );
1814 }
1815 }
1816 }
1817
1818 Ok(())
1819 }
1820
1821 fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
1822 if self.verbose && self.initialized {
1823 println!(
1824 "\n📊 SWA: Training complete. Averaged {} models.",
1825 self.num_averaged
1826 );
1827 println!("📊 SWA: Call apply_swa() to use averaged parameters.");
1828 }
1829 Ok(())
1830 }
1831}
1832
1833#[cfg(test)]
1834mod profiling_tests {
1835 use super::*;
1836
1837 #[test]
1838 fn test_histogram_stats() {
1839 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1840 let stats = HistogramStats::compute("test", &values, 5);
1841
1842 assert_eq!(stats.name, "test");
1843 assert_eq!(stats.min, 1.0);
1844 assert_eq!(stats.max, 10.0);
1845 assert!((stats.mean - 5.5).abs() < 1e-6);
1846 assert_eq!(stats.bins.len(), 6);
1847 assert_eq!(stats.counts.len(), 5);
1848 assert_eq!(stats.counts.iter().sum::<usize>(), 10);
1849 }
1850
1851 #[test]
1852 fn test_histogram_callback() {
1853 use std::collections::HashMap;
1854 let mut callback = HistogramCallback::new(2, 10, false);
1855 let state = TrainingState {
1856 epoch: 0,
1857 batch: 0,
1858 train_loss: 0.5,
1859 batch_loss: 0.5,
1860 val_loss: Some(0.6),
1861 learning_rate: 0.01,
1862 metrics: HashMap::new(),
1863 };
1864
1865 callback.on_epoch_end(0, &state).unwrap();
1867 assert_eq!(callback.history.len(), 0);
1868
1869 callback.on_epoch_end(1, &state).unwrap();
1871 assert_eq!(callback.history.len(), 1);
1872 }
1873
1874 #[test]
1875 fn test_profiling_callback() {
1876 use std::collections::HashMap;
1877 let mut callback = ProfilingCallback::new(false, 1);
1878 let state = TrainingState {
1879 epoch: 0,
1880 batch: 0,
1881 train_loss: 0.5,
1882 batch_loss: 0.5,
1883 val_loss: Some(0.6),
1884 learning_rate: 0.01,
1885 metrics: HashMap::new(),
1886 };
1887
1888 callback.on_train_begin(&state).unwrap();
1889 assert!(callback.start_time.is_some());
1890
1891 callback.on_epoch_begin(0, &state).unwrap();
1892 assert!(callback.epoch_start_time.is_some());
1893
1894 callback.on_batch_begin(0, &state).unwrap();
1895 std::thread::sleep(std::time::Duration::from_millis(10));
1896 callback.on_batch_end(0, &state).unwrap();
1897
1898 assert_eq!(callback.total_batches, 1);
1899 assert_eq!(callback.current_epoch_batch_times.len(), 1);
1900
1901 callback.on_epoch_end(0, &state).unwrap();
1902 assert_eq!(callback.stats.epoch_times.len(), 1);
1903
1904 callback.on_train_end(&state).unwrap();
1905 assert!(callback.stats.total_time > 0.0);
1906 }
1907}