1use std::collections::HashMap;
8use thiserror::Error;
9
10#[derive(Debug, Error)]
12pub enum AccumulationError {
13 #[error("Gradient shape mismatch for '{name}': expected {expected:?}, got {got:?}")]
15 ShapeMismatch {
16 name: String,
18 expected: Vec<usize>,
20 got: Vec<usize>,
22 },
23 #[error("No gradients accumulated")]
25 NoGradients,
26 #[error("Accumulator already full ({0} micro-batches)")]
28 AccumulatorFull(usize),
29}
30
31#[derive(Debug, Clone)]
33pub struct AccumulationConfig {
34 pub accumulation_steps: usize,
36 pub normalize: bool,
38 pub max_grad_norm: Option<f64>,
40}
41
42impl Default for AccumulationConfig {
43 fn default() -> Self {
44 AccumulationConfig {
45 accumulation_steps: 4,
46 normalize: true,
47 max_grad_norm: None,
48 }
49 }
50}
51
52impl AccumulationConfig {
53 pub fn new(steps: usize) -> Self {
56 AccumulationConfig {
57 accumulation_steps: steps.max(1),
58 ..Default::default()
59 }
60 }
61
62 pub fn with_normalize(mut self, normalize: bool) -> Self {
64 self.normalize = normalize;
65 self
66 }
67
68 pub fn with_max_grad_norm(mut self, norm: f64) -> Self {
70 self.max_grad_norm = Some(norm);
71 self
72 }
73
74 pub fn effective_batch_size(&self, micro_batch_size: usize) -> usize {
76 micro_batch_size * self.accumulation_steps
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct GradientBuffer {
83 pub data: Vec<f64>,
85 pub shape: Vec<usize>,
87 pub accumulated_count: usize,
89}
90
91impl GradientBuffer {
92 pub fn new(shape: Vec<usize>) -> Self {
94 let size: usize = shape.iter().product();
95 GradientBuffer {
96 data: vec![0.0; size],
97 shape,
98 accumulated_count: 0,
99 }
100 }
101
102 pub fn accumulate(&mut self, grad: &[f64]) -> Result<(), AccumulationError> {
104 if grad.len() != self.data.len() {
105 return Err(AccumulationError::ShapeMismatch {
106 name: String::new(),
107 expected: self.shape.clone(),
108 got: vec![grad.len()],
109 });
110 }
111 for (acc, &g) in self.data.iter_mut().zip(grad.iter()) {
112 *acc += g;
113 }
114 self.accumulated_count += 1;
115 Ok(())
116 }
117
118 pub fn get(&self, normalize: bool) -> Vec<f64> {
120 if normalize && self.accumulated_count > 0 {
121 let scale = 1.0 / self.accumulated_count as f64;
122 self.data.iter().map(|&v| v * scale).collect()
123 } else {
124 self.data.clone()
125 }
126 }
127
128 pub fn l2_norm(&self) -> f64 {
130 self.data.iter().map(|v| v * v).sum::<f64>().sqrt()
131 }
132
133 pub fn reset(&mut self) {
135 self.data.fill(0.0);
136 self.accumulated_count = 0;
137 }
138}
139
140pub struct GradientAccumulator {
145 config: AccumulationConfig,
146 buffers: HashMap<String, GradientBuffer>,
147 total_micro_batches: usize,
148 total_updates: usize,
149}
150
151impl GradientAccumulator {
152 pub fn new(config: AccumulationConfig) -> Self {
154 GradientAccumulator {
155 config,
156 buffers: HashMap::new(),
157 total_micro_batches: 0,
158 total_updates: 0,
159 }
160 }
161
162 pub fn register(&mut self, name: impl Into<String>, shape: Vec<usize>) {
166 let name = name.into();
167 self.buffers
168 .entry(name)
169 .or_insert_with(|| GradientBuffer::new(shape));
170 }
171
172 pub fn accumulate(&mut self, name: &str, grad: &[f64]) -> Result<(), AccumulationError> {
177 if let Some(buf) = self.buffers.get_mut(name) {
178 if buf.accumulated_count >= self.config.accumulation_steps {
179 return Err(AccumulationError::AccumulatorFull(
180 self.config.accumulation_steps,
181 ));
182 }
183 buf.accumulate(grad).map_err(|e| match e {
184 AccumulationError::ShapeMismatch { expected, got, .. } => {
185 AccumulationError::ShapeMismatch {
186 name: name.to_string(),
187 expected,
188 got,
189 }
190 }
191 other => other,
192 })
193 } else {
194 Err(AccumulationError::NoGradients)
195 }
196 }
197
198 pub fn should_update(&self) -> bool {
200 self.buffers
201 .values()
202 .any(|b| b.accumulated_count >= self.config.accumulation_steps)
203 }
204
205 pub fn get_gradients(&self) -> Result<HashMap<String, Vec<f64>>, AccumulationError> {
207 if self.buffers.is_empty() {
208 return Err(AccumulationError::NoGradients);
209 }
210 let mut grads: HashMap<String, Vec<f64>> = self
211 .buffers
212 .iter()
213 .map(|(name, buf)| (name.clone(), buf.get(self.config.normalize)))
214 .collect();
215
216 if let Some(max_norm) = self.config.max_grad_norm {
218 let total_norm: f64 = grads
219 .values()
220 .flat_map(|g| g.iter())
221 .map(|v| v * v)
222 .sum::<f64>()
223 .sqrt();
224 if total_norm > max_norm {
225 let scale = max_norm / total_norm;
226 for grad in grads.values_mut() {
227 for v in grad.iter_mut() {
228 *v *= scale;
229 }
230 }
231 }
232 }
233 Ok(grads)
234 }
235
236 pub fn reset(&mut self) {
238 for buf in self.buffers.values_mut() {
239 buf.reset();
240 }
241 self.total_updates += 1;
242 }
243
244 pub fn step(
247 &mut self,
248 gradients: &HashMap<String, Vec<f64>>,
249 ) -> Result<bool, AccumulationError> {
250 for (name, grad) in gradients {
251 self.accumulate(name, grad)?;
252 }
253 self.total_micro_batches += 1;
254 Ok(self.should_update())
255 }
256
257 pub fn stats(&self) -> AccumulationStats {
259 AccumulationStats {
260 total_micro_batches: self.total_micro_batches,
261 total_updates: self.total_updates,
262 accumulation_steps: self.config.accumulation_steps,
263 registered_params: self.buffers.len(),
264 total_param_elements: self.buffers.values().map(|b| b.data.len()).sum(),
265 }
266 }
267}
268
269#[derive(Debug, Clone)]
271pub struct AccumulationStats {
272 pub total_micro_batches: usize,
274 pub total_updates: usize,
276 pub accumulation_steps: usize,
278 pub registered_params: usize,
280 pub total_param_elements: usize,
282}
283
284impl AccumulationStats {
285 pub fn effective_batch_multiplier(&self) -> usize {
287 self.accumulation_steps
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn test_config_default() {
297 let config = AccumulationConfig::default();
298 assert_eq!(config.accumulation_steps, 4);
299 assert!(config.normalize);
300 assert!(config.max_grad_norm.is_none());
301 }
302
303 #[test]
304 fn test_config_effective_batch_size() {
305 let config = AccumulationConfig::new(4);
306 assert_eq!(config.effective_batch_size(32), 128);
307 }
308
309 #[test]
310 fn test_buffer_new() {
311 let buf = GradientBuffer::new(vec![3, 4]);
312 assert_eq!(buf.data.len(), 12);
313 assert!(buf.data.iter().all(|&v| v == 0.0));
314 assert_eq!(buf.accumulated_count, 0);
315 }
316
317 #[test]
318 fn test_buffer_accumulate() {
319 let mut buf = GradientBuffer::new(vec![3]);
320 let grad = vec![1.0, 2.0, 3.0];
321 buf.accumulate(&grad).expect("accumulate should succeed");
322 assert_eq!(buf.data, vec![1.0, 2.0, 3.0]);
323 assert_eq!(buf.accumulated_count, 1);
324
325 buf.accumulate(&grad)
326 .expect("second accumulate should succeed");
327 assert_eq!(buf.data, vec![2.0, 4.0, 6.0]);
328 assert_eq!(buf.accumulated_count, 2);
329 }
330
331 #[test]
332 fn test_buffer_accumulate_shape_mismatch() {
333 let mut buf = GradientBuffer::new(vec![3]);
334 let grad = vec![1.0, 2.0];
335 let result = buf.accumulate(&grad);
336 assert!(result.is_err());
337 match result {
338 Err(AccumulationError::ShapeMismatch { .. }) => {}
339 _ => panic!("expected ShapeMismatch error"),
340 }
341 }
342
343 #[test]
344 fn test_buffer_get_normalized() {
345 let mut buf = GradientBuffer::new(vec![2]);
346 buf.accumulate(&[2.0, 4.0]).expect("accumulate");
347 buf.accumulate(&[6.0, 8.0]).expect("accumulate");
348 let normalized = buf.get(true);
349 assert_eq!(normalized, vec![4.0, 6.0]); }
351
352 #[test]
353 fn test_buffer_get_unnormalized() {
354 let mut buf = GradientBuffer::new(vec![2]);
355 buf.accumulate(&[2.0, 4.0]).expect("accumulate");
356 buf.accumulate(&[6.0, 8.0]).expect("accumulate");
357 let raw = buf.get(false);
358 assert_eq!(raw, vec![8.0, 12.0]); }
360
361 #[test]
362 fn test_buffer_l2_norm() {
363 let mut buf = GradientBuffer::new(vec![2]);
364 buf.accumulate(&[3.0, 4.0]).expect("accumulate");
365 let norm = buf.l2_norm();
366 assert!((norm - 5.0).abs() < 1e-10);
367 }
368
369 #[test]
370 fn test_buffer_reset() {
371 let mut buf = GradientBuffer::new(vec![3]);
372 buf.accumulate(&[1.0, 2.0, 3.0]).expect("accumulate");
373 assert_eq!(buf.accumulated_count, 1);
374 buf.reset();
375 assert!(buf.data.iter().all(|&v| v == 0.0));
376 assert_eq!(buf.accumulated_count, 0);
377 }
378
379 #[test]
380 fn test_accumulator_register() {
381 let mut acc = GradientAccumulator::new(AccumulationConfig::default());
382 acc.register("weight", vec![3, 4]);
383 assert_eq!(acc.buffers.len(), 1);
384 assert!(acc.buffers.contains_key("weight"));
385 }
386
387 #[test]
388 fn test_accumulator_accumulate() {
389 let mut acc = GradientAccumulator::new(AccumulationConfig::default());
390 acc.register("w", vec![2]);
391 acc.accumulate("w", &[1.0, 2.0])
392 .expect("accumulate should succeed");
393 let buf = acc.buffers.get("w").expect("buffer should exist");
394 assert_eq!(buf.data, vec![1.0, 2.0]);
395 }
396
397 #[test]
398 fn test_accumulator_should_update() {
399 let config = AccumulationConfig::new(2);
400 let mut acc = GradientAccumulator::new(config);
401 acc.register("w", vec![2]);
402 assert!(!acc.should_update());
403 acc.accumulate("w", &[1.0, 1.0]).expect("accumulate");
404 assert!(!acc.should_update());
405 acc.accumulate("w", &[1.0, 1.0]).expect("accumulate");
406 assert!(acc.should_update());
407 }
408
409 #[test]
410 fn test_accumulator_get_gradients() {
411 let config = AccumulationConfig::new(2).with_normalize(true);
412 let mut acc = GradientAccumulator::new(config);
413 acc.register("w", vec![2]);
414 acc.accumulate("w", &[2.0, 4.0]).expect("accumulate");
415 acc.accumulate("w", &[6.0, 8.0]).expect("accumulate");
416 let grads = acc.get_gradients().expect("get_gradients");
417 let w_grad = grads.get("w").expect("w gradient");
418 assert_eq!(w_grad, &vec![4.0, 6.0]);
419 }
420
421 #[test]
422 fn test_accumulator_grad_clipping() {
423 let config = AccumulationConfig::new(1)
424 .with_normalize(false)
425 .with_max_grad_norm(5.0);
426 let mut acc = GradientAccumulator::new(config);
427 acc.register("w", vec![2]);
428 acc.accumulate("w", &[30.0, 40.0]).expect("accumulate");
430 let grads = acc.get_gradients().expect("get_gradients");
431 let w_grad = grads.get("w").expect("w gradient");
432 assert!((w_grad[0] - 3.0).abs() < 1e-10);
433 assert!((w_grad[1] - 4.0).abs() < 1e-10);
434 }
435
436 #[test]
437 fn test_accumulator_reset() {
438 let config = AccumulationConfig::new(2);
439 let mut acc = GradientAccumulator::new(config);
440 acc.register("w", vec![2]);
441 acc.accumulate("w", &[1.0, 2.0]).expect("accumulate");
442 acc.reset();
443 let buf = acc.buffers.get("w").expect("buffer");
444 assert!(buf.data.iter().all(|&v| v == 0.0));
445 assert_eq!(buf.accumulated_count, 0);
446 assert_eq!(acc.total_updates, 1);
447 }
448
449 #[test]
450 fn test_accumulator_step() {
451 let config = AccumulationConfig::new(2);
452 let mut acc = GradientAccumulator::new(config);
453 acc.register("w", vec![2]);
454 let mut grads = HashMap::new();
455 grads.insert("w".to_string(), vec![1.0, 1.0]);
456
457 let should = acc.step(&grads).expect("step 1");
458 assert!(!should);
459 let should = acc.step(&grads).expect("step 2");
460 assert!(should);
461 }
462
463 #[test]
464 fn test_accumulator_stats() {
465 let config = AccumulationConfig::new(3);
466 let mut acc = GradientAccumulator::new(config);
467 acc.register("a", vec![2, 3]);
468 acc.register("b", vec![4]);
469
470 let stats = acc.stats();
471 assert_eq!(stats.total_micro_batches, 0);
472 assert_eq!(stats.total_updates, 0);
473 assert_eq!(stats.accumulation_steps, 3);
474 assert_eq!(stats.registered_params, 2);
475 assert_eq!(stats.total_param_elements, 10); assert_eq!(stats.effective_batch_multiplier(), 3);
477 }
478
479 #[test]
480 fn test_accumulator_empty_no_gradients() {
481 let acc = GradientAccumulator::new(AccumulationConfig::default());
482 let result = acc.get_gradients();
483 assert!(result.is_err());
484 match result {
485 Err(AccumulationError::NoGradients) => {}
486 _ => panic!("expected NoGradients error"),
487 }
488 }
489}