1use std::collections::HashMap;
8
9use candle_core::{DType, Device, Tensor};
10
11use crate::config::TernaryConfig;
12use crate::error::Result;
13use crate::ternary::{
14 calculate_memory_savings, ternary_quantize_deterministic, ternary_quantize_stochastic,
15};
16
17#[derive(Debug, Clone)]
19struct AccumulatedGradient {
20 ternary: Tensor,
22 scale_sum: f32,
24 shape: Vec<usize>,
26}
27
28pub struct TernaryGradientAccumulator {
54 config: TernaryConfig,
55 device: Device,
56 accumulators: HashMap<String, AccumulatedGradient>,
58 count: usize,
60}
61
62impl TernaryGradientAccumulator {
63 pub fn new(
75 param_shapes: &[(String, Vec<usize>)],
76 config: TernaryConfig,
77 device: &Device,
78 ) -> Result<Self> {
79 let mut accumulators = HashMap::new();
80
81 for (name, shape) in param_shapes {
82 let ternary = Tensor::zeros(shape.as_slice(), DType::F32, device)?;
83 accumulators.insert(
84 name.clone(),
85 AccumulatedGradient {
86 ternary,
87 scale_sum: 0.0,
88 shape: shape.clone(),
89 },
90 );
91 }
92
93 Ok(Self {
94 config,
95 device: device.clone(),
96 accumulators,
97 count: 0,
98 })
99 }
100
101 pub fn accumulate(&mut self, gradients: &HashMap<String, Tensor>) -> Result<()> {
114 let threshold = Some(self.config.ternary_threshold);
115
116 for (name, grad) in gradients {
117 if let Some(accum) = self.accumulators.get_mut(name) {
118 let (ternary, scale) = if self.config.use_stochastic_rounding {
120 ternary_quantize_stochastic(grad, threshold)?
121 } else {
122 ternary_quantize_deterministic(grad, threshold)?
123 };
124
125 accum.ternary = accum.ternary.add(&ternary)?;
128 accum.scale_sum += scale;
129 }
130 }
131
132 self.count += 1;
133 Ok(())
134 }
135
136 #[allow(clippy::cast_precision_loss)]
149 pub fn get_accumulated(&self) -> Result<HashMap<String, Tensor>> {
150 let mut accumulated = HashMap::new();
151
152 for (name, accum) in &self.accumulators {
153 if self.count > 0 {
154 let avg_scale = accum.scale_sum / self.count as f32;
156 let result = (&accum.ternary * avg_scale as f64)?;
158 let result = (result / self.count as f64)?;
159 accumulated.insert(name.clone(), result);
160 } else {
161 accumulated.insert(name.clone(), accum.ternary.clone());
162 }
163 }
164
165 Ok(accumulated)
166 }
167
168 pub fn reset(&mut self) -> Result<()> {
174 for accum in self.accumulators.values_mut() {
175 accum.ternary = accum.ternary.zeros_like()?;
176 accum.scale_sum = 0.0;
177 }
178 self.count = 0;
179 Ok(())
180 }
181
182 #[must_use]
184 pub const fn count(&self) -> usize {
185 self.count
186 }
187
188 #[must_use]
194 pub fn memory_savings(&self) -> f32 {
195 let param_count: usize = self.accumulators.values().map(|a| a.shape.iter().product::<usize>()).sum();
196 let num_tensors = self.accumulators.len();
197 calculate_memory_savings(param_count, num_tensors)
198 }
199
200 #[must_use]
202 pub fn ready_for_update(&self) -> bool {
203 self.count >= self.config.accumulation_steps
204 }
205}
206
207pub struct TernaryOptimizerWrapper {
231 config: TernaryConfig,
232 accumulator: TernaryGradientAccumulator,
233 step_count: usize,
234 update_count: usize,
235}
236
237impl TernaryOptimizerWrapper {
238 pub fn new(
250 param_shapes: &[(String, Vec<usize>)],
251 config: TernaryConfig,
252 device: &Device,
253 ) -> Result<Self> {
254 let accumulator = TernaryGradientAccumulator::new(param_shapes, config.clone(), device)?;
255
256 Ok(Self {
257 config,
258 accumulator,
259 step_count: 0,
260 update_count: 0,
261 })
262 }
263
264 pub fn step(&mut self, gradients: &HashMap<String, Tensor>) -> Result<bool> {
278 self.accumulator.accumulate(gradients)?;
280 self.step_count += 1;
281
282 Ok(self.step_count % self.config.accumulation_steps == 0)
284 }
285
286 pub fn get_gradients_for_update(&mut self) -> Result<HashMap<String, Tensor>> {
298 let grads = self.accumulator.get_accumulated()?;
299 self.accumulator.reset()?;
300 self.update_count += 1;
301 Ok(grads)
302 }
303
304 #[must_use]
306 pub fn get_stats(&self) -> OptimizerStats {
307 OptimizerStats {
308 step_count: self.step_count,
309 update_count: self.update_count,
310 memory_savings: self.accumulator.memory_savings(),
311 accumulation_steps: self.config.accumulation_steps,
312 }
313 }
314
315 #[must_use]
317 pub const fn step_count(&self) -> usize {
318 self.step_count
319 }
320
321 #[must_use]
323 pub const fn update_count(&self) -> usize {
324 self.update_count
325 }
326
327 pub fn reset_state(&mut self) {
329 self.step_count = 0;
330 self.update_count = 0;
331 }
332
333 pub fn load_state(&mut self, step_count: usize, update_count: usize) {
335 self.step_count = step_count;
336 self.update_count = update_count;
337 }
338}
339
340#[derive(Debug, Clone)]
342pub struct OptimizerStats {
343 pub step_count: usize,
345 pub update_count: usize,
347 pub memory_savings: f32,
349 pub accumulation_steps: usize,
351}
352
353impl std::fmt::Display for OptimizerStats {
354 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 write!(
356 f,
357 "Steps: {} | Updates: {} | Memory saved: {:.1}%",
358 self.step_count,
359 self.update_count,
360 self.memory_savings * 100.0
361 )
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 fn create_param_shapes() -> Vec<(String, Vec<usize>)> {
370 vec![
371 ("layer1.weight".to_string(), vec![64, 128]),
372 ("layer1.bias".to_string(), vec![64]),
373 ("layer2.weight".to_string(), vec![32, 64]),
374 ]
375 }
376
377 fn create_mock_gradients(device: &Device) -> HashMap<String, Tensor> {
378 let mut gradients = HashMap::new();
379 gradients.insert(
380 "layer1.weight".to_string(),
381 Tensor::randn(0.0f32, 1.0, (64, 128), device).unwrap(),
382 );
383 gradients.insert(
384 "layer1.bias".to_string(),
385 Tensor::randn(0.0f32, 1.0, 64, device).unwrap(),
386 );
387 gradients.insert(
388 "layer2.weight".to_string(),
389 Tensor::randn(0.0f32, 1.0, (32, 64), device).unwrap(),
390 );
391 gradients
392 }
393
394 #[test]
395 fn test_accumulator_creation() {
396 let shapes = create_param_shapes();
397 let device = Device::Cpu;
398 let config = TernaryConfig::default();
399
400 let accumulator = TernaryGradientAccumulator::new(&shapes, config, &device).unwrap();
401 assert_eq!(accumulator.count(), 0);
402 }
403
404 #[test]
405 fn test_accumulator_accumulate() {
406 let shapes = create_param_shapes();
407 let device = Device::Cpu;
408 let config = TernaryConfig::default();
409
410 let mut accumulator = TernaryGradientAccumulator::new(&shapes, config, &device).unwrap();
411 let gradients = create_mock_gradients(&device);
412
413 accumulator.accumulate(&gradients).unwrap();
414 assert_eq!(accumulator.count(), 1);
415
416 accumulator.accumulate(&gradients).unwrap();
417 assert_eq!(accumulator.count(), 2);
418 }
419
420 #[test]
421 fn test_accumulator_get_accumulated() {
422 let shapes = create_param_shapes();
423 let device = Device::Cpu;
424 let config = TernaryConfig::default();
425
426 let mut accumulator = TernaryGradientAccumulator::new(&shapes, config, &device).unwrap();
427 let gradients = create_mock_gradients(&device);
428
429 accumulator.accumulate(&gradients).unwrap();
430 let accumulated = accumulator.get_accumulated().unwrap();
431
432 assert_eq!(accumulated.len(), 3);
433 for (name, _shape) in &shapes {
434 assert!(accumulated.contains_key(name));
435 }
436 }
437
438 #[test]
439 fn test_accumulator_reset() {
440 let shapes = create_param_shapes();
441 let device = Device::Cpu;
442 let config = TernaryConfig::default();
443
444 let mut accumulator = TernaryGradientAccumulator::new(&shapes, config, &device).unwrap();
445 let gradients = create_mock_gradients(&device);
446
447 accumulator.accumulate(&gradients).unwrap();
448 assert_eq!(accumulator.count(), 1);
449
450 accumulator.reset().unwrap();
451 assert_eq!(accumulator.count(), 0);
452 }
453
454 #[test]
455 fn test_accumulator_memory_savings() {
456 let shapes = create_param_shapes();
457 let device = Device::Cpu;
458 let config = TernaryConfig::default();
459
460 let accumulator = TernaryGradientAccumulator::new(&shapes, config, &device).unwrap();
461 let savings = accumulator.memory_savings();
462
463 assert!(savings > 0.9, "Expected >90% savings, got {:.2}%", savings * 100.0);
465 }
466
467 #[test]
468 fn test_optimizer_wrapper_step() {
469 let shapes = create_param_shapes();
470 let device = Device::Cpu;
471 let config = TernaryConfig::default().with_accumulation_steps(4);
472
473 let mut wrapper = TernaryOptimizerWrapper::new(&shapes, config, &device).unwrap();
474 let gradients = create_mock_gradients(&device);
475
476 for _ in 0..3 {
478 let should_update = wrapper.step(&gradients).unwrap();
479 assert!(!should_update);
480 }
481
482 let should_update = wrapper.step(&gradients).unwrap();
484 assert!(should_update);
485
486 let accumulated = wrapper.get_gradients_for_update().unwrap();
488 assert_eq!(accumulated.len(), 3);
489
490 let should_update = wrapper.step(&gradients).unwrap();
492 assert!(!should_update);
493 }
494
495 #[test]
496 fn test_optimizer_wrapper_stats() {
497 let shapes = create_param_shapes();
498 let device = Device::Cpu;
499 let config = TernaryConfig::default().with_accumulation_steps(2);
500
501 let mut wrapper = TernaryOptimizerWrapper::new(&shapes, config, &device).unwrap();
502 let gradients = create_mock_gradients(&device);
503
504 wrapper.step(&gradients).unwrap();
505 wrapper.step(&gradients).unwrap();
506 let _ = wrapper.get_gradients_for_update().unwrap();
507
508 let stats = wrapper.get_stats();
509 assert_eq!(stats.step_count, 2);
510 assert_eq!(stats.update_count, 1);
511 assert!(stats.memory_savings > 0.9);
512 }
513
514 #[test]
515 fn test_stochastic_vs_deterministic() {
516 let shapes = create_param_shapes();
517 let device = Device::Cpu;
518
519 let config_stochastic = TernaryConfig::default().with_stochastic_rounding(true);
521 let mut acc_stochastic = TernaryGradientAccumulator::new(&shapes, config_stochastic, &device).unwrap();
522
523 let config_deterministic = TernaryConfig::default().with_stochastic_rounding(false);
525 let mut acc_deterministic = TernaryGradientAccumulator::new(&shapes, config_deterministic, &device).unwrap();
526
527 let gradients = create_mock_gradients(&device);
528
529 acc_stochastic.accumulate(&gradients).unwrap();
530 acc_deterministic.accumulate(&gradients).unwrap();
531
532 let result_stochastic = acc_stochastic.get_accumulated().unwrap();
534 let result_deterministic = acc_deterministic.get_accumulated().unwrap();
535
536 assert_eq!(result_stochastic.len(), 3);
537 assert_eq!(result_deterministic.len(), 3);
538 }
539}