Skip to main content

trustformers_optim/
microadam.rs

1//! # MicroAdam Optimizer
2//!
3//! Implementation of MicroAdam from NeurIPS 2024: "Accurate adaptive optimization with low space overhead".
4//! This optimizer provides Adam-like convergence guarantees while significantly reducing memory overhead
5//! through compressed gradient storage and efficient state management.
6//!
7//! ## Key Features
8//!
9//! - **Low Memory Overhead**: Compressed storage with provable convergence guarantees
10//! - **Adaptive Compression**: Dynamic compression based on gradient characteristics
11//! - **Theoretical Guarantees**: Maintains Adam's convergence properties with reduced space
12//! - **Efficient State Updates**: Optimized state transitions with minimal memory allocation
13//!
14//! ## Research Background
15//!
16//! MicroAdam addresses the memory bottleneck in large-scale optimization by introducing
17//! efficient compression techniques that preserve the essential information needed for
18//! convergence while dramatically reducing storage requirements.
19
20use crate::common::{OptimizerState, StateMemoryStats};
21use crate::traits::StatefulOptimizer;
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24use trustformers_core::errors::{Result, TrustformersError};
25use trustformers_core::tensor::Tensor;
26use trustformers_core::traits::Optimizer;
27
28/// Configuration for MicroAdam optimizer
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct MicroAdamConfig {
31    /// Learning rate (default: 1e-3)
32    pub learning_rate: f32,
33    /// Coefficient for computing first moment (default: 0.9)
34    pub beta1: f32,
35    /// Coefficient for computing second moment (default: 0.999)
36    pub beta2: f32,
37    /// Small constant for numerical stability (default: 1e-8)
38    pub epsilon: f32,
39    /// Weight decay coefficient (default: 0.01)
40    pub weight_decay: f32,
41    /// Compression ratio for gradient storage (default: 0.1 = 90% compression)
42    pub compression_ratio: f32,
43    /// Minimum compression block size (default: 64)
44    pub min_block_size: usize,
45    /// Enable adaptive compression based on gradient sparsity (default: true)
46    pub adaptive_compression: bool,
47    /// Threshold for gradient compression (default: 1e-6)
48    pub compression_threshold: f32,
49    /// Use bias correction (default: true)
50    pub bias_correction: bool,
51    /// Maximum compression error tolerance (default: 1e-4)
52    pub max_compression_error: f32,
53}
54
55impl Default for MicroAdamConfig {
56    fn default() -> Self {
57        Self {
58            learning_rate: 1e-3,
59            beta1: 0.9,
60            beta2: 0.999,
61            epsilon: 1e-8,
62            weight_decay: 0.01,
63            compression_ratio: 0.1,
64            min_block_size: 64,
65            adaptive_compression: true,
66            compression_threshold: 1e-6,
67            bias_correction: true,
68            max_compression_error: 1e-4,
69        }
70    }
71}
72
73/// Compressed gradient storage for memory efficiency
74#[derive(Debug, Clone)]
75struct CompressedGradient {
76    /// Compressed gradient values
77    compressed_data: Vec<f32>,
78    /// Indices of significant gradient components
79    indices: Vec<usize>,
80    /// Scale factor for reconstruction
81    scale_factor: f32,
82    /// Original gradient size
83    original_size: usize,
84    /// Compression method used
85    compression_type: CompressionType,
86}
87
88/// Available compression methods
89#[derive(Debug, Clone, Copy)]
90enum CompressionType {
91    /// Top-K sparsification with adaptive threshold
92    TopK,
93    /// Magnitude-based thresholding
94    Threshold,
95    /// Block-wise compression
96    BlockWise,
97    /// Adaptive hybrid compression
98    #[allow(dead_code)]
99    Adaptive,
100}
101
102impl CompressedGradient {
103    /// Compress gradient using specified method and ratio
104    fn compress(gradient: &[f32], config: &MicroAdamConfig) -> Self {
105        let original_size = gradient.len();
106        let target_size = (original_size as f32 * config.compression_ratio) as usize;
107        let target_size = target_size.max(config.min_block_size.min(original_size));
108
109        let compression_type = if config.adaptive_compression {
110            // Choose compression method based on gradient characteristics
111            Self::choose_adaptive_compression(gradient, config)
112        } else {
113            CompressionType::TopK
114        };
115
116        match compression_type {
117            CompressionType::TopK => Self::compress_topk(gradient, target_size),
118            CompressionType::Threshold => Self::compress_threshold(gradient, config),
119            CompressionType::BlockWise => Self::compress_blockwise(gradient, config),
120            CompressionType::Adaptive => Self::compress_adaptive(gradient, config),
121        }
122    }
123
124    /// Choose optimal compression method based on gradient characteristics
125    fn choose_adaptive_compression(gradient: &[f32], config: &MicroAdamConfig) -> CompressionType {
126        let mean_abs = gradient.iter().map(|x| x.abs()).sum::<f32>() / gradient.len() as f32;
127        let sparsity = gradient.iter().filter(|&&x| x.abs() < config.compression_threshold).count()
128            as f32
129            / gradient.len() as f32;
130
131        if sparsity > 0.8 {
132            CompressionType::Threshold
133        } else if mean_abs > 1e-3 {
134            CompressionType::BlockWise
135        } else {
136            CompressionType::TopK
137        }
138    }
139
140    /// Top-K compression with magnitude-based selection
141    fn compress_topk(gradient: &[f32], k: usize) -> Self {
142        let mut indexed_values: Vec<(usize, f32)> =
143            gradient.iter().enumerate().map(|(i, &val)| (i, val.abs())).collect();
144
145        // Sort by magnitude (descending)
146        indexed_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
147
148        let k = k.min(indexed_values.len());
149        let indices: Vec<usize> = indexed_values[..k].iter().map(|(i, _)| *i).collect();
150        let compressed_data: Vec<f32> = indices.iter().map(|&i| gradient[i]).collect();
151
152        // Calculate scale factor for better reconstruction
153        let max_val = compressed_data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
154        let scale_factor = if max_val > 0.0 { 1.0 / max_val } else { 1.0 };
155
156        Self {
157            compressed_data: compressed_data.iter().map(|x| x * scale_factor).collect(),
158            indices,
159            scale_factor: 1.0 / scale_factor,
160            original_size: gradient.len(),
161            compression_type: CompressionType::TopK,
162        }
163    }
164
165    /// Threshold-based compression
166    fn compress_threshold(gradient: &[f32], config: &MicroAdamConfig) -> Self {
167        let threshold = config.compression_threshold;
168        let mut indices = Vec::new();
169        let mut compressed_data = Vec::new();
170
171        for (i, &val) in gradient.iter().enumerate() {
172            if val.abs() >= threshold {
173                indices.push(i);
174                compressed_data.push(val);
175            }
176        }
177
178        let max_val = compressed_data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
179        let scale_factor = if max_val > 0.0 { 1.0 / max_val } else { 1.0 };
180
181        Self {
182            compressed_data: compressed_data.iter().map(|x| x * scale_factor).collect(),
183            indices,
184            scale_factor: 1.0 / scale_factor,
185            original_size: gradient.len(),
186            compression_type: CompressionType::Threshold,
187        }
188    }
189
190    /// Block-wise compression with local optimization
191    fn compress_blockwise(gradient: &[f32], config: &MicroAdamConfig) -> Self {
192        let block_size = config.min_block_size;
193        let num_blocks = gradient.len().div_ceil(block_size);
194        let target_elements_per_block =
195            ((block_size as f32 * config.compression_ratio) as usize).max(1);
196
197        let mut indices = Vec::new();
198        let mut compressed_data = Vec::new();
199
200        for block_idx in 0..num_blocks {
201            let start = block_idx * block_size;
202            let end = (start + block_size).min(gradient.len());
203            let block = &gradient[start..end];
204
205            // Find top elements in this block
206            let mut block_indexed: Vec<(usize, f32)> =
207                block.iter().enumerate().map(|(i, &val)| (start + i, val.abs())).collect();
208
209            block_indexed
210                .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
211
212            let k = target_elements_per_block.min(block_indexed.len());
213            for i in 0..k {
214                let global_idx = block_indexed[i].0;
215                indices.push(global_idx);
216                compressed_data.push(gradient[global_idx]);
217            }
218        }
219
220        let max_val = compressed_data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
221        let scale_factor = if max_val > 0.0 { 1.0 / max_val } else { 1.0 };
222
223        Self {
224            compressed_data: compressed_data.iter().map(|x| x * scale_factor).collect(),
225            indices,
226            scale_factor: 1.0 / scale_factor,
227            original_size: gradient.len(),
228            compression_type: CompressionType::BlockWise,
229        }
230    }
231
232    /// Adaptive compression combining multiple methods
233    fn compress_adaptive(gradient: &[f32], config: &MicroAdamConfig) -> Self {
234        // Try multiple compression methods and choose the best one
235        let topk = Self::compress_topk(
236            gradient,
237            (gradient.len() as f32 * config.compression_ratio) as usize,
238        );
239        let threshold = Self::compress_threshold(gradient, config);
240        let blockwise = Self::compress_blockwise(gradient, config);
241
242        // Choose based on compression efficiency and error
243        let topk_ratio = topk.compressed_data.len() as f32 / gradient.len() as f32;
244        let threshold_ratio = threshold.compressed_data.len() as f32 / gradient.len() as f32;
245        let blockwise_ratio = blockwise.compressed_data.len() as f32 / gradient.len() as f32;
246
247        if threshold_ratio <= config.compression_ratio && threshold_ratio < topk_ratio {
248            threshold
249        } else if blockwise_ratio <= config.compression_ratio && blockwise_ratio < topk_ratio {
250            blockwise
251        } else {
252            topk
253        }
254    }
255
256    /// Decompress gradient back to original size
257    fn decompress(&self) -> Vec<f32> {
258        let mut result = vec![0.0; self.original_size];
259        for (i, &idx) in self.indices.iter().enumerate() {
260            if idx < self.original_size && i < self.compressed_data.len() {
261                result[idx] = self.compressed_data[i] * self.scale_factor;
262            }
263        }
264        result
265    }
266
267    /// Calculate compression ratio achieved
268    fn compression_ratio(&self) -> f32 {
269        self.compressed_data.len() as f32 / self.original_size as f32
270    }
271
272    /// Estimate compression error
273    fn compression_error(&self, original: &[f32]) -> f32 {
274        let decompressed = self.decompress();
275        let mut error_sum = 0.0;
276        let mut norm_sum = 0.0;
277
278        for (orig, decomp) in original.iter().zip(decompressed.iter()) {
279            error_sum += (orig - decomp).powi(2);
280            norm_sum += orig.powi(2);
281        }
282
283        if norm_sum > 0.0 {
284            (error_sum / norm_sum).sqrt()
285        } else {
286            0.0
287        }
288    }
289}
290
291/// MicroAdam optimizer implementation
292///
293/// Provides memory-efficient Adam optimization through compressed gradient storage
294/// while maintaining convergence guarantees through careful state management.
295#[derive(Debug)]
296pub struct MicroAdam {
297    config: MicroAdamConfig,
298    state: OptimizerState,
299    /// First moment estimates (compressed)
300    momentum: HashMap<String, CompressedGradient>,
301    /// Second moment estimates (compressed)
302    variance: HashMap<String, CompressedGradient>,
303    /// Compression statistics for monitoring
304    compression_stats: CompressionStats,
305}
306
307/// Statistics for monitoring compression performance
308#[derive(Debug, Default)]
309struct CompressionStats {
310    total_parameters: usize,
311    total_compressed_size: usize,
312    average_compression_ratio: f32,
313    average_compression_error: f32,
314    compression_method_usage: HashMap<String, usize>,
315}
316
317impl MicroAdam {
318    /// Create a new MicroAdam optimizer with default configuration
319    pub fn new() -> Self {
320        Self::with_config(MicroAdamConfig::default())
321    }
322
323    /// Create MicroAdam with custom learning rate
324    pub fn new_with_lr(learning_rate: f32) -> Self {
325        let config = MicroAdamConfig {
326            learning_rate,
327            ..Default::default()
328        };
329        Self::with_config(config)
330    }
331
332    /// Create MicroAdam for large language models with optimized compression
333    pub fn for_large_models() -> Self {
334        let config = MicroAdamConfig {
335            learning_rate: 1e-4,
336            beta1: 0.9,
337            beta2: 0.999,
338            epsilon: 1e-8,
339            weight_decay: 0.01,
340            compression_ratio: 0.05, // Higher compression for large models
341            min_block_size: 128,
342            adaptive_compression: true,
343            compression_threshold: 1e-7,
344            bias_correction: true,
345            max_compression_error: 1e-5,
346        };
347        Self::with_config(config)
348    }
349
350    /// Create MicroAdam for memory-constrained environments
351    pub fn for_memory_constrained() -> Self {
352        let config = MicroAdamConfig {
353            learning_rate: 1e-3,
354            beta1: 0.9,
355            beta2: 0.999,
356            epsilon: 1e-8,
357            weight_decay: 0.01,
358            compression_ratio: 0.02, // Aggressive compression
359            min_block_size: 32,
360            adaptive_compression: true,
361            compression_threshold: 1e-6,
362            bias_correction: true,
363            max_compression_error: 1e-4,
364        };
365        Self::with_config(config)
366    }
367
368    /// Create MicroAdam with custom configuration
369    pub fn with_config(config: MicroAdamConfig) -> Self {
370        Self {
371            config,
372            state: OptimizerState::new(),
373            momentum: HashMap::new(),
374            variance: HashMap::new(),
375            compression_stats: CompressionStats::default(),
376        }
377    }
378
379    /// Get memory savings compared to standard Adam
380    pub fn memory_savings_ratio(&self) -> f32 {
381        if self.compression_stats.total_parameters > 0 {
382            1.0 - (self.compression_stats.total_compressed_size as f32
383                / (self.compression_stats.total_parameters * 2) as f32)
384        } else {
385            0.0
386        }
387    }
388
389    /// Get compression statistics
390    pub fn compression_statistics(&self) -> String {
391        format!(
392            "MicroAdam Compression Stats:\n\
393             - Total parameters: {}\n\
394             - Compressed size: {}\n\
395             - Memory savings: {:.1}%\n\
396             - Average compression ratio: {:.3}\n\
397             - Average compression error: {:.2e}",
398            self.compression_stats.total_parameters,
399            self.compression_stats.total_compressed_size,
400            self.memory_savings_ratio() * 100.0,
401            self.compression_stats.average_compression_ratio,
402            self.compression_stats.average_compression_error
403        )
404    }
405
406    /// Update compression statistics
407    fn update_compression_stats(
408        &mut self,
409        _param_id: &str,
410        compressed: &CompressedGradient,
411        original_gradient: &[f32],
412    ) {
413        self.compression_stats.total_parameters += compressed.original_size;
414        self.compression_stats.total_compressed_size += compressed.compressed_data.len();
415
416        let compression_ratio = compressed.compression_ratio();
417        let compression_error = compressed.compression_error(original_gradient);
418
419        // Update averages
420        let total_params = self.compression_stats.total_parameters as f32;
421        self.compression_stats.average_compression_ratio =
422            (self.compression_stats.average_compression_ratio
423                * (total_params - compressed.original_size as f32)
424                + compression_ratio * compressed.original_size as f32)
425                / total_params;
426
427        self.compression_stats.average_compression_error =
428            (self.compression_stats.average_compression_error
429                * (total_params - compressed.original_size as f32)
430                + compression_error * compressed.original_size as f32)
431                / total_params;
432
433        // Track compression method usage
434        let method_name = format!("{:?}", compressed.compression_type);
435        *self.compression_stats.compression_method_usage.entry(method_name).or_insert(0) += 1;
436    }
437}
438
439impl Default for MicroAdam {
440    fn default() -> Self {
441        Self::new()
442    }
443}
444
445impl Optimizer for MicroAdam {
446    fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
447        // Generate parameter ID from memory address
448        let param_id = format!("{:p}", parameter as *const Tensor);
449
450        // Extract gradient data
451        let grad_data = grad.data()?;
452
453        // Compress gradient for storage efficiency
454        let compressed_gradient = CompressedGradient::compress(&grad_data, &self.config);
455
456        // Check compression error
457        let compression_error = compressed_gradient.compression_error(&grad_data);
458        if compression_error > self.config.max_compression_error {
459            return Err(TrustformersError::tensor_op_error(
460                &format!(
461                    "Compression error {} exceeds maximum allowed {}",
462                    compression_error, self.config.max_compression_error
463                ),
464                "MicroAdam::update",
465            ));
466        }
467
468        // Update compression statistics
469        self.update_compression_stats(&param_id, &compressed_gradient, &grad_data);
470
471        // Get or initialize compressed momentum
472        let momentum = self.momentum.entry(param_id.clone()).or_insert_with(|| {
473            CompressedGradient::compress(&vec![0.0; grad_data.len()], &self.config)
474        });
475
476        // Get or initialize compressed variance
477        let variance = self.variance.entry(param_id.clone()).or_insert_with(|| {
478            CompressedGradient::compress(&vec![0.0; grad_data.len()], &self.config)
479        });
480
481        // Decompress for computation
482        let mut m = momentum.decompress();
483        let mut v = variance.decompress();
484
485        // Ensure sizes match
486        m.resize(grad_data.len(), 0.0);
487        v.resize(grad_data.len(), 0.0);
488
489        // Update step count
490        self.state.step();
491
492        // Compute bias correction factors
493        let bias_correction1 = if self.config.bias_correction {
494            1.0 - self.config.beta1.powf(self.state.step as f32)
495        } else {
496            1.0
497        };
498
499        let bias_correction2 = if self.config.bias_correction {
500            1.0 - self.config.beta2.powf(self.state.step as f32)
501        } else {
502            1.0
503        };
504
505        // Update biased first moment estimate
506        for i in 0..grad_data.len() {
507            m[i] = self.config.beta1 * m[i] + (1.0 - self.config.beta1) * grad_data[i];
508        }
509
510        // Update biased second moment estimate
511        for i in 0..grad_data.len() {
512            v[i] = self.config.beta2 * v[i] + (1.0 - self.config.beta2) * grad_data[i].powi(2);
513        }
514
515        // Apply parameter updates directly
516        let mut param_data = parameter.data()?;
517        for i in 0..grad_data.len() {
518            let m_hat = m[i] / bias_correction1;
519            let v_hat = v[i] / bias_correction2;
520            let update_val =
521                self.config.learning_rate * m_hat / (v_hat.sqrt() + self.config.epsilon);
522
523            // Apply weight decay if specified
524            if self.config.weight_decay > 0.0 {
525                param_data[i] *= 1.0 - self.config.learning_rate * self.config.weight_decay;
526            }
527
528            // Apply the update
529            param_data[i] -= update_val;
530        }
531
532        // Update parameter with new data
533        *parameter = Tensor::new(param_data)?;
534
535        // Recompress and store updated moments
536        *momentum = CompressedGradient::compress(&m, &self.config);
537        *variance = CompressedGradient::compress(&v, &self.config);
538
539        Ok(())
540    }
541
542    fn zero_grad(&mut self) {
543        // MicroAdam doesn't accumulate gradients in the traditional sense
544        // as it compresses them immediately
545    }
546
547    fn step(&mut self) {
548        // Updates are handled in the update() method
549    }
550
551    fn get_lr(&self) -> f32 {
552        self.config.learning_rate
553    }
554
555    fn set_lr(&mut self, lr: f32) {
556        self.config.learning_rate = lr;
557    }
558}
559
560impl StatefulOptimizer for MicroAdam {
561    type Config = MicroAdamConfig;
562    type State = OptimizerState;
563
564    fn config(&self) -> &Self::Config {
565        &self.config
566    }
567
568    fn state(&self) -> &Self::State {
569        &self.state
570    }
571
572    fn state_mut(&mut self) -> &mut Self::State {
573        &mut self.state
574    }
575
576    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
577        let mut state_dict = HashMap::new();
578
579        // Store compressed momentum
580        for (param_id, momentum) in &self.momentum {
581            let key = format!("momentum.{}", param_id);
582            let tensor = Tensor::new(momentum.decompress())?;
583            state_dict.insert(key, tensor);
584        }
585
586        // Store compressed variance
587        for (param_id, variance) in &self.variance {
588            let key = format!("variance.{}", param_id);
589            let tensor = Tensor::new(variance.decompress())?;
590            state_dict.insert(key, tensor);
591        }
592
593        // Store step count
594        state_dict.insert(
595            "step".to_string(),
596            Tensor::new(vec![self.state.step as f32])?,
597        );
598
599        Ok(state_dict)
600    }
601
602    fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()> {
603        // Load step count
604        if let Some(step_tensor) = state_dict.get("step") {
605            let step_data = step_tensor.data()?;
606            if !step_data.is_empty() {
607                self.state.step = step_data[0] as usize;
608            }
609        }
610
611        // Load and compress momentum states
612        for (key, tensor) in &state_dict {
613            if key.starts_with("momentum.") {
614                let param_id = key.strip_prefix("momentum.").unwrap().to_string();
615                let values = tensor.data()?;
616                let compressed = CompressedGradient::compress(&values, &self.config);
617                self.momentum.insert(param_id, compressed);
618            } else if key.starts_with("variance.") {
619                let param_id = key.strip_prefix("variance.").unwrap().to_string();
620                let values = tensor.data()?;
621                let compressed = CompressedGradient::compress(&values, &self.config);
622                self.variance.insert(param_id, compressed);
623            }
624        }
625
626        Ok(())
627    }
628
629    fn memory_usage(&self) -> StateMemoryStats {
630        let momentum_size: usize = self.momentum.values().map(|m| m.compressed_data.len()).sum();
631        let variance_size: usize = self.variance.values().map(|v| v.compressed_data.len()).sum();
632
633        StateMemoryStats {
634            momentum_elements: momentum_size,
635            variance_elements: variance_size,
636            third_moment_elements: 0,
637            total_bytes: (momentum_size + variance_size) * std::mem::size_of::<f32>(),
638            num_parameters: self.momentum.len(),
639        }
640    }
641
642    fn reset_state(&mut self) {
643        self.state.clear();
644        self.momentum.clear();
645        self.variance.clear();
646        self.compression_stats = CompressionStats::default();
647    }
648
649    fn num_parameters(&self) -> usize {
650        self.momentum.len()
651    }
652}
653
654#[cfg(test)]
655mod tests {
656    use super::*;
657
658    #[test]
659    fn test_microadam_creation() {
660        let optimizer = MicroAdam::new();
661        assert_eq!(optimizer.config.learning_rate, 1e-3);
662        assert_eq!(optimizer.config.beta1, 0.9);
663        assert_eq!(optimizer.config.beta2, 0.999);
664        // Basic creation test - no name() method test needed
665    }
666
667    #[test]
668    fn test_microadam_with_config() {
669        let config = MicroAdamConfig {
670            learning_rate: 2e-3,
671            compression_ratio: 0.2,
672            ..Default::default()
673        };
674        let optimizer = MicroAdam::with_config(config);
675        assert_eq!(optimizer.config.learning_rate, 2e-3);
676        assert_eq!(optimizer.config.compression_ratio, 0.2);
677    }
678
679    #[test]
680    fn test_microadam_for_large_models() {
681        let optimizer = MicroAdam::for_large_models();
682        assert_eq!(optimizer.config.learning_rate, 1e-4);
683        assert_eq!(optimizer.config.compression_ratio, 0.05);
684        assert_eq!(optimizer.config.min_block_size, 128);
685        assert!(optimizer.config.adaptive_compression);
686    }
687
688    #[test]
689    fn test_microadam_for_memory_constrained() {
690        let optimizer = MicroAdam::for_memory_constrained();
691        assert_eq!(optimizer.config.compression_ratio, 0.02);
692        assert_eq!(optimizer.config.min_block_size, 32);
693        assert!(optimizer.config.adaptive_compression);
694    }
695
696    #[test]
697    fn test_compressed_gradient_topk() {
698        let gradient = vec![0.1, 0.05, 0.2, 0.01, 0.15, 0.03];
699        let _config = MicroAdamConfig::default();
700        let compressed = CompressedGradient::compress_topk(&gradient, 3);
701
702        assert_eq!(compressed.compressed_data.len(), 3);
703        assert_eq!(compressed.indices.len(), 3);
704        assert_eq!(compressed.original_size, 6);
705
706        // Should select indices 2, 4, 0 (values 0.2, 0.15, 0.1)
707        let mut expected_indices = vec![2, 4, 0];
708        let mut actual_indices = compressed.indices.clone();
709        expected_indices.sort();
710        actual_indices.sort();
711        assert_eq!(actual_indices, expected_indices);
712    }
713
714    #[test]
715    fn test_compressed_gradient_threshold() {
716        let gradient = vec![0.1, 0.001, 0.2, 0.0001, 0.15, 0.0003];
717        let config = MicroAdamConfig {
718            compression_threshold: 0.05,
719            ..Default::default()
720        };
721        let compressed = CompressedGradient::compress_threshold(&gradient, &config);
722
723        // Should keep values >= 0.05: indices 0, 2, 4 (values 0.1, 0.2, 0.15)
724        assert_eq!(compressed.compressed_data.len(), 3);
725        assert_eq!(compressed.indices.len(), 3);
726
727        let mut expected_indices = vec![0, 2, 4];
728        let mut actual_indices = compressed.indices.clone();
729        expected_indices.sort();
730        actual_indices.sort();
731        assert_eq!(actual_indices, expected_indices);
732    }
733
734    #[test]
735    fn test_compression_decompress_cycle() {
736        let gradient = vec![0.1, 0.05, 0.2, 0.01, 0.15, 0.03];
737        let config = MicroAdamConfig::default();
738        let compressed = CompressedGradient::compress(&gradient, &config);
739        let decompressed = compressed.decompress();
740
741        assert_eq!(decompressed.len(), gradient.len());
742
743        // Check that significant values are preserved
744        for (i, &original) in gradient.iter().enumerate() {
745            if original.abs() > 0.08 {
746                // Significant values
747                assert!(
748                    decompressed[i].abs() > 0.0,
749                    "Significant value at index {} was lost",
750                    i
751                );
752            }
753        }
754    }
755
756    #[test]
757    fn test_compression_error_calculation() {
758        let gradient = vec![0.1, 0.05, 0.2, 0.01, 0.15, 0.03];
759        let config = MicroAdamConfig::default();
760        let compressed = CompressedGradient::compress(&gradient, &config);
761        let error = compressed.compression_error(&gradient);
762
763        assert!(error >= 0.0);
764        assert!(error <= 1.0); // Relative error should be reasonable
765    }
766
767    #[test]
768    fn test_microadam_update() -> Result<()> {
769        let mut optimizer = MicroAdam::new();
770        let gradient_data = vec![0.1, -0.05, 0.2, -0.01];
771        let gradient = Tensor::new(gradient_data.clone())?;
772        let mut parameter = Tensor::new(vec![1.0, 1.0, 1.0, 1.0])?;
773
774        optimizer.update(&mut parameter, &gradient)?;
775
776        // Check that optimizer state was updated
777        assert_eq!(optimizer.state().step, 1);
778
779        // Check that parameter was updated
780        let param_data = parameter.data()?;
781        assert_eq!(param_data.len(), gradient_data.len());
782
783        // Parameter values should have changed from initial [1.0, 1.0, 1.0, 1.0]
784        assert_ne!(param_data[0], 1.0);
785
786        Ok(())
787    }
788
789    #[test]
790    fn test_microadam_multiple_updates() -> Result<()> {
791        let mut optimizer = MicroAdam::new();
792        let gradient_data = vec![0.1, -0.05, 0.2, -0.01];
793        let gradient = Tensor::new(gradient_data)?;
794        let mut parameter = Tensor::new(vec![1.0, 1.0, 1.0, 1.0])?;
795
796        // Multiple updates
797        for i in 1..=5 {
798            optimizer.update(&mut parameter, &gradient)?;
799            assert_eq!(optimizer.state().step, i);
800        }
801
802        Ok(())
803    }
804
805    #[test]
806    fn test_memory_savings_ratio() {
807        let mut config = MicroAdamConfig::default();
808        config.max_compression_error = 1.0; // Allow higher compression error for tests
809        let mut optimizer = MicroAdam::with_config(config);
810
811        // Initially no savings
812        assert_eq!(optimizer.memory_savings_ratio(), 0.0);
813
814        // After processing some parameters, should show savings
815        let gradient_data = vec![0.1; 1000]; // Large gradient
816        let gradient = Tensor::new(gradient_data).unwrap();
817        let mut parameter = Tensor::new(vec![1.0; 1000]).unwrap();
818        optimizer.update(&mut parameter, &gradient).unwrap();
819
820        let savings = optimizer.memory_savings_ratio();
821        assert!(savings > 0.0, "Should show memory savings");
822        assert!(savings < 1.0, "Savings ratio should be less than 100%");
823    }
824
825    #[test]
826    fn test_compression_statistics() {
827        let mut config = MicroAdamConfig::default();
828        config.max_compression_error = 1.0; // Allow higher compression error for tests
829        let mut optimizer = MicroAdam::with_config(config);
830        let gradient_data = vec![0.1; 500];
831        let gradient = Tensor::new(gradient_data).unwrap();
832        let mut parameter = Tensor::new(vec![1.0; 500]).unwrap();
833
834        optimizer.update(&mut parameter, &gradient).unwrap();
835
836        let stats = optimizer.compression_statistics();
837        assert!(stats.contains("MicroAdam Compression Stats"));
838        assert!(stats.contains("Total parameters: 500"));
839        assert!(stats.contains("Memory savings"));
840        assert!(stats.contains("compression ratio"));
841    }
842
843    #[test]
844    fn test_learning_rate_setter_getter() {
845        let mut optimizer = MicroAdam::new();
846        assert_eq!(optimizer.get_lr(), 1e-3);
847
848        optimizer.set_lr(2e-3);
849        assert_eq!(optimizer.get_lr(), 2e-3);
850    }
851
852    #[test]
853    fn test_state_dict_operations() -> Result<()> {
854        let mut optimizer = MicroAdam::new();
855        let gradient_data = vec![0.1, -0.05, 0.2];
856        let gradient = Tensor::new(gradient_data)?;
857        let mut param1 = Tensor::new(vec![1.0, 1.0, 1.0])?;
858        let mut param2 = Tensor::new(vec![2.0, 2.0, 2.0])?;
859
860        // Update to create state
861        optimizer.update(&mut param1, &gradient)?;
862        optimizer.update(&mut param2, &gradient)?;
863
864        // Save state
865        let state_dict = optimizer.state_dict()?;
866        assert!(state_dict.contains_key("step"));
867
868        // Create new optimizer and load state
869        let mut new_optimizer = MicroAdam::new();
870        new_optimizer.load_state_dict(state_dict)?;
871
872        assert_eq!(new_optimizer.state().step, optimizer.state().step);
873
874        Ok(())
875    }
876
877    #[test]
878    fn test_memory_usage_tracking() -> Result<()> {
879        let mut config = MicroAdamConfig::default();
880        config.max_compression_error = 1.0; // Allow higher compression error for tests
881        let mut optimizer = MicroAdam::with_config(config);
882        let initial_usage = optimizer.memory_usage();
883
884        let gradient_data = vec![0.1; 1000];
885        let gradient = Tensor::new(gradient_data)?;
886        let mut parameter = Tensor::new(vec![1.0; 1000])?;
887        optimizer.update(&mut parameter, &gradient)?;
888
889        let after_usage = optimizer.memory_usage();
890        assert!(after_usage.total_bytes > initial_usage.total_bytes);
891        assert!(after_usage.momentum_elements > 0);
892        assert!(after_usage.variance_elements > 0);
893
894        Ok(())
895    }
896
897    #[test]
898    fn test_adaptive_compression_selection() {
899        let sparse_gradient = vec![0.0; 1000]; // Very sparse
900        let dense_gradient = vec![0.1; 1000]; // Dense
901
902        let config = MicroAdamConfig {
903            adaptive_compression: true,
904            compression_threshold: 1e-6,
905            ..Default::default()
906        };
907
908        let sparse_compression =
909            CompressedGradient::choose_adaptive_compression(&sparse_gradient, &config);
910        let dense_compression =
911            CompressedGradient::choose_adaptive_compression(&dense_gradient, &config);
912
913        // Should choose different methods for different gradient characteristics
914        // This test mainly ensures the selection logic runs without panicking
915        match sparse_compression {
916            CompressionType::Threshold
917            | CompressionType::TopK
918            | CompressionType::BlockWise
919            | CompressionType::Adaptive => {},
920        }
921
922        match dense_compression {
923            CompressionType::Threshold
924            | CompressionType::TopK
925            | CompressionType::BlockWise
926            | CompressionType::Adaptive => {},
927        }
928    }
929}