1use crate::common::{OptimizerState, StateMemoryStats};
21use crate::traits::StatefulOptimizer;
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24use trustformers_core::errors::{Result, TrustformersError};
25use trustformers_core::tensor::Tensor;
26use trustformers_core::traits::Optimizer;
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct MicroAdamConfig {
31 pub learning_rate: f32,
33 pub beta1: f32,
35 pub beta2: f32,
37 pub epsilon: f32,
39 pub weight_decay: f32,
41 pub compression_ratio: f32,
43 pub min_block_size: usize,
45 pub adaptive_compression: bool,
47 pub compression_threshold: f32,
49 pub bias_correction: bool,
51 pub max_compression_error: f32,
53}
54
55impl Default for MicroAdamConfig {
56 fn default() -> Self {
57 Self {
58 learning_rate: 1e-3,
59 beta1: 0.9,
60 beta2: 0.999,
61 epsilon: 1e-8,
62 weight_decay: 0.01,
63 compression_ratio: 0.1,
64 min_block_size: 64,
65 adaptive_compression: true,
66 compression_threshold: 1e-6,
67 bias_correction: true,
68 max_compression_error: 1e-4,
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
75struct CompressedGradient {
76 compressed_data: Vec<f32>,
78 indices: Vec<usize>,
80 scale_factor: f32,
82 original_size: usize,
84 compression_type: CompressionType,
86}
87
88#[derive(Debug, Clone, Copy)]
90enum CompressionType {
91 TopK,
93 Threshold,
95 BlockWise,
97 #[allow(dead_code)]
99 Adaptive,
100}
101
102impl CompressedGradient {
103 fn compress(gradient: &[f32], config: &MicroAdamConfig) -> Self {
105 let original_size = gradient.len();
106 let target_size = (original_size as f32 * config.compression_ratio) as usize;
107 let target_size = target_size.max(config.min_block_size.min(original_size));
108
109 let compression_type = if config.adaptive_compression {
110 Self::choose_adaptive_compression(gradient, config)
112 } else {
113 CompressionType::TopK
114 };
115
116 match compression_type {
117 CompressionType::TopK => Self::compress_topk(gradient, target_size),
118 CompressionType::Threshold => Self::compress_threshold(gradient, config),
119 CompressionType::BlockWise => Self::compress_blockwise(gradient, config),
120 CompressionType::Adaptive => Self::compress_adaptive(gradient, config),
121 }
122 }
123
124 fn choose_adaptive_compression(gradient: &[f32], config: &MicroAdamConfig) -> CompressionType {
126 let mean_abs = gradient.iter().map(|x| x.abs()).sum::<f32>() / gradient.len() as f32;
127 let sparsity = gradient.iter().filter(|&&x| x.abs() < config.compression_threshold).count()
128 as f32
129 / gradient.len() as f32;
130
131 if sparsity > 0.8 {
132 CompressionType::Threshold
133 } else if mean_abs > 1e-3 {
134 CompressionType::BlockWise
135 } else {
136 CompressionType::TopK
137 }
138 }
139
140 fn compress_topk(gradient: &[f32], k: usize) -> Self {
142 let mut indexed_values: Vec<(usize, f32)> =
143 gradient.iter().enumerate().map(|(i, &val)| (i, val.abs())).collect();
144
145 indexed_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
147
148 let k = k.min(indexed_values.len());
149 let indices: Vec<usize> = indexed_values[..k].iter().map(|(i, _)| *i).collect();
150 let compressed_data: Vec<f32> = indices.iter().map(|&i| gradient[i]).collect();
151
152 let max_val = compressed_data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
154 let scale_factor = if max_val > 0.0 { 1.0 / max_val } else { 1.0 };
155
156 Self {
157 compressed_data: compressed_data.iter().map(|x| x * scale_factor).collect(),
158 indices,
159 scale_factor: 1.0 / scale_factor,
160 original_size: gradient.len(),
161 compression_type: CompressionType::TopK,
162 }
163 }
164
165 fn compress_threshold(gradient: &[f32], config: &MicroAdamConfig) -> Self {
167 let threshold = config.compression_threshold;
168 let mut indices = Vec::new();
169 let mut compressed_data = Vec::new();
170
171 for (i, &val) in gradient.iter().enumerate() {
172 if val.abs() >= threshold {
173 indices.push(i);
174 compressed_data.push(val);
175 }
176 }
177
178 let max_val = compressed_data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
179 let scale_factor = if max_val > 0.0 { 1.0 / max_val } else { 1.0 };
180
181 Self {
182 compressed_data: compressed_data.iter().map(|x| x * scale_factor).collect(),
183 indices,
184 scale_factor: 1.0 / scale_factor,
185 original_size: gradient.len(),
186 compression_type: CompressionType::Threshold,
187 }
188 }
189
190 fn compress_blockwise(gradient: &[f32], config: &MicroAdamConfig) -> Self {
192 let block_size = config.min_block_size;
193 let num_blocks = gradient.len().div_ceil(block_size);
194 let target_elements_per_block =
195 ((block_size as f32 * config.compression_ratio) as usize).max(1);
196
197 let mut indices = Vec::new();
198 let mut compressed_data = Vec::new();
199
200 for block_idx in 0..num_blocks {
201 let start = block_idx * block_size;
202 let end = (start + block_size).min(gradient.len());
203 let block = &gradient[start..end];
204
205 let mut block_indexed: Vec<(usize, f32)> =
207 block.iter().enumerate().map(|(i, &val)| (start + i, val.abs())).collect();
208
209 block_indexed
210 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
211
212 let k = target_elements_per_block.min(block_indexed.len());
213 for i in 0..k {
214 let global_idx = block_indexed[i].0;
215 indices.push(global_idx);
216 compressed_data.push(gradient[global_idx]);
217 }
218 }
219
220 let max_val = compressed_data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
221 let scale_factor = if max_val > 0.0 { 1.0 / max_val } else { 1.0 };
222
223 Self {
224 compressed_data: compressed_data.iter().map(|x| x * scale_factor).collect(),
225 indices,
226 scale_factor: 1.0 / scale_factor,
227 original_size: gradient.len(),
228 compression_type: CompressionType::BlockWise,
229 }
230 }
231
232 fn compress_adaptive(gradient: &[f32], config: &MicroAdamConfig) -> Self {
234 let topk = Self::compress_topk(
236 gradient,
237 (gradient.len() as f32 * config.compression_ratio) as usize,
238 );
239 let threshold = Self::compress_threshold(gradient, config);
240 let blockwise = Self::compress_blockwise(gradient, config);
241
242 let topk_ratio = topk.compressed_data.len() as f32 / gradient.len() as f32;
244 let threshold_ratio = threshold.compressed_data.len() as f32 / gradient.len() as f32;
245 let blockwise_ratio = blockwise.compressed_data.len() as f32 / gradient.len() as f32;
246
247 if threshold_ratio <= config.compression_ratio && threshold_ratio < topk_ratio {
248 threshold
249 } else if blockwise_ratio <= config.compression_ratio && blockwise_ratio < topk_ratio {
250 blockwise
251 } else {
252 topk
253 }
254 }
255
256 fn decompress(&self) -> Vec<f32> {
258 let mut result = vec![0.0; self.original_size];
259 for (i, &idx) in self.indices.iter().enumerate() {
260 if idx < self.original_size && i < self.compressed_data.len() {
261 result[idx] = self.compressed_data[i] * self.scale_factor;
262 }
263 }
264 result
265 }
266
267 fn compression_ratio(&self) -> f32 {
269 self.compressed_data.len() as f32 / self.original_size as f32
270 }
271
272 fn compression_error(&self, original: &[f32]) -> f32 {
274 let decompressed = self.decompress();
275 let mut error_sum = 0.0;
276 let mut norm_sum = 0.0;
277
278 for (orig, decomp) in original.iter().zip(decompressed.iter()) {
279 error_sum += (orig - decomp).powi(2);
280 norm_sum += orig.powi(2);
281 }
282
283 if norm_sum > 0.0 {
284 (error_sum / norm_sum).sqrt()
285 } else {
286 0.0
287 }
288 }
289}
290
291#[derive(Debug)]
296pub struct MicroAdam {
297 config: MicroAdamConfig,
298 state: OptimizerState,
299 momentum: HashMap<String, CompressedGradient>,
301 variance: HashMap<String, CompressedGradient>,
303 compression_stats: CompressionStats,
305}
306
307#[derive(Debug, Default)]
309struct CompressionStats {
310 total_parameters: usize,
311 total_compressed_size: usize,
312 average_compression_ratio: f32,
313 average_compression_error: f32,
314 compression_method_usage: HashMap<String, usize>,
315}
316
317impl MicroAdam {
318 pub fn new() -> Self {
320 Self::with_config(MicroAdamConfig::default())
321 }
322
323 pub fn new_with_lr(learning_rate: f32) -> Self {
325 let config = MicroAdamConfig {
326 learning_rate,
327 ..Default::default()
328 };
329 Self::with_config(config)
330 }
331
332 pub fn for_large_models() -> Self {
334 let config = MicroAdamConfig {
335 learning_rate: 1e-4,
336 beta1: 0.9,
337 beta2: 0.999,
338 epsilon: 1e-8,
339 weight_decay: 0.01,
340 compression_ratio: 0.05, min_block_size: 128,
342 adaptive_compression: true,
343 compression_threshold: 1e-7,
344 bias_correction: true,
345 max_compression_error: 1e-5,
346 };
347 Self::with_config(config)
348 }
349
350 pub fn for_memory_constrained() -> Self {
352 let config = MicroAdamConfig {
353 learning_rate: 1e-3,
354 beta1: 0.9,
355 beta2: 0.999,
356 epsilon: 1e-8,
357 weight_decay: 0.01,
358 compression_ratio: 0.02, min_block_size: 32,
360 adaptive_compression: true,
361 compression_threshold: 1e-6,
362 bias_correction: true,
363 max_compression_error: 1e-4,
364 };
365 Self::with_config(config)
366 }
367
368 pub fn with_config(config: MicroAdamConfig) -> Self {
370 Self {
371 config,
372 state: OptimizerState::new(),
373 momentum: HashMap::new(),
374 variance: HashMap::new(),
375 compression_stats: CompressionStats::default(),
376 }
377 }
378
379 pub fn memory_savings_ratio(&self) -> f32 {
381 if self.compression_stats.total_parameters > 0 {
382 1.0 - (self.compression_stats.total_compressed_size as f32
383 / (self.compression_stats.total_parameters * 2) as f32)
384 } else {
385 0.0
386 }
387 }
388
389 pub fn compression_statistics(&self) -> String {
391 format!(
392 "MicroAdam Compression Stats:\n\
393 - Total parameters: {}\n\
394 - Compressed size: {}\n\
395 - Memory savings: {:.1}%\n\
396 - Average compression ratio: {:.3}\n\
397 - Average compression error: {:.2e}",
398 self.compression_stats.total_parameters,
399 self.compression_stats.total_compressed_size,
400 self.memory_savings_ratio() * 100.0,
401 self.compression_stats.average_compression_ratio,
402 self.compression_stats.average_compression_error
403 )
404 }
405
406 fn update_compression_stats(
408 &mut self,
409 _param_id: &str,
410 compressed: &CompressedGradient,
411 original_gradient: &[f32],
412 ) {
413 self.compression_stats.total_parameters += compressed.original_size;
414 self.compression_stats.total_compressed_size += compressed.compressed_data.len();
415
416 let compression_ratio = compressed.compression_ratio();
417 let compression_error = compressed.compression_error(original_gradient);
418
419 let total_params = self.compression_stats.total_parameters as f32;
421 self.compression_stats.average_compression_ratio =
422 (self.compression_stats.average_compression_ratio
423 * (total_params - compressed.original_size as f32)
424 + compression_ratio * compressed.original_size as f32)
425 / total_params;
426
427 self.compression_stats.average_compression_error =
428 (self.compression_stats.average_compression_error
429 * (total_params - compressed.original_size as f32)
430 + compression_error * compressed.original_size as f32)
431 / total_params;
432
433 let method_name = format!("{:?}", compressed.compression_type);
435 *self.compression_stats.compression_method_usage.entry(method_name).or_insert(0) += 1;
436 }
437}
438
439impl Default for MicroAdam {
440 fn default() -> Self {
441 Self::new()
442 }
443}
444
445impl Optimizer for MicroAdam {
446 fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
447 let param_id = format!("{:p}", parameter as *const Tensor);
449
450 let grad_data = grad.data()?;
452
453 let compressed_gradient = CompressedGradient::compress(&grad_data, &self.config);
455
456 let compression_error = compressed_gradient.compression_error(&grad_data);
458 if compression_error > self.config.max_compression_error {
459 return Err(TrustformersError::tensor_op_error(
460 &format!(
461 "Compression error {} exceeds maximum allowed {}",
462 compression_error, self.config.max_compression_error
463 ),
464 "MicroAdam::update",
465 ));
466 }
467
468 self.update_compression_stats(¶m_id, &compressed_gradient, &grad_data);
470
471 let momentum = self.momentum.entry(param_id.clone()).or_insert_with(|| {
473 CompressedGradient::compress(&vec![0.0; grad_data.len()], &self.config)
474 });
475
476 let variance = self.variance.entry(param_id.clone()).or_insert_with(|| {
478 CompressedGradient::compress(&vec![0.0; grad_data.len()], &self.config)
479 });
480
481 let mut m = momentum.decompress();
483 let mut v = variance.decompress();
484
485 m.resize(grad_data.len(), 0.0);
487 v.resize(grad_data.len(), 0.0);
488
489 self.state.step();
491
492 let bias_correction1 = if self.config.bias_correction {
494 1.0 - self.config.beta1.powf(self.state.step as f32)
495 } else {
496 1.0
497 };
498
499 let bias_correction2 = if self.config.bias_correction {
500 1.0 - self.config.beta2.powf(self.state.step as f32)
501 } else {
502 1.0
503 };
504
505 for i in 0..grad_data.len() {
507 m[i] = self.config.beta1 * m[i] + (1.0 - self.config.beta1) * grad_data[i];
508 }
509
510 for i in 0..grad_data.len() {
512 v[i] = self.config.beta2 * v[i] + (1.0 - self.config.beta2) * grad_data[i].powi(2);
513 }
514
515 let mut param_data = parameter.data()?;
517 for i in 0..grad_data.len() {
518 let m_hat = m[i] / bias_correction1;
519 let v_hat = v[i] / bias_correction2;
520 let update_val =
521 self.config.learning_rate * m_hat / (v_hat.sqrt() + self.config.epsilon);
522
523 if self.config.weight_decay > 0.0 {
525 param_data[i] *= 1.0 - self.config.learning_rate * self.config.weight_decay;
526 }
527
528 param_data[i] -= update_val;
530 }
531
532 *parameter = Tensor::new(param_data)?;
534
535 *momentum = CompressedGradient::compress(&m, &self.config);
537 *variance = CompressedGradient::compress(&v, &self.config);
538
539 Ok(())
540 }
541
542 fn zero_grad(&mut self) {
543 }
546
547 fn step(&mut self) {
548 }
550
551 fn get_lr(&self) -> f32 {
552 self.config.learning_rate
553 }
554
555 fn set_lr(&mut self, lr: f32) {
556 self.config.learning_rate = lr;
557 }
558}
559
560impl StatefulOptimizer for MicroAdam {
561 type Config = MicroAdamConfig;
562 type State = OptimizerState;
563
564 fn config(&self) -> &Self::Config {
565 &self.config
566 }
567
568 fn state(&self) -> &Self::State {
569 &self.state
570 }
571
572 fn state_mut(&mut self) -> &mut Self::State {
573 &mut self.state
574 }
575
576 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
577 let mut state_dict = HashMap::new();
578
579 for (param_id, momentum) in &self.momentum {
581 let key = format!("momentum.{}", param_id);
582 let tensor = Tensor::new(momentum.decompress())?;
583 state_dict.insert(key, tensor);
584 }
585
586 for (param_id, variance) in &self.variance {
588 let key = format!("variance.{}", param_id);
589 let tensor = Tensor::new(variance.decompress())?;
590 state_dict.insert(key, tensor);
591 }
592
593 state_dict.insert(
595 "step".to_string(),
596 Tensor::new(vec![self.state.step as f32])?,
597 );
598
599 Ok(state_dict)
600 }
601
602 fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()> {
603 if let Some(step_tensor) = state_dict.get("step") {
605 let step_data = step_tensor.data()?;
606 if !step_data.is_empty() {
607 self.state.step = step_data[0] as usize;
608 }
609 }
610
611 for (key, tensor) in &state_dict {
613 if key.starts_with("momentum.") {
614 let param_id = key.strip_prefix("momentum.").unwrap().to_string();
615 let values = tensor.data()?;
616 let compressed = CompressedGradient::compress(&values, &self.config);
617 self.momentum.insert(param_id, compressed);
618 } else if key.starts_with("variance.") {
619 let param_id = key.strip_prefix("variance.").unwrap().to_string();
620 let values = tensor.data()?;
621 let compressed = CompressedGradient::compress(&values, &self.config);
622 self.variance.insert(param_id, compressed);
623 }
624 }
625
626 Ok(())
627 }
628
629 fn memory_usage(&self) -> StateMemoryStats {
630 let momentum_size: usize = self.momentum.values().map(|m| m.compressed_data.len()).sum();
631 let variance_size: usize = self.variance.values().map(|v| v.compressed_data.len()).sum();
632
633 StateMemoryStats {
634 momentum_elements: momentum_size,
635 variance_elements: variance_size,
636 third_moment_elements: 0,
637 total_bytes: (momentum_size + variance_size) * std::mem::size_of::<f32>(),
638 num_parameters: self.momentum.len(),
639 }
640 }
641
642 fn reset_state(&mut self) {
643 self.state.clear();
644 self.momentum.clear();
645 self.variance.clear();
646 self.compression_stats = CompressionStats::default();
647 }
648
649 fn num_parameters(&self) -> usize {
650 self.momentum.len()
651 }
652}
653
654#[cfg(test)]
655mod tests {
656 use super::*;
657
658 #[test]
659 fn test_microadam_creation() {
660 let optimizer = MicroAdam::new();
661 assert_eq!(optimizer.config.learning_rate, 1e-3);
662 assert_eq!(optimizer.config.beta1, 0.9);
663 assert_eq!(optimizer.config.beta2, 0.999);
664 }
666
667 #[test]
668 fn test_microadam_with_config() {
669 let config = MicroAdamConfig {
670 learning_rate: 2e-3,
671 compression_ratio: 0.2,
672 ..Default::default()
673 };
674 let optimizer = MicroAdam::with_config(config);
675 assert_eq!(optimizer.config.learning_rate, 2e-3);
676 assert_eq!(optimizer.config.compression_ratio, 0.2);
677 }
678
679 #[test]
680 fn test_microadam_for_large_models() {
681 let optimizer = MicroAdam::for_large_models();
682 assert_eq!(optimizer.config.learning_rate, 1e-4);
683 assert_eq!(optimizer.config.compression_ratio, 0.05);
684 assert_eq!(optimizer.config.min_block_size, 128);
685 assert!(optimizer.config.adaptive_compression);
686 }
687
688 #[test]
689 fn test_microadam_for_memory_constrained() {
690 let optimizer = MicroAdam::for_memory_constrained();
691 assert_eq!(optimizer.config.compression_ratio, 0.02);
692 assert_eq!(optimizer.config.min_block_size, 32);
693 assert!(optimizer.config.adaptive_compression);
694 }
695
696 #[test]
697 fn test_compressed_gradient_topk() {
698 let gradient = vec![0.1, 0.05, 0.2, 0.01, 0.15, 0.03];
699 let _config = MicroAdamConfig::default();
700 let compressed = CompressedGradient::compress_topk(&gradient, 3);
701
702 assert_eq!(compressed.compressed_data.len(), 3);
703 assert_eq!(compressed.indices.len(), 3);
704 assert_eq!(compressed.original_size, 6);
705
706 let mut expected_indices = vec![2, 4, 0];
708 let mut actual_indices = compressed.indices.clone();
709 expected_indices.sort();
710 actual_indices.sort();
711 assert_eq!(actual_indices, expected_indices);
712 }
713
714 #[test]
715 fn test_compressed_gradient_threshold() {
716 let gradient = vec![0.1, 0.001, 0.2, 0.0001, 0.15, 0.0003];
717 let config = MicroAdamConfig {
718 compression_threshold: 0.05,
719 ..Default::default()
720 };
721 let compressed = CompressedGradient::compress_threshold(&gradient, &config);
722
723 assert_eq!(compressed.compressed_data.len(), 3);
725 assert_eq!(compressed.indices.len(), 3);
726
727 let mut expected_indices = vec![0, 2, 4];
728 let mut actual_indices = compressed.indices.clone();
729 expected_indices.sort();
730 actual_indices.sort();
731 assert_eq!(actual_indices, expected_indices);
732 }
733
734 #[test]
735 fn test_compression_decompress_cycle() {
736 let gradient = vec![0.1, 0.05, 0.2, 0.01, 0.15, 0.03];
737 let config = MicroAdamConfig::default();
738 let compressed = CompressedGradient::compress(&gradient, &config);
739 let decompressed = compressed.decompress();
740
741 assert_eq!(decompressed.len(), gradient.len());
742
743 for (i, &original) in gradient.iter().enumerate() {
745 if original.abs() > 0.08 {
746 assert!(
748 decompressed[i].abs() > 0.0,
749 "Significant value at index {} was lost",
750 i
751 );
752 }
753 }
754 }
755
756 #[test]
757 fn test_compression_error_calculation() {
758 let gradient = vec![0.1, 0.05, 0.2, 0.01, 0.15, 0.03];
759 let config = MicroAdamConfig::default();
760 let compressed = CompressedGradient::compress(&gradient, &config);
761 let error = compressed.compression_error(&gradient);
762
763 assert!(error >= 0.0);
764 assert!(error <= 1.0); }
766
767 #[test]
768 fn test_microadam_update() -> Result<()> {
769 let mut optimizer = MicroAdam::new();
770 let gradient_data = vec![0.1, -0.05, 0.2, -0.01];
771 let gradient = Tensor::new(gradient_data.clone())?;
772 let mut parameter = Tensor::new(vec![1.0, 1.0, 1.0, 1.0])?;
773
774 optimizer.update(&mut parameter, &gradient)?;
775
776 assert_eq!(optimizer.state().step, 1);
778
779 let param_data = parameter.data()?;
781 assert_eq!(param_data.len(), gradient_data.len());
782
783 assert_ne!(param_data[0], 1.0);
785
786 Ok(())
787 }
788
789 #[test]
790 fn test_microadam_multiple_updates() -> Result<()> {
791 let mut optimizer = MicroAdam::new();
792 let gradient_data = vec![0.1, -0.05, 0.2, -0.01];
793 let gradient = Tensor::new(gradient_data)?;
794 let mut parameter = Tensor::new(vec![1.0, 1.0, 1.0, 1.0])?;
795
796 for i in 1..=5 {
798 optimizer.update(&mut parameter, &gradient)?;
799 assert_eq!(optimizer.state().step, i);
800 }
801
802 Ok(())
803 }
804
805 #[test]
806 fn test_memory_savings_ratio() {
807 let mut config = MicroAdamConfig::default();
808 config.max_compression_error = 1.0; let mut optimizer = MicroAdam::with_config(config);
810
811 assert_eq!(optimizer.memory_savings_ratio(), 0.0);
813
814 let gradient_data = vec![0.1; 1000]; let gradient = Tensor::new(gradient_data).unwrap();
817 let mut parameter = Tensor::new(vec![1.0; 1000]).unwrap();
818 optimizer.update(&mut parameter, &gradient).unwrap();
819
820 let savings = optimizer.memory_savings_ratio();
821 assert!(savings > 0.0, "Should show memory savings");
822 assert!(savings < 1.0, "Savings ratio should be less than 100%");
823 }
824
825 #[test]
826 fn test_compression_statistics() {
827 let mut config = MicroAdamConfig::default();
828 config.max_compression_error = 1.0; let mut optimizer = MicroAdam::with_config(config);
830 let gradient_data = vec![0.1; 500];
831 let gradient = Tensor::new(gradient_data).unwrap();
832 let mut parameter = Tensor::new(vec![1.0; 500]).unwrap();
833
834 optimizer.update(&mut parameter, &gradient).unwrap();
835
836 let stats = optimizer.compression_statistics();
837 assert!(stats.contains("MicroAdam Compression Stats"));
838 assert!(stats.contains("Total parameters: 500"));
839 assert!(stats.contains("Memory savings"));
840 assert!(stats.contains("compression ratio"));
841 }
842
843 #[test]
844 fn test_learning_rate_setter_getter() {
845 let mut optimizer = MicroAdam::new();
846 assert_eq!(optimizer.get_lr(), 1e-3);
847
848 optimizer.set_lr(2e-3);
849 assert_eq!(optimizer.get_lr(), 2e-3);
850 }
851
852 #[test]
853 fn test_state_dict_operations() -> Result<()> {
854 let mut optimizer = MicroAdam::new();
855 let gradient_data = vec![0.1, -0.05, 0.2];
856 let gradient = Tensor::new(gradient_data)?;
857 let mut param1 = Tensor::new(vec![1.0, 1.0, 1.0])?;
858 let mut param2 = Tensor::new(vec![2.0, 2.0, 2.0])?;
859
860 optimizer.update(&mut param1, &gradient)?;
862 optimizer.update(&mut param2, &gradient)?;
863
864 let state_dict = optimizer.state_dict()?;
866 assert!(state_dict.contains_key("step"));
867
868 let mut new_optimizer = MicroAdam::new();
870 new_optimizer.load_state_dict(state_dict)?;
871
872 assert_eq!(new_optimizer.state().step, optimizer.state().step);
873
874 Ok(())
875 }
876
877 #[test]
878 fn test_memory_usage_tracking() -> Result<()> {
879 let mut config = MicroAdamConfig::default();
880 config.max_compression_error = 1.0; let mut optimizer = MicroAdam::with_config(config);
882 let initial_usage = optimizer.memory_usage();
883
884 let gradient_data = vec![0.1; 1000];
885 let gradient = Tensor::new(gradient_data)?;
886 let mut parameter = Tensor::new(vec![1.0; 1000])?;
887 optimizer.update(&mut parameter, &gradient)?;
888
889 let after_usage = optimizer.memory_usage();
890 assert!(after_usage.total_bytes > initial_usage.total_bytes);
891 assert!(after_usage.momentum_elements > 0);
892 assert!(after_usage.variance_elements > 0);
893
894 Ok(())
895 }
896
897 #[test]
898 fn test_adaptive_compression_selection() {
899 let sparse_gradient = vec![0.0; 1000]; let dense_gradient = vec![0.1; 1000]; let config = MicroAdamConfig {
903 adaptive_compression: true,
904 compression_threshold: 1e-6,
905 ..Default::default()
906 };
907
908 let sparse_compression =
909 CompressedGradient::choose_adaptive_compression(&sparse_gradient, &config);
910 let dense_compression =
911 CompressedGradient::choose_adaptive_compression(&dense_gradient, &config);
912
913 match sparse_compression {
916 CompressionType::Threshold
917 | CompressionType::TopK
918 | CompressionType::BlockWise
919 | CompressionType::Adaptive => {},
920 }
921
922 match dense_compression {
923 CompressionType::Threshold
924 | CompressionType::TopK
925 | CompressionType::BlockWise
926 | CompressionType::Adaptive => {},
927 }
928 }
929}