Skip to main content

tensorlogic_train/callbacks/
gradient.rs

1//! Gradient monitoring and accumulation callbacks.
2
3use crate::callbacks::core::Callback;
4use crate::{TrainError, TrainResult, TrainingState};
5use std::collections::HashMap;
6
7/// Gradient flow monitor for tracking gradient statistics during training.
8///
9/// This callback tracks gradient norms, mean, std, and identifies vanishing/exploding gradients.
10/// Useful for debugging training issues and understanding gradient flow through the network.
11///
12/// # Example
13/// ```rust,ignore
14/// use tensorlogic_train::{GradientMonitor, CallbackList};
15///
16/// let mut callbacks = CallbackList::new();
17/// callbacks.add(Box::new(GradientMonitor::new(
18///     10,      // log_frequency
19///     1e-7,    // vanishing_threshold
20///     100.0,   // exploding_threshold
21/// )));
22/// ```
23pub struct GradientMonitor {
24    /// Frequency of logging (every N batches).
25    log_frequency: usize,
26    /// Threshold for detecting vanishing gradients.
27    vanishing_threshold: f64,
28    /// Threshold for detecting exploding gradients.
29    exploding_threshold: f64,
30    /// History of gradient norms.
31    pub gradient_norms: Vec<f64>,
32    /// History of gradient means.
33    pub gradient_means: Vec<f64>,
34    /// History of gradient stds.
35    pub gradient_stds: Vec<f64>,
36    /// Count of vanishing gradient warnings.
37    pub vanishing_count: usize,
38    /// Count of exploding gradient warnings.
39    pub exploding_count: usize,
40    /// Current batch counter.
41    batch_counter: usize,
42}
43
44impl GradientMonitor {
45    /// Create a new gradient monitor.
46    ///
47    /// # Arguments
48    /// * `log_frequency` - Log statistics every N batches
49    /// * `vanishing_threshold` - Threshold below which gradients are considered vanishing
50    /// * `exploding_threshold` - Threshold above which gradients are considered exploding
51    pub fn new(log_frequency: usize, vanishing_threshold: f64, exploding_threshold: f64) -> Self {
52        Self {
53            log_frequency,
54            vanishing_threshold,
55            exploding_threshold,
56            gradient_norms: Vec::new(),
57            gradient_means: Vec::new(),
58            gradient_stds: Vec::new(),
59            vanishing_count: 0,
60            exploding_count: 0,
61            batch_counter: 0,
62        }
63    }
64
65    /// Compute gradient statistics (placeholder - actual implementation needs gradient access).
66    fn compute_gradient_stats(&mut self, _state: &TrainingState) -> (f64, f64, f64) {
67        // In a real implementation, this would access actual gradients
68        // For now, return placeholder values
69        // (norm, mean, std)
70        (1.0, 0.0, 0.1)
71    }
72
73    /// Check for vanishing gradients.
74    fn check_vanishing(&mut self, norm: f64) -> bool {
75        if norm < self.vanishing_threshold {
76            self.vanishing_count += 1;
77            return true;
78        }
79        false
80    }
81
82    /// Check for exploding gradients.
83    fn check_exploding(&mut self, norm: f64) -> bool {
84        if norm > self.exploding_threshold {
85            self.exploding_count += 1;
86            return true;
87        }
88        false
89    }
90
91    /// Print gradient statistics.
92    fn print_stats(&self, norm: f64, mean: f64, std: f64) {
93        println!("Gradient Stats [Batch {}]:", self.batch_counter);
94        println!("  Norm: {:.6e}, Mean: {:.6e}, Std: {:.6e}", norm, mean, std);
95
96        if self.vanishing_count > 0 {
97            println!(
98                "  Warning: Vanishing gradient warnings: {}",
99                self.vanishing_count
100            );
101        }
102
103        if self.exploding_count > 0 {
104            println!(
105                "  Warning: Exploding gradient warnings: {}",
106                self.exploding_count
107            );
108        }
109    }
110
111    /// Get summary statistics.
112    pub fn summary(&self) -> GradientSummary {
113        let avg_norm = if !self.gradient_norms.is_empty() {
114            self.gradient_norms.iter().sum::<f64>() / self.gradient_norms.len() as f64
115        } else {
116            0.0
117        };
118
119        GradientSummary {
120            total_batches: self.batch_counter,
121            average_norm: avg_norm,
122            vanishing_count: self.vanishing_count,
123            exploding_count: self.exploding_count,
124        }
125    }
126}
127
128/// Summary of gradient statistics.
129#[derive(Debug, Clone)]
130pub struct GradientSummary {
131    /// Total number of batches monitored.
132    pub total_batches: usize,
133    /// Average gradient norm.
134    pub average_norm: f64,
135    /// Number of vanishing gradient warnings.
136    pub vanishing_count: usize,
137    /// Number of exploding gradient warnings.
138    pub exploding_count: usize,
139}
140
141impl Callback for GradientMonitor {
142    fn on_batch_end(&mut self, _batch: usize, state: &TrainingState) -> TrainResult<()> {
143        self.batch_counter += 1;
144
145        // Compute gradient statistics
146        let (norm, mean, std) = self.compute_gradient_stats(state);
147
148        // Record statistics
149        self.gradient_norms.push(norm);
150        self.gradient_means.push(mean);
151        self.gradient_stds.push(std);
152
153        // Check for issues
154        let vanishing = self.check_vanishing(norm);
155        let exploding = self.check_exploding(norm);
156
157        // Log if needed
158        if self.batch_counter.is_multiple_of(self.log_frequency) {
159            self.print_stats(norm, mean, std);
160        } else if vanishing || exploding {
161            // Always log warnings immediately
162            self.print_stats(norm, mean, std);
163        }
164
165        Ok(())
166    }
167
168    fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
169        let summary = self.summary();
170        println!("\n=== Gradient Monitoring Summary ===");
171        println!("Total batches: {}", summary.total_batches);
172        println!("Average gradient norm: {:.6e}", summary.average_norm);
173        println!("Vanishing gradient warnings: {}", summary.vanishing_count);
174        println!("Exploding gradient warnings: {}", summary.exploding_count);
175        println!("====================================\n");
176        Ok(())
177    }
178}
179
180/// Gradient scaling strategy for accumulation.
181#[derive(Debug, Clone, Copy, PartialEq)]
182pub enum GradientScalingStrategy {
183    /// Divide by accumulation steps (default, maintains gradient magnitude)
184    Average,
185    /// Sum gradients without scaling (useful for some optimizers)
186    Sum,
187    /// Dynamic scaling based on batch size ratio
188    Dynamic,
189}
190
191/// Gradient Accumulation callback with advanced features.
192///
193/// Simulates larger batch sizes by accumulating gradients over multiple
194/// mini-batches before updating parameters. This is useful when GPU memory
195/// is limited but you want to train with effectively larger batches.
196///
197/// Effective batch size = mini_batch_size * accumulation_steps
198///
199/// # Features
200/// - Memory-efficient in-place accumulation
201/// - Multiple scaling strategies
202/// - Gradient overflow detection
203/// - Memory usage tracking
204/// - Automatic gradient zeroing
205///
206/// # Example
207/// ```rust,ignore
208/// use tensorlogic_train::{GradientAccumulationCallback, GradientScalingStrategy};
209///
210/// let mut grad_accum = GradientAccumulationCallback::new(
211///     4, // accumulate over 4 mini-batches
212///     GradientScalingStrategy::Average,
213/// ).unwrap();
214/// ```
215pub struct GradientAccumulationCallback {
216    /// Number of steps to accumulate gradients before updating.
217    accumulation_steps: usize,
218    /// Current accumulation counter.
219    current_step: usize,
220    /// Accumulated gradients.
221    accumulated_grads: HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
222    /// Whether gradients are initialized.
223    initialized: bool,
224    /// Gradient scaling strategy.
225    scaling_strategy: GradientScalingStrategy,
226    /// Track maximum gradient norm seen during accumulation.
227    max_grad_norm: f64,
228    /// Track if overflow was detected.
229    overflow_detected: bool,
230    /// Total number of accumulation cycles completed.
231    total_cycles: usize,
232    /// Enable gradient clipping during accumulation.
233    clip_grad_norm: Option<f64>,
234}
235
236impl GradientAccumulationCallback {
237    /// Create a new Gradient Accumulation callback with default average scaling.
238    ///
239    /// # Arguments
240    /// * `accumulation_steps` - Number of mini-batches to accumulate (e.g., 4, 8, 16)
241    pub fn new(accumulation_steps: usize) -> TrainResult<Self> {
242        Self::with_strategy(accumulation_steps, GradientScalingStrategy::Average)
243    }
244
245    /// Create a new Gradient Accumulation callback with specified scaling strategy.
246    ///
247    /// # Arguments
248    /// * `accumulation_steps` - Number of mini-batches to accumulate
249    /// * `scaling_strategy` - How to scale accumulated gradients
250    pub fn with_strategy(
251        accumulation_steps: usize,
252        scaling_strategy: GradientScalingStrategy,
253    ) -> TrainResult<Self> {
254        if accumulation_steps == 0 {
255            return Err(TrainError::CallbackError(
256                "Accumulation steps must be greater than 0".to_string(),
257            ));
258        }
259
260        Ok(Self {
261            accumulation_steps,
262            current_step: 0,
263            accumulated_grads: HashMap::new(),
264            initialized: false,
265            scaling_strategy,
266            max_grad_norm: 0.0,
267            overflow_detected: false,
268            total_cycles: 0,
269            clip_grad_norm: None,
270        })
271    }
272
273    /// Enable gradient clipping during accumulation.
274    ///
275    /// # Arguments
276    /// * `max_norm` - Maximum gradient norm before clipping
277    pub fn with_grad_clipping(mut self, max_norm: f64) -> Self {
278        self.clip_grad_norm = Some(max_norm);
279        self
280    }
281
282    /// Accumulate gradients with optional clipping and overflow detection.
283    pub fn accumulate(
284        &mut self,
285        gradients: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
286    ) -> TrainResult<()> {
287        // Check for NaN/Inf before accumulation
288        for grad in gradients.values() {
289            if grad.iter().any(|&x| x.is_nan() || x.is_infinite()) {
290                self.overflow_detected = true;
291                return Err(TrainError::CallbackError(
292                    "Gradient overflow detected (NaN or Inf)".to_string(),
293                ));
294            }
295        }
296
297        // Compute gradient norm for monitoring
298        let grad_norm = self.compute_total_norm(gradients);
299        self.max_grad_norm = self.max_grad_norm.max(grad_norm);
300
301        if !self.initialized {
302            // Initialize on first call with zero-copy when possible
303            for (name, grad) in gradients {
304                let clipped_grad = if let Some(max_norm) = self.clip_grad_norm {
305                    if grad_norm > max_norm {
306                        let scale = max_norm / grad_norm;
307                        grad * scale
308                    } else {
309                        grad.clone()
310                    }
311                } else {
312                    grad.clone()
313                };
314                self.accumulated_grads.insert(name.clone(), clipped_grad);
315            }
316            self.initialized = true;
317        } else {
318            // In-place accumulation for memory efficiency
319            for (name, grad) in gradients {
320                if let Some(acc_grad) = self.accumulated_grads.get_mut(name) {
321                    let grad_to_add = if let Some(max_norm) = self.clip_grad_norm {
322                        if grad_norm > max_norm {
323                            let scale = max_norm / grad_norm;
324                            grad * scale
325                        } else {
326                            grad.clone()
327                        }
328                    } else {
329                        grad.clone()
330                    };
331
332                    // In-place addition
333                    *acc_grad = &*acc_grad + &grad_to_add;
334                }
335            }
336        }
337
338        self.current_step += 1;
339        Ok(())
340    }
341
342    /// Compute the total L2 norm of all gradients.
343    fn compute_total_norm(
344        &self,
345        gradients: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
346    ) -> f64 {
347        let mut total_norm_sq = 0.0;
348        for grad in gradients.values() {
349            total_norm_sq += grad.iter().map(|&x| x * x).sum::<f64>();
350        }
351        total_norm_sq.sqrt()
352    }
353
354    /// Check if we should perform an optimizer step.
355    pub fn should_update(&self) -> bool {
356        self.current_step >= self.accumulation_steps
357    }
358
359    /// Get scaled accumulated gradients and reset state.
360    pub fn get_and_reset(
361        &mut self,
362    ) -> HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>> {
363        let scale = match self.scaling_strategy {
364            GradientScalingStrategy::Average => 1.0 / self.accumulation_steps as f64,
365            GradientScalingStrategy::Sum => 1.0,
366            GradientScalingStrategy::Dynamic => {
367                // Dynamic scaling based on actual steps accumulated
368                1.0 / self.current_step.max(1) as f64
369            }
370        };
371
372        let mut scaled_grads = HashMap::new();
373        for (name, grad) in &self.accumulated_grads {
374            scaled_grads.insert(name.clone(), grad * scale);
375        }
376
377        // Update statistics
378        self.total_cycles += 1;
379
380        // Reset state
381        self.current_step = 0;
382        self.initialized = false;
383        self.accumulated_grads.clear();
384        self.max_grad_norm = 0.0;
385        self.overflow_detected = false;
386
387        scaled_grads
388    }
389
390    /// Get statistics about gradient accumulation.
391    pub fn get_stats(&self) -> GradientAccumulationStats {
392        let memory_usage = self.estimate_memory_usage();
393
394        GradientAccumulationStats {
395            accumulation_steps: self.accumulation_steps,
396            current_step: self.current_step,
397            total_cycles: self.total_cycles,
398            max_grad_norm: self.max_grad_norm,
399            overflow_detected: self.overflow_detected,
400            num_parameters: self.accumulated_grads.len(),
401            memory_usage_mb: memory_usage,
402        }
403    }
404
405    /// Estimate memory usage of accumulated gradients in MB.
406    fn estimate_memory_usage(&self) -> f64 {
407        let mut total_elements = 0usize;
408        for grad in self.accumulated_grads.values() {
409            total_elements += grad.len();
410        }
411        // f64 = 8 bytes
412        (total_elements * 8) as f64 / (1024.0 * 1024.0)
413    }
414
415    /// Reset all state without returning gradients (useful for error recovery).
416    pub fn reset(&mut self) {
417        self.current_step = 0;
418        self.initialized = false;
419        self.accumulated_grads.clear();
420        self.max_grad_norm = 0.0;
421        self.overflow_detected = false;
422    }
423}
424
425/// Statistics for gradient accumulation.
426#[derive(Debug, Clone)]
427pub struct GradientAccumulationStats {
428    /// Configured accumulation steps.
429    pub accumulation_steps: usize,
430    /// Current step in accumulation.
431    pub current_step: usize,
432    /// Total completed cycles.
433    pub total_cycles: usize,
434    /// Maximum gradient norm seen.
435    pub max_grad_norm: f64,
436    /// Whether overflow was detected.
437    pub overflow_detected: bool,
438    /// Number of parameters being accumulated.
439    pub num_parameters: usize,
440    /// Estimated memory usage in MB.
441    pub memory_usage_mb: f64,
442}
443
444impl Callback for GradientAccumulationCallback {
445    fn on_epoch_begin(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
446        // Reset at the beginning of each epoch
447        self.current_step = 0;
448        self.initialized = false;
449        self.accumulated_grads.clear();
450        Ok(())
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use scirs2_core::ndarray::Array2;
458
459    fn create_test_gradients() -> HashMap<String, Array2<f64>> {
460        let mut grads = HashMap::new();
461        grads.insert(
462            "layer1".to_string(),
463            Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(),
464        );
465        grads.insert(
466            "layer2".to_string(),
467            Array2::from_shape_vec((2, 2), vec![0.5, 1.0, 1.5, 2.0]).unwrap(),
468        );
469        grads
470    }
471
472    #[test]
473    fn test_gradient_accumulation_average_strategy() {
474        let mut accum = GradientAccumulationCallback::new(2).unwrap();
475        let grads = create_test_gradients();
476
477        // First accumulation
478        accum.accumulate(&grads).unwrap();
479        assert_eq!(accum.current_step, 1);
480        assert!(!accum.should_update());
481
482        // Second accumulation
483        accum.accumulate(&grads).unwrap();
484        assert_eq!(accum.current_step, 2);
485        assert!(accum.should_update());
486
487        // Get averaged gradients
488        let averaged = accum.get_and_reset();
489        let layer1 = averaged.get("layer1").unwrap();
490
491        // Should be average of 2 accumulations (same gradient twice)
492        assert_eq!(layer1[[0, 0]], 1.0); // (1.0 + 1.0) / 2
493        assert_eq!(layer1[[0, 1]], 2.0); // (2.0 + 2.0) / 2
494
495        // Should be reset
496        assert_eq!(accum.current_step, 0);
497    }
498
499    #[test]
500    fn test_gradient_accumulation_sum_strategy() {
501        let mut accum =
502            GradientAccumulationCallback::with_strategy(2, GradientScalingStrategy::Sum).unwrap();
503        let grads = create_test_gradients();
504
505        accum.accumulate(&grads).unwrap();
506        accum.accumulate(&grads).unwrap();
507
508        let summed = accum.get_and_reset();
509        let layer1 = summed.get("layer1").unwrap();
510
511        // Should be sum (no scaling)
512        assert_eq!(layer1[[0, 0]], 2.0); // 1.0 + 1.0
513        assert_eq!(layer1[[0, 1]], 4.0); // 2.0 + 2.0
514    }
515
516    #[test]
517    fn test_gradient_accumulation_dynamic_strategy() {
518        let mut accum =
519            GradientAccumulationCallback::with_strategy(4, GradientScalingStrategy::Dynamic)
520                .unwrap();
521        let grads = create_test_gradients();
522
523        // Accumulate only 3 times (less than configured 4)
524        accum.accumulate(&grads).unwrap();
525        accum.accumulate(&grads).unwrap();
526        accum.accumulate(&grads).unwrap();
527
528        let scaled = accum.get_and_reset();
529        let layer1 = scaled.get("layer1").unwrap();
530
531        // Should scale by actual steps (3) not configured steps (4)
532        assert_eq!(layer1[[0, 0]], 1.0); // (1.0 + 1.0 + 1.0) / 3
533    }
534
535    #[test]
536    fn test_gradient_clipping_during_accumulation() {
537        let mut accum = GradientAccumulationCallback::new(2)
538            .unwrap()
539            .with_grad_clipping(1.0); // Very small max norm
540
541        let mut grads = HashMap::new();
542        grads.insert(
543            "layer1".to_string(),
544            Array2::from_shape_vec((2, 2), vec![10.0, 10.0, 10.0, 10.0]).unwrap(),
545        );
546
547        // Large gradients should be clipped
548        accum.accumulate(&grads).unwrap();
549        assert!(accum.max_grad_norm > 0.0);
550
551        // Accumulated gradients should be clipped
552        let accumulated = &accum.accumulated_grads["layer1"];
553        let norm_sq: f64 = accumulated.iter().map(|&x| x * x).sum();
554        let norm = norm_sq.sqrt();
555
556        // Norm should be at or below clip threshold
557        assert!(norm <= 1.1); // Small tolerance
558    }
559
560    #[test]
561    fn test_overflow_detection() {
562        let mut accum = GradientAccumulationCallback::new(2).unwrap();
563
564        let mut grads = HashMap::new();
565        grads.insert(
566            "layer1".to_string(),
567            Array2::from_shape_vec((2, 2), vec![f64::NAN, 1.0, 2.0, 3.0]).unwrap(),
568        );
569
570        // Should detect NaN
571        let result = accum.accumulate(&grads);
572        assert!(result.is_err());
573        assert!(accum.overflow_detected);
574    }
575
576    #[test]
577    fn test_gradient_accumulation_stats() {
578        let mut accum = GradientAccumulationCallback::new(2).unwrap();
579        let grads = create_test_gradients();
580
581        accum.accumulate(&grads).unwrap();
582        accum.accumulate(&grads).unwrap();
583        accum.get_and_reset();
584
585        let stats = accum.get_stats();
586        assert_eq!(stats.accumulation_steps, 2);
587        assert_eq!(stats.total_cycles, 1);
588        assert!(!stats.overflow_detected);
589    }
590
591    #[test]
592    fn test_memory_usage_estimation() {
593        let mut accum = GradientAccumulationCallback::new(2).unwrap();
594        let grads = create_test_gradients();
595
596        accum.accumulate(&grads).unwrap();
597
598        let stats = accum.get_stats();
599        assert!(stats.memory_usage_mb > 0.0);
600        assert_eq!(stats.num_parameters, 2); // 2 layers
601    }
602
603    #[test]
604    fn test_gradient_accumulation_reset() {
605        let mut accum = GradientAccumulationCallback::new(2).unwrap();
606        let grads = create_test_gradients();
607
608        accum.accumulate(&grads).unwrap();
609        assert_eq!(accum.current_step, 1);
610
611        accum.reset();
612        assert_eq!(accum.current_step, 0);
613        assert!(!accum.initialized);
614        assert_eq!(accum.accumulated_grads.len(), 0);
615    }
616
617    #[test]
618    fn test_gradient_accumulation_zero_steps_error() {
619        let result = GradientAccumulationCallback::new(0);
620        assert!(result.is_err());
621    }
622
623    #[test]
624    fn test_gradient_accumulation_multiple_cycles() {
625        let mut accum = GradientAccumulationCallback::new(2).unwrap();
626        let grads = create_test_gradients();
627
628        // First cycle
629        accum.accumulate(&grads).unwrap();
630        accum.accumulate(&grads).unwrap();
631        accum.get_and_reset();
632
633        // Second cycle
634        accum.accumulate(&grads).unwrap();
635        accum.accumulate(&grads).unwrap();
636        accum.get_and_reset();
637
638        let stats = accum.get_stats();
639        assert_eq!(stats.total_cycles, 2);
640    }
641
642    #[test]
643    fn test_different_gradient_shapes() {
644        let mut accum = GradientAccumulationCallback::new(2).unwrap();
645
646        let mut grads1 = HashMap::new();
647        grads1.insert(
648            "layer1".to_string(),
649            Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(),
650        );
651
652        let mut grads2 = HashMap::new();
653        grads2.insert(
654            "layer1".to_string(),
655            Array2::from_shape_vec((2, 3), vec![0.5, 1.0, 1.5, 2.0, 2.5, 3.0]).unwrap(),
656        );
657
658        accum.accumulate(&grads1).unwrap();
659        accum.accumulate(&grads2).unwrap();
660
661        let averaged = accum.get_and_reset();
662        let layer1 = averaged.get("layer1").unwrap();
663
664        assert_eq!(layer1.dim(), (2, 3));
665        assert_eq!(layer1[[0, 0]], 0.75); // (1.0 + 0.5) / 2
666    }
667}