scirs2_neural/training/
gradient_accumulation.rs

1//! Gradient accumulation utilities
2
3use scirs2_core::ndarray::ScalarOperand;
4use scirs2_core::numeric::{Float, FromPrimitive};
5use std::fmt::Debug;
6
7/// Configuration for gradient accumulation
8#[derive(Debug, Clone)]
9pub struct GradientAccumulationConfig {
10    /// Number of batches to accumulate gradients over
11    pub accumulation_steps: usize,
12    /// Whether to normalize gradients by accumulation steps
13    pub normalize: bool,
14}
15
16impl Default for GradientAccumulationConfig {
17    fn default() -> Self {
18        Self {
19            accumulation_steps: 1,
20            normalize: true,
21        }
22    }
23}
24
25/// Statistics for gradient tracking
26#[derive(Debug, Clone)]
27pub struct GradientStats<F: Float + Debug + ScalarOperand + Send + Sync + FromPrimitive> {
28    /// Average gradient norm
29    pub avg_grad_norm: F,
30    /// Maximum gradient norm
31    pub max_grad_norm: F,
32    /// Minimum gradient norm
33    pub min_grad_norm: F,
34}
35
36/// Gradient accumulator for training
37#[derive(Debug)]
38pub struct GradientAccumulator<F: Float + Debug + ScalarOperand + Send + Sync + FromPrimitive> {
39    /// Configuration
40    pub config: GradientAccumulationConfig,
41    /// Current step in accumulation
42    pub current_step: usize,
43    /// Statistics
44    pub stats: Option<GradientStats<F>>,
45}
46
47impl<F: Float + Debug + ScalarOperand + Send + Sync + FromPrimitive> GradientAccumulator<F> {
48    /// Create a new gradient accumulator
49    pub fn new(config: GradientAccumulationConfig) -> Self {
50        Self {
51            config,
52            current_step: 0,
53            stats: None,
54        }
55    }
56
57    /// Reset the accumulator
58    pub fn reset(&mut self) {
59        self.current_step = 0;
60        self.stats = None;
61    }
62
63    /// Check if we should apply accumulated gradients
64    pub fn should_update(&self) -> bool {
65        self.current_step >= self.config.accumulation_steps
66    }
67}