Skip to main content

vsa_optim_rs/ternary/
accumulator.rs

1//! Ternary gradient accumulator for memory-efficient training.
2//!
3//! Gradient accumulation over many steps can be memory-intensive.
4//! Using ternary representation with scale factors reduces memory by ~10x
5//! while maintaining accuracy through the scale factors.
6
7use std::collections::HashMap;
8
9use candle_core::{DType, Device, Tensor};
10
11use crate::config::TernaryConfig;
12use crate::error::Result;
13use crate::ternary::{
14    calculate_memory_savings, ternary_quantize_deterministic, ternary_quantize_stochastic,
15};
16
17/// Accumulated state for a single parameter.
18#[derive(Debug, Clone)]
19struct AccumulatedGradient {
20    /// Accumulated ternary direction (sum of ternary values).
21    ternary: Tensor,
22    /// Accumulated scale factors.
23    scale_sum: f32,
24    /// Original shape for reconstruction.
25    shape: Vec<usize>,
26}
27
28/// Accumulate gradients using ternary representation.
29///
30/// The accumulator keeps a ternary "direction" tensor and a scale tensor.
31/// New gradients are projected onto this representation and accumulated.
32/// Full-precision reconstruction happens only at update time.
33///
34/// # Example
35///
36/// ```ignore
37/// use vsa_optim_rs::ternary::TernaryGradientAccumulator;
38/// use vsa_optim_rs::TernaryConfig;
39///
40/// let shapes = vec![
41///     ("layer1.weight".to_string(), vec![64, 128]),
42///     ("layer1.bias".to_string(), vec![64]),
43/// ];
44/// let mut accumulator = TernaryGradientAccumulator::new(&shapes, TernaryConfig::default(), &Device::Cpu)?;
45///
46/// // Accumulate gradients
47/// accumulator.accumulate(&gradients)?;
48///
49/// // Get full-precision result
50/// let accumulated = accumulator.get_accumulated()?;
51/// accumulator.reset();
52/// ```
53pub struct TernaryGradientAccumulator {
54    config: TernaryConfig,
55    device: Device,
56    /// Accumulated gradients per parameter.
57    accumulators: HashMap<String, AccumulatedGradient>,
58    /// Number of accumulation steps.
59    count: usize,
60}
61
62impl TernaryGradientAccumulator {
63    /// Create a new gradient accumulator.
64    ///
65    /// # Arguments
66    ///
67    /// * `param_shapes` - List of (name, shape) tuples for parameters
68    /// * `config` - Ternary configuration
69    /// * `device` - Device for tensor storage
70    ///
71    /// # Errors
72    ///
73    /// Returns error if tensor creation fails.
74    pub fn new(
75        param_shapes: &[(String, Vec<usize>)],
76        config: TernaryConfig,
77        device: &Device,
78    ) -> Result<Self> {
79        let mut accumulators = HashMap::new();
80
81        for (name, shape) in param_shapes {
82            let ternary = Tensor::zeros(shape.as_slice(), DType::F32, device)?;
83            accumulators.insert(
84                name.clone(),
85                AccumulatedGradient {
86                    ternary,
87                    scale_sum: 0.0,
88                    shape: shape.clone(),
89                },
90            );
91        }
92
93        Ok(Self {
94            config,
95            device: device.clone(),
96            accumulators,
97            count: 0,
98        })
99    }
100
101    /// Accumulate gradients in ternary form.
102    ///
103    /// Converts gradients to ternary, accumulates direction,
104    /// and tracks scale. This is called after each backward pass.
105    ///
106    /// # Arguments
107    ///
108    /// * `gradients` - Map of parameter names to gradient tensors
109    ///
110    /// # Errors
111    ///
112    /// Returns error if quantization or tensor operations fail.
113    pub fn accumulate(&mut self, gradients: &HashMap<String, Tensor>) -> Result<()> {
114        let threshold = Some(self.config.ternary_threshold);
115
116        for (name, grad) in gradients {
117            if let Some(accum) = self.accumulators.get_mut(name) {
118                // Quantize gradient
119                let (ternary, scale) = if self.config.use_stochastic_rounding {
120                    ternary_quantize_stochastic(grad, threshold)?
121                } else {
122                    ternary_quantize_deterministic(grad, threshold)?
123                };
124
125                // Accumulate (ternary addition = element-wise sum)
126                // Note: Sum of ternary is not ternary, but uses fewer bits
127                accum.ternary = accum.ternary.add(&ternary)?;
128                accum.scale_sum += scale;
129            }
130        }
131
132        self.count += 1;
133        Ok(())
134    }
135
136    /// Get full-precision accumulated gradients.
137    ///
138    /// Reconstructs full-precision gradients from ternary accumulation.
139    /// The scale is averaged and applied to the accumulated direction.
140    ///
141    /// # Returns
142    ///
143    /// Dictionary mapping parameter names to accumulated gradients.
144    ///
145    /// # Errors
146    ///
147    /// Returns error if tensor operations fail.
148    #[allow(clippy::cast_precision_loss)]
149    pub fn get_accumulated(&self) -> Result<HashMap<String, Tensor>> {
150        let mut accumulated = HashMap::new();
151
152        for (name, accum) in &self.accumulators {
153            if self.count > 0 {
154                // Average scale
155                let avg_scale = accum.scale_sum / self.count as f32;
156                // Reconstruct: direction * scale / count
157                let result = (&accum.ternary * avg_scale as f64)?;
158                let result = (result / self.count as f64)?;
159                accumulated.insert(name.clone(), result);
160            } else {
161                accumulated.insert(name.clone(), accum.ternary.clone());
162            }
163        }
164
165        Ok(accumulated)
166    }
167
168    /// Reset accumulator for next accumulation cycle.
169    ///
170    /// # Errors
171    ///
172    /// Returns error if tensor zeroing fails.
173    pub fn reset(&mut self) -> Result<()> {
174        for accum in self.accumulators.values_mut() {
175            accum.ternary = accum.ternary.zeros_like()?;
176            accum.scale_sum = 0.0;
177        }
178        self.count = 0;
179        Ok(())
180    }
181
182    /// Get the number of accumulated steps.
183    #[must_use]
184    pub const fn count(&self) -> usize {
185        self.count
186    }
187
188    /// Calculate memory savings from ternary representation.
189    ///
190    /// # Returns
191    ///
192    /// Fraction of memory saved (0 to 1).
193    #[must_use]
194    pub fn memory_savings(&self) -> f32 {
195        let param_count: usize = self.accumulators.values().map(|a| a.shape.iter().product::<usize>()).sum();
196        let num_tensors = self.accumulators.len();
197        calculate_memory_savings(param_count, num_tensors)
198    }
199
200    /// Check if ready for optimizer update.
201    #[must_use]
202    pub fn ready_for_update(&self) -> bool {
203        self.count >= self.config.accumulation_steps
204    }
205}
206
207/// Optimizer wrapper with ternary gradient accumulation.
208///
209/// Combines ternary accumulation with gradient management for
210/// memory-efficient training. Useful for large batch training where
211/// gradient accumulation is necessary.
212///
213/// # Example
214///
215/// ```ignore
216/// use vsa_optim_rs::ternary::TernaryOptimizerWrapper;
217/// use vsa_optim_rs::TernaryConfig;
218///
219/// let mut wrapper = TernaryOptimizerWrapper::new(param_shapes, TernaryConfig::default(), &device)?;
220///
221/// for (i, batch) in batches.iter().enumerate() {
222///     // Compute gradients...
223///     let should_update = wrapper.step(&gradients)?;
224///     if should_update {
225///         let accumulated = wrapper.get_gradients_for_update()?;
226///         // Apply to optimizer...
227///     }
228/// }
229/// ```
230pub struct TernaryOptimizerWrapper {
231    config: TernaryConfig,
232    accumulator: TernaryGradientAccumulator,
233    step_count: usize,
234    update_count: usize,
235}
236
237impl TernaryOptimizerWrapper {
238    /// Create a new ternary optimizer wrapper.
239    ///
240    /// # Arguments
241    ///
242    /// * `param_shapes` - List of (name, shape) tuples for parameters
243    /// * `config` - Ternary configuration
244    /// * `device` - Device for tensor storage
245    ///
246    /// # Errors
247    ///
248    /// Returns error if accumulator creation fails.
249    pub fn new(
250        param_shapes: &[(String, Vec<usize>)],
251        config: TernaryConfig,
252        device: &Device,
253    ) -> Result<Self> {
254        let accumulator = TernaryGradientAccumulator::new(param_shapes, config.clone(), device)?;
255
256        Ok(Self {
257            config,
258            accumulator,
259            step_count: 0,
260            update_count: 0,
261        })
262    }
263
264    /// Accumulate gradient and check if update is needed.
265    ///
266    /// # Arguments
267    ///
268    /// * `gradients` - Current step gradients
269    ///
270    /// # Returns
271    ///
272    /// True if optimizer update should be performed, False if just accumulated.
273    ///
274    /// # Errors
275    ///
276    /// Returns error if accumulation fails.
277    pub fn step(&mut self, gradients: &HashMap<String, Tensor>) -> Result<bool> {
278        // Accumulate current gradients
279        self.accumulator.accumulate(gradients)?;
280        self.step_count += 1;
281
282        // Check if update is needed
283        Ok(self.step_count % self.config.accumulation_steps == 0)
284    }
285
286    /// Get accumulated gradients for optimizer update.
287    ///
288    /// Call this when `step()` returns true.
289    ///
290    /// # Returns
291    ///
292    /// Accumulated full-precision gradients.
293    ///
294    /// # Errors
295    ///
296    /// Returns error if reconstruction fails.
297    pub fn get_gradients_for_update(&mut self) -> Result<HashMap<String, Tensor>> {
298        let grads = self.accumulator.get_accumulated()?;
299        self.accumulator.reset()?;
300        self.update_count += 1;
301        Ok(grads)
302    }
303
304    /// Get optimization statistics.
305    #[must_use]
306    pub fn get_stats(&self) -> OptimizerStats {
307        OptimizerStats {
308            step_count: self.step_count,
309            update_count: self.update_count,
310            memory_savings: self.accumulator.memory_savings(),
311            accumulation_steps: self.config.accumulation_steps,
312        }
313    }
314
315    /// Get the step count.
316    #[must_use]
317    pub const fn step_count(&self) -> usize {
318        self.step_count
319    }
320
321    /// Get the update count.
322    #[must_use]
323    pub const fn update_count(&self) -> usize {
324        self.update_count
325    }
326
327    /// Reset state for checkpointing.
328    pub fn reset_state(&mut self) {
329        self.step_count = 0;
330        self.update_count = 0;
331    }
332
333    /// Load state from checkpoint.
334    pub fn load_state(&mut self, step_count: usize, update_count: usize) {
335        self.step_count = step_count;
336        self.update_count = update_count;
337    }
338}
339
340/// Optimization statistics.
341#[derive(Debug, Clone)]
342pub struct OptimizerStats {
343    /// Total number of steps.
344    pub step_count: usize,
345    /// Number of optimizer updates.
346    pub update_count: usize,
347    /// Memory savings fraction.
348    pub memory_savings: f32,
349    /// Configured accumulation steps.
350    pub accumulation_steps: usize,
351}
352
353impl std::fmt::Display for OptimizerStats {
354    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355        write!(
356            f,
357            "Steps: {} | Updates: {} | Memory saved: {:.1}%",
358            self.step_count,
359            self.update_count,
360            self.memory_savings * 100.0
361        )
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    fn create_param_shapes() -> Vec<(String, Vec<usize>)> {
370        vec![
371            ("layer1.weight".to_string(), vec![64, 128]),
372            ("layer1.bias".to_string(), vec![64]),
373            ("layer2.weight".to_string(), vec![32, 64]),
374        ]
375    }
376
377    fn create_mock_gradients(device: &Device) -> HashMap<String, Tensor> {
378        let mut gradients = HashMap::new();
379        gradients.insert(
380            "layer1.weight".to_string(),
381            Tensor::randn(0.0f32, 1.0, (64, 128), device).unwrap(),
382        );
383        gradients.insert(
384            "layer1.bias".to_string(),
385            Tensor::randn(0.0f32, 1.0, 64, device).unwrap(),
386        );
387        gradients.insert(
388            "layer2.weight".to_string(),
389            Tensor::randn(0.0f32, 1.0, (32, 64), device).unwrap(),
390        );
391        gradients
392    }
393
394    #[test]
395    fn test_accumulator_creation() {
396        let shapes = create_param_shapes();
397        let device = Device::Cpu;
398        let config = TernaryConfig::default();
399
400        let accumulator = TernaryGradientAccumulator::new(&shapes, config, &device).unwrap();
401        assert_eq!(accumulator.count(), 0);
402    }
403
404    #[test]
405    fn test_accumulator_accumulate() {
406        let shapes = create_param_shapes();
407        let device = Device::Cpu;
408        let config = TernaryConfig::default();
409
410        let mut accumulator = TernaryGradientAccumulator::new(&shapes, config, &device).unwrap();
411        let gradients = create_mock_gradients(&device);
412
413        accumulator.accumulate(&gradients).unwrap();
414        assert_eq!(accumulator.count(), 1);
415
416        accumulator.accumulate(&gradients).unwrap();
417        assert_eq!(accumulator.count(), 2);
418    }
419
420    #[test]
421    fn test_accumulator_get_accumulated() {
422        let shapes = create_param_shapes();
423        let device = Device::Cpu;
424        let config = TernaryConfig::default();
425
426        let mut accumulator = TernaryGradientAccumulator::new(&shapes, config, &device).unwrap();
427        let gradients = create_mock_gradients(&device);
428
429        accumulator.accumulate(&gradients).unwrap();
430        let accumulated = accumulator.get_accumulated().unwrap();
431
432        assert_eq!(accumulated.len(), 3);
433        for (name, _shape) in &shapes {
434            assert!(accumulated.contains_key(name));
435        }
436    }
437
438    #[test]
439    fn test_accumulator_reset() {
440        let shapes = create_param_shapes();
441        let device = Device::Cpu;
442        let config = TernaryConfig::default();
443
444        let mut accumulator = TernaryGradientAccumulator::new(&shapes, config, &device).unwrap();
445        let gradients = create_mock_gradients(&device);
446
447        accumulator.accumulate(&gradients).unwrap();
448        assert_eq!(accumulator.count(), 1);
449
450        accumulator.reset().unwrap();
451        assert_eq!(accumulator.count(), 0);
452    }
453
454    #[test]
455    fn test_accumulator_memory_savings() {
456        let shapes = create_param_shapes();
457        let device = Device::Cpu;
458        let config = TernaryConfig::default();
459
460        let accumulator = TernaryGradientAccumulator::new(&shapes, config, &device).unwrap();
461        let savings = accumulator.memory_savings();
462
463        // Should save ~90%+ for reasonable sizes
464        assert!(savings > 0.9, "Expected >90% savings, got {:.2}%", savings * 100.0);
465    }
466
467    #[test]
468    fn test_optimizer_wrapper_step() {
469        let shapes = create_param_shapes();
470        let device = Device::Cpu;
471        let config = TernaryConfig::default().with_accumulation_steps(4);
472
473        let mut wrapper = TernaryOptimizerWrapper::new(&shapes, config, &device).unwrap();
474        let gradients = create_mock_gradients(&device);
475
476        // Steps 1-3: accumulate only
477        for _ in 0..3 {
478            let should_update = wrapper.step(&gradients).unwrap();
479            assert!(!should_update);
480        }
481
482        // Step 4: should update
483        let should_update = wrapper.step(&gradients).unwrap();
484        assert!(should_update);
485
486        // Get gradients and verify
487        let accumulated = wrapper.get_gradients_for_update().unwrap();
488        assert_eq!(accumulated.len(), 3);
489
490        // Step 5: back to accumulating
491        let should_update = wrapper.step(&gradients).unwrap();
492        assert!(!should_update);
493    }
494
495    #[test]
496    fn test_optimizer_wrapper_stats() {
497        let shapes = create_param_shapes();
498        let device = Device::Cpu;
499        let config = TernaryConfig::default().with_accumulation_steps(2);
500
501        let mut wrapper = TernaryOptimizerWrapper::new(&shapes, config, &device).unwrap();
502        let gradients = create_mock_gradients(&device);
503
504        wrapper.step(&gradients).unwrap();
505        wrapper.step(&gradients).unwrap();
506        let _ = wrapper.get_gradients_for_update().unwrap();
507
508        let stats = wrapper.get_stats();
509        assert_eq!(stats.step_count, 2);
510        assert_eq!(stats.update_count, 1);
511        assert!(stats.memory_savings > 0.9);
512    }
513
514    #[test]
515    fn test_stochastic_vs_deterministic() {
516        let shapes = create_param_shapes();
517        let device = Device::Cpu;
518
519        // Test stochastic
520        let config_stochastic = TernaryConfig::default().with_stochastic_rounding(true);
521        let mut acc_stochastic = TernaryGradientAccumulator::new(&shapes, config_stochastic, &device).unwrap();
522
523        // Test deterministic
524        let config_deterministic = TernaryConfig::default().with_stochastic_rounding(false);
525        let mut acc_deterministic = TernaryGradientAccumulator::new(&shapes, config_deterministic, &device).unwrap();
526
527        let gradients = create_mock_gradients(&device);
528
529        acc_stochastic.accumulate(&gradients).unwrap();
530        acc_deterministic.accumulate(&gradients).unwrap();
531
532        // Both should produce valid results
533        let result_stochastic = acc_stochastic.get_accumulated().unwrap();
534        let result_deterministic = acc_deterministic.get_accumulated().unwrap();
535
536        assert_eq!(result_stochastic.len(), 3);
537        assert_eq!(result_deterministic.len(), 3);
538    }
539}