1use crate::callbacks::core::Callback;
4use crate::{TrainError, TrainResult, TrainingState};
5use std::collections::HashMap;
6
7pub struct GradientMonitor {
24 log_frequency: usize,
26 vanishing_threshold: f64,
28 exploding_threshold: f64,
30 pub gradient_norms: Vec<f64>,
32 pub gradient_means: Vec<f64>,
34 pub gradient_stds: Vec<f64>,
36 pub vanishing_count: usize,
38 pub exploding_count: usize,
40 batch_counter: usize,
42}
43
44impl GradientMonitor {
45 pub fn new(log_frequency: usize, vanishing_threshold: f64, exploding_threshold: f64) -> Self {
52 Self {
53 log_frequency,
54 vanishing_threshold,
55 exploding_threshold,
56 gradient_norms: Vec::new(),
57 gradient_means: Vec::new(),
58 gradient_stds: Vec::new(),
59 vanishing_count: 0,
60 exploding_count: 0,
61 batch_counter: 0,
62 }
63 }
64
65 fn compute_gradient_stats(&mut self, _state: &TrainingState) -> (f64, f64, f64) {
67 (1.0, 0.0, 0.1)
71 }
72
73 fn check_vanishing(&mut self, norm: f64) -> bool {
75 if norm < self.vanishing_threshold {
76 self.vanishing_count += 1;
77 return true;
78 }
79 false
80 }
81
82 fn check_exploding(&mut self, norm: f64) -> bool {
84 if norm > self.exploding_threshold {
85 self.exploding_count += 1;
86 return true;
87 }
88 false
89 }
90
91 fn print_stats(&self, norm: f64, mean: f64, std: f64) {
93 println!("Gradient Stats [Batch {}]:", self.batch_counter);
94 println!(" Norm: {:.6e}, Mean: {:.6e}, Std: {:.6e}", norm, mean, std);
95
96 if self.vanishing_count > 0 {
97 println!(
98 " Warning: Vanishing gradient warnings: {}",
99 self.vanishing_count
100 );
101 }
102
103 if self.exploding_count > 0 {
104 println!(
105 " Warning: Exploding gradient warnings: {}",
106 self.exploding_count
107 );
108 }
109 }
110
111 pub fn summary(&self) -> GradientSummary {
113 let avg_norm = if !self.gradient_norms.is_empty() {
114 self.gradient_norms.iter().sum::<f64>() / self.gradient_norms.len() as f64
115 } else {
116 0.0
117 };
118
119 GradientSummary {
120 total_batches: self.batch_counter,
121 average_norm: avg_norm,
122 vanishing_count: self.vanishing_count,
123 exploding_count: self.exploding_count,
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
130pub struct GradientSummary {
131 pub total_batches: usize,
133 pub average_norm: f64,
135 pub vanishing_count: usize,
137 pub exploding_count: usize,
139}
140
141impl Callback for GradientMonitor {
142 fn on_batch_end(&mut self, _batch: usize, state: &TrainingState) -> TrainResult<()> {
143 self.batch_counter += 1;
144
145 let (norm, mean, std) = self.compute_gradient_stats(state);
147
148 self.gradient_norms.push(norm);
150 self.gradient_means.push(mean);
151 self.gradient_stds.push(std);
152
153 let vanishing = self.check_vanishing(norm);
155 let exploding = self.check_exploding(norm);
156
157 if self.batch_counter.is_multiple_of(self.log_frequency) {
159 self.print_stats(norm, mean, std);
160 } else if vanishing || exploding {
161 self.print_stats(norm, mean, std);
163 }
164
165 Ok(())
166 }
167
168 fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
169 let summary = self.summary();
170 println!("\n=== Gradient Monitoring Summary ===");
171 println!("Total batches: {}", summary.total_batches);
172 println!("Average gradient norm: {:.6e}", summary.average_norm);
173 println!("Vanishing gradient warnings: {}", summary.vanishing_count);
174 println!("Exploding gradient warnings: {}", summary.exploding_count);
175 println!("====================================\n");
176 Ok(())
177 }
178}
179
180#[derive(Debug, Clone, Copy, PartialEq)]
182pub enum GradientScalingStrategy {
183 Average,
185 Sum,
187 Dynamic,
189}
190
191pub struct GradientAccumulationCallback {
216 accumulation_steps: usize,
218 current_step: usize,
220 accumulated_grads: HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
222 initialized: bool,
224 scaling_strategy: GradientScalingStrategy,
226 max_grad_norm: f64,
228 overflow_detected: bool,
230 total_cycles: usize,
232 clip_grad_norm: Option<f64>,
234}
235
236impl GradientAccumulationCallback {
237 pub fn new(accumulation_steps: usize) -> TrainResult<Self> {
242 Self::with_strategy(accumulation_steps, GradientScalingStrategy::Average)
243 }
244
245 pub fn with_strategy(
251 accumulation_steps: usize,
252 scaling_strategy: GradientScalingStrategy,
253 ) -> TrainResult<Self> {
254 if accumulation_steps == 0 {
255 return Err(TrainError::CallbackError(
256 "Accumulation steps must be greater than 0".to_string(),
257 ));
258 }
259
260 Ok(Self {
261 accumulation_steps,
262 current_step: 0,
263 accumulated_grads: HashMap::new(),
264 initialized: false,
265 scaling_strategy,
266 max_grad_norm: 0.0,
267 overflow_detected: false,
268 total_cycles: 0,
269 clip_grad_norm: None,
270 })
271 }
272
273 pub fn with_grad_clipping(mut self, max_norm: f64) -> Self {
278 self.clip_grad_norm = Some(max_norm);
279 self
280 }
281
282 pub fn accumulate(
284 &mut self,
285 gradients: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
286 ) -> TrainResult<()> {
287 for grad in gradients.values() {
289 if grad.iter().any(|&x| x.is_nan() || x.is_infinite()) {
290 self.overflow_detected = true;
291 return Err(TrainError::CallbackError(
292 "Gradient overflow detected (NaN or Inf)".to_string(),
293 ));
294 }
295 }
296
297 let grad_norm = self.compute_total_norm(gradients);
299 self.max_grad_norm = self.max_grad_norm.max(grad_norm);
300
301 if !self.initialized {
302 for (name, grad) in gradients {
304 let clipped_grad = if let Some(max_norm) = self.clip_grad_norm {
305 if grad_norm > max_norm {
306 let scale = max_norm / grad_norm;
307 grad * scale
308 } else {
309 grad.clone()
310 }
311 } else {
312 grad.clone()
313 };
314 self.accumulated_grads.insert(name.clone(), clipped_grad);
315 }
316 self.initialized = true;
317 } else {
318 for (name, grad) in gradients {
320 if let Some(acc_grad) = self.accumulated_grads.get_mut(name) {
321 let grad_to_add = if let Some(max_norm) = self.clip_grad_norm {
322 if grad_norm > max_norm {
323 let scale = max_norm / grad_norm;
324 grad * scale
325 } else {
326 grad.clone()
327 }
328 } else {
329 grad.clone()
330 };
331
332 *acc_grad = &*acc_grad + &grad_to_add;
334 }
335 }
336 }
337
338 self.current_step += 1;
339 Ok(())
340 }
341
342 fn compute_total_norm(
344 &self,
345 gradients: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
346 ) -> f64 {
347 let mut total_norm_sq = 0.0;
348 for grad in gradients.values() {
349 total_norm_sq += grad.iter().map(|&x| x * x).sum::<f64>();
350 }
351 total_norm_sq.sqrt()
352 }
353
354 pub fn should_update(&self) -> bool {
356 self.current_step >= self.accumulation_steps
357 }
358
359 pub fn get_and_reset(
361 &mut self,
362 ) -> HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>> {
363 let scale = match self.scaling_strategy {
364 GradientScalingStrategy::Average => 1.0 / self.accumulation_steps as f64,
365 GradientScalingStrategy::Sum => 1.0,
366 GradientScalingStrategy::Dynamic => {
367 1.0 / self.current_step.max(1) as f64
369 }
370 };
371
372 let mut scaled_grads = HashMap::new();
373 for (name, grad) in &self.accumulated_grads {
374 scaled_grads.insert(name.clone(), grad * scale);
375 }
376
377 self.total_cycles += 1;
379
380 self.current_step = 0;
382 self.initialized = false;
383 self.accumulated_grads.clear();
384 self.max_grad_norm = 0.0;
385 self.overflow_detected = false;
386
387 scaled_grads
388 }
389
390 pub fn get_stats(&self) -> GradientAccumulationStats {
392 let memory_usage = self.estimate_memory_usage();
393
394 GradientAccumulationStats {
395 accumulation_steps: self.accumulation_steps,
396 current_step: self.current_step,
397 total_cycles: self.total_cycles,
398 max_grad_norm: self.max_grad_norm,
399 overflow_detected: self.overflow_detected,
400 num_parameters: self.accumulated_grads.len(),
401 memory_usage_mb: memory_usage,
402 }
403 }
404
405 fn estimate_memory_usage(&self) -> f64 {
407 let mut total_elements = 0usize;
408 for grad in self.accumulated_grads.values() {
409 total_elements += grad.len();
410 }
411 (total_elements * 8) as f64 / (1024.0 * 1024.0)
413 }
414
415 pub fn reset(&mut self) {
417 self.current_step = 0;
418 self.initialized = false;
419 self.accumulated_grads.clear();
420 self.max_grad_norm = 0.0;
421 self.overflow_detected = false;
422 }
423}
424
425#[derive(Debug, Clone)]
427pub struct GradientAccumulationStats {
428 pub accumulation_steps: usize,
430 pub current_step: usize,
432 pub total_cycles: usize,
434 pub max_grad_norm: f64,
436 pub overflow_detected: bool,
438 pub num_parameters: usize,
440 pub memory_usage_mb: f64,
442}
443
444impl Callback for GradientAccumulationCallback {
445 fn on_epoch_begin(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
446 self.current_step = 0;
448 self.initialized = false;
449 self.accumulated_grads.clear();
450 Ok(())
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use scirs2_core::ndarray::Array2;
458
459 fn create_test_gradients() -> HashMap<String, Array2<f64>> {
460 let mut grads = HashMap::new();
461 grads.insert(
462 "layer1".to_string(),
463 Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(),
464 );
465 grads.insert(
466 "layer2".to_string(),
467 Array2::from_shape_vec((2, 2), vec![0.5, 1.0, 1.5, 2.0]).unwrap(),
468 );
469 grads
470 }
471
472 #[test]
473 fn test_gradient_accumulation_average_strategy() {
474 let mut accum = GradientAccumulationCallback::new(2).unwrap();
475 let grads = create_test_gradients();
476
477 accum.accumulate(&grads).unwrap();
479 assert_eq!(accum.current_step, 1);
480 assert!(!accum.should_update());
481
482 accum.accumulate(&grads).unwrap();
484 assert_eq!(accum.current_step, 2);
485 assert!(accum.should_update());
486
487 let averaged = accum.get_and_reset();
489 let layer1 = averaged.get("layer1").unwrap();
490
491 assert_eq!(layer1[[0, 0]], 1.0); assert_eq!(layer1[[0, 1]], 2.0); assert_eq!(accum.current_step, 0);
497 }
498
499 #[test]
500 fn test_gradient_accumulation_sum_strategy() {
501 let mut accum =
502 GradientAccumulationCallback::with_strategy(2, GradientScalingStrategy::Sum).unwrap();
503 let grads = create_test_gradients();
504
505 accum.accumulate(&grads).unwrap();
506 accum.accumulate(&grads).unwrap();
507
508 let summed = accum.get_and_reset();
509 let layer1 = summed.get("layer1").unwrap();
510
511 assert_eq!(layer1[[0, 0]], 2.0); assert_eq!(layer1[[0, 1]], 4.0); }
515
516 #[test]
517 fn test_gradient_accumulation_dynamic_strategy() {
518 let mut accum =
519 GradientAccumulationCallback::with_strategy(4, GradientScalingStrategy::Dynamic)
520 .unwrap();
521 let grads = create_test_gradients();
522
523 accum.accumulate(&grads).unwrap();
525 accum.accumulate(&grads).unwrap();
526 accum.accumulate(&grads).unwrap();
527
528 let scaled = accum.get_and_reset();
529 let layer1 = scaled.get("layer1").unwrap();
530
531 assert_eq!(layer1[[0, 0]], 1.0); }
534
535 #[test]
536 fn test_gradient_clipping_during_accumulation() {
537 let mut accum = GradientAccumulationCallback::new(2)
538 .unwrap()
539 .with_grad_clipping(1.0); let mut grads = HashMap::new();
542 grads.insert(
543 "layer1".to_string(),
544 Array2::from_shape_vec((2, 2), vec![10.0, 10.0, 10.0, 10.0]).unwrap(),
545 );
546
547 accum.accumulate(&grads).unwrap();
549 assert!(accum.max_grad_norm > 0.0);
550
551 let accumulated = &accum.accumulated_grads["layer1"];
553 let norm_sq: f64 = accumulated.iter().map(|&x| x * x).sum();
554 let norm = norm_sq.sqrt();
555
556 assert!(norm <= 1.1); }
559
560 #[test]
561 fn test_overflow_detection() {
562 let mut accum = GradientAccumulationCallback::new(2).unwrap();
563
564 let mut grads = HashMap::new();
565 grads.insert(
566 "layer1".to_string(),
567 Array2::from_shape_vec((2, 2), vec![f64::NAN, 1.0, 2.0, 3.0]).unwrap(),
568 );
569
570 let result = accum.accumulate(&grads);
572 assert!(result.is_err());
573 assert!(accum.overflow_detected);
574 }
575
576 #[test]
577 fn test_gradient_accumulation_stats() {
578 let mut accum = GradientAccumulationCallback::new(2).unwrap();
579 let grads = create_test_gradients();
580
581 accum.accumulate(&grads).unwrap();
582 accum.accumulate(&grads).unwrap();
583 accum.get_and_reset();
584
585 let stats = accum.get_stats();
586 assert_eq!(stats.accumulation_steps, 2);
587 assert_eq!(stats.total_cycles, 1);
588 assert!(!stats.overflow_detected);
589 }
590
591 #[test]
592 fn test_memory_usage_estimation() {
593 let mut accum = GradientAccumulationCallback::new(2).unwrap();
594 let grads = create_test_gradients();
595
596 accum.accumulate(&grads).unwrap();
597
598 let stats = accum.get_stats();
599 assert!(stats.memory_usage_mb > 0.0);
600 assert_eq!(stats.num_parameters, 2); }
602
603 #[test]
604 fn test_gradient_accumulation_reset() {
605 let mut accum = GradientAccumulationCallback::new(2).unwrap();
606 let grads = create_test_gradients();
607
608 accum.accumulate(&grads).unwrap();
609 assert_eq!(accum.current_step, 1);
610
611 accum.reset();
612 assert_eq!(accum.current_step, 0);
613 assert!(!accum.initialized);
614 assert_eq!(accum.accumulated_grads.len(), 0);
615 }
616
617 #[test]
618 fn test_gradient_accumulation_zero_steps_error() {
619 let result = GradientAccumulationCallback::new(0);
620 assert!(result.is_err());
621 }
622
623 #[test]
624 fn test_gradient_accumulation_multiple_cycles() {
625 let mut accum = GradientAccumulationCallback::new(2).unwrap();
626 let grads = create_test_gradients();
627
628 accum.accumulate(&grads).unwrap();
630 accum.accumulate(&grads).unwrap();
631 accum.get_and_reset();
632
633 accum.accumulate(&grads).unwrap();
635 accum.accumulate(&grads).unwrap();
636 accum.get_and_reset();
637
638 let stats = accum.get_stats();
639 assert_eq!(stats.total_cycles, 2);
640 }
641
642 #[test]
643 fn test_different_gradient_shapes() {
644 let mut accum = GradientAccumulationCallback::new(2).unwrap();
645
646 let mut grads1 = HashMap::new();
647 grads1.insert(
648 "layer1".to_string(),
649 Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(),
650 );
651
652 let mut grads2 = HashMap::new();
653 grads2.insert(
654 "layer1".to_string(),
655 Array2::from_shape_vec((2, 3), vec![0.5, 1.0, 1.5, 2.0, 2.5, 3.0]).unwrap(),
656 );
657
658 accum.accumulate(&grads1).unwrap();
659 accum.accumulate(&grads2).unwrap();
660
661 let averaged = accum.get_and_reset();
662 let layer1 = averaged.get("layer1").unwrap();
663
664 assert_eq!(layer1.dim(), (2, 3));
665 assert_eq!(layer1[[0, 0]], 0.75); }
667}