scirs2_neural/training/
gradient_accumulation.rs1use scirs2_core::ndarray::ScalarOperand;
4use scirs2_core::numeric::{Float, FromPrimitive};
5use std::fmt::Debug;
6
7#[derive(Debug, Clone)]
9pub struct GradientAccumulationConfig {
10 pub accumulation_steps: usize,
12 pub normalize: bool,
14}
15
16impl Default for GradientAccumulationConfig {
17 fn default() -> Self {
18 Self {
19 accumulation_steps: 1,
20 normalize: true,
21 }
22 }
23}
24
25#[derive(Debug, Clone)]
27pub struct GradientStats<F: Float + Debug + ScalarOperand + Send + Sync + FromPrimitive> {
28 pub avg_grad_norm: F,
30 pub max_grad_norm: F,
32 pub min_grad_norm: F,
34}
35
36#[derive(Debug)]
38pub struct GradientAccumulator<F: Float + Debug + ScalarOperand + Send + Sync + FromPrimitive> {
39 pub config: GradientAccumulationConfig,
41 pub current_step: usize,
43 pub stats: Option<GradientStats<F>>,
45}
46
47impl<F: Float + Debug + ScalarOperand + Send + Sync + FromPrimitive> GradientAccumulator<F> {
48 pub fn new(config: GradientAccumulationConfig) -> Self {
50 Self {
51 config,
52 current_step: 0,
53 stats: None,
54 }
55 }
56
57 pub fn reset(&mut self) {
59 self.current_step = 0;
60 self.stats = None;
61 }
62
63 pub fn should_update(&self) -> bool {
65 self.current_step >= self.config.accumulation_steps
66 }
67}