Skip to main content

tensorlogic_train/
gradient_accumulation.rs

1//! Gradient accumulation for effective large-batch training.
2//!
3//! Enables training with larger effective batch sizes than memory allows
4//! by accumulating gradients over multiple micro-batches before applying
5//! an optimizer update.
6
7use std::collections::HashMap;
8use thiserror::Error;
9
10/// Errors that can occur during gradient accumulation.
11#[derive(Debug, Error)]
12pub enum AccumulationError {
13    /// Gradient shape mismatch when accumulating into a buffer.
14    #[error("Gradient shape mismatch for '{name}': expected {expected:?}, got {got:?}")]
15    ShapeMismatch {
16        /// Parameter name (may be empty for unnamed buffers).
17        name: String,
18        /// Expected shape of the gradient.
19        expected: Vec<usize>,
20        /// Actual shape of the gradient provided.
21        got: Vec<usize>,
22    },
23    /// Attempted to retrieve gradients when none have been accumulated.
24    #[error("No gradients accumulated")]
25    NoGradients,
26    /// The accumulator has already reached the configured number of micro-batches.
27    #[error("Accumulator already full ({0} micro-batches)")]
28    AccumulatorFull(usize),
29}
30
31/// Configuration for gradient accumulation.
32#[derive(Debug, Clone)]
33pub struct AccumulationConfig {
34    /// Number of micro-batches to accumulate before triggering an update.
35    pub accumulation_steps: usize,
36    /// Whether to normalize (average) gradients across micro-batches.
37    pub normalize: bool,
38    /// Maximum gradient norm for clipping (None = no clipping).
39    pub max_grad_norm: Option<f64>,
40}
41
42impl Default for AccumulationConfig {
43    fn default() -> Self {
44        AccumulationConfig {
45            accumulation_steps: 4,
46            normalize: true,
47            max_grad_norm: None,
48        }
49    }
50}
51
52impl AccumulationConfig {
53    /// Create a new config with the given number of accumulation steps.
54    /// Clamps to a minimum of 1.
55    pub fn new(steps: usize) -> Self {
56        AccumulationConfig {
57            accumulation_steps: steps.max(1),
58            ..Default::default()
59        }
60    }
61
62    /// Set whether to normalize (average) gradients.
63    pub fn with_normalize(mut self, normalize: bool) -> Self {
64        self.normalize = normalize;
65        self
66    }
67
68    /// Set maximum gradient norm for clipping.
69    pub fn with_max_grad_norm(mut self, norm: f64) -> Self {
70        self.max_grad_norm = Some(norm);
71        self
72    }
73
74    /// Compute the effective batch size given the micro-batch size.
75    pub fn effective_batch_size(&self, micro_batch_size: usize) -> usize {
76        micro_batch_size * self.accumulation_steps
77    }
78}
79
80/// A single gradient buffer for one parameter.
81#[derive(Debug, Clone)]
82pub struct GradientBuffer {
83    /// Accumulated gradient values (flattened).
84    pub data: Vec<f64>,
85    /// Shape of the gradient tensor.
86    pub shape: Vec<usize>,
87    /// Number of micro-batches accumulated so far.
88    pub accumulated_count: usize,
89}
90
91impl GradientBuffer {
92    /// Create a new gradient buffer initialized to zeros.
93    pub fn new(shape: Vec<usize>) -> Self {
94        let size: usize = shape.iter().product();
95        GradientBuffer {
96            data: vec![0.0; size],
97            shape,
98            accumulated_count: 0,
99        }
100    }
101
102    /// Add a micro-batch gradient to the buffer.
103    pub fn accumulate(&mut self, grad: &[f64]) -> Result<(), AccumulationError> {
104        if grad.len() != self.data.len() {
105            return Err(AccumulationError::ShapeMismatch {
106                name: String::new(),
107                expected: self.shape.clone(),
108                got: vec![grad.len()],
109            });
110        }
111        for (acc, &g) in self.data.iter_mut().zip(grad.iter()) {
112            *acc += g;
113        }
114        self.accumulated_count += 1;
115        Ok(())
116    }
117
118    /// Get the accumulated gradient, optionally normalized by the count.
119    pub fn get(&self, normalize: bool) -> Vec<f64> {
120        if normalize && self.accumulated_count > 0 {
121            let scale = 1.0 / self.accumulated_count as f64;
122            self.data.iter().map(|&v| v * scale).collect()
123        } else {
124            self.data.clone()
125        }
126    }
127
128    /// Compute the L2 norm of the accumulated gradient.
129    pub fn l2_norm(&self) -> f64 {
130        self.data.iter().map(|v| v * v).sum::<f64>().sqrt()
131    }
132
133    /// Reset the buffer to zeros.
134    pub fn reset(&mut self) {
135        self.data.fill(0.0);
136        self.accumulated_count = 0;
137    }
138}
139
140/// Gradient accumulator managing multiple parameter gradients.
141///
142/// Provides micro-batching support by accumulating gradients across
143/// multiple forward/backward passes before applying an optimizer step.
144pub struct GradientAccumulator {
145    config: AccumulationConfig,
146    buffers: HashMap<String, GradientBuffer>,
147    total_micro_batches: usize,
148    total_updates: usize,
149}
150
151impl GradientAccumulator {
152    /// Create a new gradient accumulator with the given configuration.
153    pub fn new(config: AccumulationConfig) -> Self {
154        GradientAccumulator {
155            config,
156            buffers: HashMap::new(),
157            total_micro_batches: 0,
158            total_updates: 0,
159        }
160    }
161
162    /// Register a parameter with its gradient shape.
163    ///
164    /// If the parameter is already registered, this is a no-op.
165    pub fn register(&mut self, name: impl Into<String>, shape: Vec<usize>) {
166        let name = name.into();
167        self.buffers
168            .entry(name)
169            .or_insert_with(|| GradientBuffer::new(shape));
170    }
171
172    /// Accumulate a gradient for a named parameter.
173    ///
174    /// Returns an error if the parameter has not been registered or if
175    /// the gradient size does not match the registered shape.
176    pub fn accumulate(&mut self, name: &str, grad: &[f64]) -> Result<(), AccumulationError> {
177        if let Some(buf) = self.buffers.get_mut(name) {
178            if buf.accumulated_count >= self.config.accumulation_steps {
179                return Err(AccumulationError::AccumulatorFull(
180                    self.config.accumulation_steps,
181                ));
182            }
183            buf.accumulate(grad).map_err(|e| match e {
184                AccumulationError::ShapeMismatch { expected, got, .. } => {
185                    AccumulationError::ShapeMismatch {
186                        name: name.to_string(),
187                        expected,
188                        got,
189                    }
190                }
191                other => other,
192            })
193        } else {
194            Err(AccumulationError::NoGradients)
195        }
196    }
197
198    /// Check if enough micro-batches have been accumulated to trigger an update.
199    pub fn should_update(&self) -> bool {
200        self.buffers
201            .values()
202            .any(|b| b.accumulated_count >= self.config.accumulation_steps)
203    }
204
205    /// Get all accumulated gradients, optionally normalized and clipped.
206    pub fn get_gradients(&self) -> Result<HashMap<String, Vec<f64>>, AccumulationError> {
207        if self.buffers.is_empty() {
208            return Err(AccumulationError::NoGradients);
209        }
210        let mut grads: HashMap<String, Vec<f64>> = self
211            .buffers
212            .iter()
213            .map(|(name, buf)| (name.clone(), buf.get(self.config.normalize)))
214            .collect();
215
216        // Apply gradient clipping if configured
217        if let Some(max_norm) = self.config.max_grad_norm {
218            let total_norm: f64 = grads
219                .values()
220                .flat_map(|g| g.iter())
221                .map(|v| v * v)
222                .sum::<f64>()
223                .sqrt();
224            if total_norm > max_norm {
225                let scale = max_norm / total_norm;
226                for grad in grads.values_mut() {
227                    for v in grad.iter_mut() {
228                        *v *= scale;
229                    }
230                }
231            }
232        }
233        Ok(grads)
234    }
235
236    /// Reset all buffers after an update step.
237    pub fn reset(&mut self) {
238        for buf in self.buffers.values_mut() {
239            buf.reset();
240        }
241        self.total_updates += 1;
242    }
243
244    /// Accumulate a full micro-batch of gradients, returning `true` if an
245    /// update should now be applied.
246    pub fn step(
247        &mut self,
248        gradients: &HashMap<String, Vec<f64>>,
249    ) -> Result<bool, AccumulationError> {
250        for (name, grad) in gradients {
251            self.accumulate(name, grad)?;
252        }
253        self.total_micro_batches += 1;
254        Ok(self.should_update())
255    }
256
257    /// Get statistics about the accumulation state.
258    pub fn stats(&self) -> AccumulationStats {
259        AccumulationStats {
260            total_micro_batches: self.total_micro_batches,
261            total_updates: self.total_updates,
262            accumulation_steps: self.config.accumulation_steps,
263            registered_params: self.buffers.len(),
264            total_param_elements: self.buffers.values().map(|b| b.data.len()).sum(),
265        }
266    }
267}
268
269/// Statistics from gradient accumulation.
270#[derive(Debug, Clone)]
271pub struct AccumulationStats {
272    /// Total number of micro-batches processed.
273    pub total_micro_batches: usize,
274    /// Total number of optimizer updates applied.
275    pub total_updates: usize,
276    /// Configured number of accumulation steps.
277    pub accumulation_steps: usize,
278    /// Number of registered parameters.
279    pub registered_params: usize,
280    /// Total number of scalar gradient elements across all parameters.
281    pub total_param_elements: usize,
282}
283
284impl AccumulationStats {
285    /// The effective batch size multiplier (same as accumulation_steps).
286    pub fn effective_batch_multiplier(&self) -> usize {
287        self.accumulation_steps
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn test_config_default() {
297        let config = AccumulationConfig::default();
298        assert_eq!(config.accumulation_steps, 4);
299        assert!(config.normalize);
300        assert!(config.max_grad_norm.is_none());
301    }
302
303    #[test]
304    fn test_config_effective_batch_size() {
305        let config = AccumulationConfig::new(4);
306        assert_eq!(config.effective_batch_size(32), 128);
307    }
308
309    #[test]
310    fn test_buffer_new() {
311        let buf = GradientBuffer::new(vec![3, 4]);
312        assert_eq!(buf.data.len(), 12);
313        assert!(buf.data.iter().all(|&v| v == 0.0));
314        assert_eq!(buf.accumulated_count, 0);
315    }
316
317    #[test]
318    fn test_buffer_accumulate() {
319        let mut buf = GradientBuffer::new(vec![3]);
320        let grad = vec![1.0, 2.0, 3.0];
321        buf.accumulate(&grad).expect("accumulate should succeed");
322        assert_eq!(buf.data, vec![1.0, 2.0, 3.0]);
323        assert_eq!(buf.accumulated_count, 1);
324
325        buf.accumulate(&grad)
326            .expect("second accumulate should succeed");
327        assert_eq!(buf.data, vec![2.0, 4.0, 6.0]);
328        assert_eq!(buf.accumulated_count, 2);
329    }
330
331    #[test]
332    fn test_buffer_accumulate_shape_mismatch() {
333        let mut buf = GradientBuffer::new(vec![3]);
334        let grad = vec![1.0, 2.0];
335        let result = buf.accumulate(&grad);
336        assert!(result.is_err());
337        match result {
338            Err(AccumulationError::ShapeMismatch { .. }) => {}
339            _ => panic!("expected ShapeMismatch error"),
340        }
341    }
342
343    #[test]
344    fn test_buffer_get_normalized() {
345        let mut buf = GradientBuffer::new(vec![2]);
346        buf.accumulate(&[2.0, 4.0]).expect("accumulate");
347        buf.accumulate(&[6.0, 8.0]).expect("accumulate");
348        let normalized = buf.get(true);
349        assert_eq!(normalized, vec![4.0, 6.0]); // (2+6)/2=4, (4+8)/2=6
350    }
351
352    #[test]
353    fn test_buffer_get_unnormalized() {
354        let mut buf = GradientBuffer::new(vec![2]);
355        buf.accumulate(&[2.0, 4.0]).expect("accumulate");
356        buf.accumulate(&[6.0, 8.0]).expect("accumulate");
357        let raw = buf.get(false);
358        assert_eq!(raw, vec![8.0, 12.0]); // 2+6=8, 4+8=12
359    }
360
361    #[test]
362    fn test_buffer_l2_norm() {
363        let mut buf = GradientBuffer::new(vec![2]);
364        buf.accumulate(&[3.0, 4.0]).expect("accumulate");
365        let norm = buf.l2_norm();
366        assert!((norm - 5.0).abs() < 1e-10);
367    }
368
369    #[test]
370    fn test_buffer_reset() {
371        let mut buf = GradientBuffer::new(vec![3]);
372        buf.accumulate(&[1.0, 2.0, 3.0]).expect("accumulate");
373        assert_eq!(buf.accumulated_count, 1);
374        buf.reset();
375        assert!(buf.data.iter().all(|&v| v == 0.0));
376        assert_eq!(buf.accumulated_count, 0);
377    }
378
379    #[test]
380    fn test_accumulator_register() {
381        let mut acc = GradientAccumulator::new(AccumulationConfig::default());
382        acc.register("weight", vec![3, 4]);
383        assert_eq!(acc.buffers.len(), 1);
384        assert!(acc.buffers.contains_key("weight"));
385    }
386
387    #[test]
388    fn test_accumulator_accumulate() {
389        let mut acc = GradientAccumulator::new(AccumulationConfig::default());
390        acc.register("w", vec![2]);
391        acc.accumulate("w", &[1.0, 2.0])
392            .expect("accumulate should succeed");
393        let buf = acc.buffers.get("w").expect("buffer should exist");
394        assert_eq!(buf.data, vec![1.0, 2.0]);
395    }
396
397    #[test]
398    fn test_accumulator_should_update() {
399        let config = AccumulationConfig::new(2);
400        let mut acc = GradientAccumulator::new(config);
401        acc.register("w", vec![2]);
402        assert!(!acc.should_update());
403        acc.accumulate("w", &[1.0, 1.0]).expect("accumulate");
404        assert!(!acc.should_update());
405        acc.accumulate("w", &[1.0, 1.0]).expect("accumulate");
406        assert!(acc.should_update());
407    }
408
409    #[test]
410    fn test_accumulator_get_gradients() {
411        let config = AccumulationConfig::new(2).with_normalize(true);
412        let mut acc = GradientAccumulator::new(config);
413        acc.register("w", vec![2]);
414        acc.accumulate("w", &[2.0, 4.0]).expect("accumulate");
415        acc.accumulate("w", &[6.0, 8.0]).expect("accumulate");
416        let grads = acc.get_gradients().expect("get_gradients");
417        let w_grad = grads.get("w").expect("w gradient");
418        assert_eq!(w_grad, &vec![4.0, 6.0]);
419    }
420
421    #[test]
422    fn test_accumulator_grad_clipping() {
423        let config = AccumulationConfig::new(1)
424            .with_normalize(false)
425            .with_max_grad_norm(5.0);
426        let mut acc = GradientAccumulator::new(config);
427        acc.register("w", vec![2]);
428        // gradient [30, 40] has norm 50, clip to 5 => scale by 5/50 = 0.1
429        acc.accumulate("w", &[30.0, 40.0]).expect("accumulate");
430        let grads = acc.get_gradients().expect("get_gradients");
431        let w_grad = grads.get("w").expect("w gradient");
432        assert!((w_grad[0] - 3.0).abs() < 1e-10);
433        assert!((w_grad[1] - 4.0).abs() < 1e-10);
434    }
435
436    #[test]
437    fn test_accumulator_reset() {
438        let config = AccumulationConfig::new(2);
439        let mut acc = GradientAccumulator::new(config);
440        acc.register("w", vec![2]);
441        acc.accumulate("w", &[1.0, 2.0]).expect("accumulate");
442        acc.reset();
443        let buf = acc.buffers.get("w").expect("buffer");
444        assert!(buf.data.iter().all(|&v| v == 0.0));
445        assert_eq!(buf.accumulated_count, 0);
446        assert_eq!(acc.total_updates, 1);
447    }
448
449    #[test]
450    fn test_accumulator_step() {
451        let config = AccumulationConfig::new(2);
452        let mut acc = GradientAccumulator::new(config);
453        acc.register("w", vec![2]);
454        let mut grads = HashMap::new();
455        grads.insert("w".to_string(), vec![1.0, 1.0]);
456
457        let should = acc.step(&grads).expect("step 1");
458        assert!(!should);
459        let should = acc.step(&grads).expect("step 2");
460        assert!(should);
461    }
462
463    #[test]
464    fn test_accumulator_stats() {
465        let config = AccumulationConfig::new(3);
466        let mut acc = GradientAccumulator::new(config);
467        acc.register("a", vec![2, 3]);
468        acc.register("b", vec![4]);
469
470        let stats = acc.stats();
471        assert_eq!(stats.total_micro_batches, 0);
472        assert_eq!(stats.total_updates, 0);
473        assert_eq!(stats.accumulation_steps, 3);
474        assert_eq!(stats.registered_params, 2);
475        assert_eq!(stats.total_param_elements, 10); // 6 + 4
476        assert_eq!(stats.effective_batch_multiplier(), 3);
477    }
478
479    #[test]
480    fn test_accumulator_empty_no_gradients() {
481        let acc = GradientAccumulator::new(AccumulationConfig::default());
482        let result = acc.get_gradients();
483        assert!(result.is_err());
484        match result {
485            Err(AccumulationError::NoGradients) => {}
486            _ => panic!("expected NoGradients error"),
487        }
488    }
489}