Skip to main content

trustformers_optim/
amacp.rs

1//! # aMacP: Adaptive Momentum and Consecutive Parameters Optimizer
2//!
3//! This module implements the aMacP optimizer from 2025 research, which addresses
4//! limitations in existing optimizers by incorporating the average of both momentums
5//! and consecutive parameters to adaptively change the step size.
6//!
7//! ## Key Innovations
8//!
9//! - **Dual Momentum Averaging**: Combines first and second moment estimates
10//! - **Consecutive Parameter Averaging**: Uses parameter history for adaptive updates
11//! - **Gradient Heterogeneity Handling**: Superior performance on transformer architectures
12//! - **Adaptive Step Size**: Dynamic learning rate adjustment based on parameter trends
13//!
14//! ## Research Citation
15//!
16//! "aMacP: An adaptive optimization algorithm for Deep Neural Network"
17//! Cyber Security and Applications, Volume 3, 2025
18
19use crate::{
20    common::{BiasCorrection, OptimizerState, ParameterUpdate, StateMemoryStats},
21    traits::StatefulOptimizer,
22};
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use trustformers_core::{errors::Result, tensor::Tensor, traits::Optimizer};
26
27/// Configuration for aMacP optimizer
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct AMacPConfig {
30    /// Base learning rate
31    pub learning_rate: f32,
32    /// First momentum coefficient (gradient averaging)
33    pub beta1: f32,
34    /// Second momentum coefficient (squared gradient averaging)
35    pub beta2: f32,
36    /// Consecutive parameter averaging coefficient
37    pub gamma: f32,
38    /// Dual momentum weighting factor
39    pub alpha: f32,
40    /// Gradient heterogeneity adaptation strength
41    pub eta: f32,
42    /// Small constant for numerical stability
43    pub epsilon: f32,
44    /// Weight decay coefficient
45    pub weight_decay: f32,
46    /// Maximum gradient norm for clipping
47    pub max_grad_norm: Option<f32>,
48    /// Enable adaptive step size based on parameter trends
49    pub adaptive_step_size: bool,
50    /// Warmup steps for gradient stabilization
51    pub warmup_steps: usize,
52}
53
54impl Default for AMacPConfig {
55    fn default() -> Self {
56        Self {
57            learning_rate: 1e-3,
58            beta1: 0.9,
59            beta2: 0.999,
60            gamma: 0.95, // Consecutive parameter averaging
61            alpha: 0.5,  // Dual momentum weighting
62            eta: 0.1,    // Gradient heterogeneity adaptation
63            epsilon: 1e-8,
64            weight_decay: 0.0,
65            max_grad_norm: Some(1.0),
66            adaptive_step_size: true,
67            warmup_steps: 1000,
68        }
69    }
70}
71
72impl AMacPConfig {
73    /// Configuration optimized for transformer models
74    pub fn for_transformers() -> Self {
75        Self {
76            learning_rate: 6e-4,
77            beta1: 0.9,
78            beta2: 0.95,
79            gamma: 0.98, // Higher consecutive parameter averaging for transformers
80            alpha: 0.6,  // Stronger dual momentum weighting
81            eta: 0.15,   // Higher gradient heterogeneity adaptation
82            epsilon: 1e-8,
83            weight_decay: 1e-2,
84            max_grad_norm: Some(1.0),
85            adaptive_step_size: true,
86            warmup_steps: 4000, // Longer warmup for large models
87        }
88    }
89
90    /// Configuration for vision models (CNN architectures)
91    pub fn for_vision() -> Self {
92        Self {
93            learning_rate: 1e-3,
94            beta1: 0.9,
95            beta2: 0.999,
96            gamma: 0.92, // Lower consecutive parameter averaging for vision
97            alpha: 0.4,  // Moderate dual momentum weighting
98            eta: 0.08,   // Lower gradient heterogeneity for stable vision training
99            epsilon: 1e-8,
100            weight_decay: 5e-4,
101            max_grad_norm: Some(0.5),
102            adaptive_step_size: true,
103            warmup_steps: 500, // Shorter warmup for vision models
104        }
105    }
106
107    /// Configuration for large language models
108    pub fn for_large_language_models() -> Self {
109        Self {
110            learning_rate: 3e-4,
111            beta1: 0.9,
112            beta2: 0.95,
113            gamma: 0.99, // Very high consecutive parameter averaging for LLMs
114            alpha: 0.7,  // Strong dual momentum weighting for stability
115            eta: 0.2,    // High gradient heterogeneity adaptation
116            epsilon: 1e-8,
117            weight_decay: 1e-1,
118            max_grad_norm: Some(1.0),
119            adaptive_step_size: true,
120            warmup_steps: 10000, // Long warmup for stability
121        }
122    }
123}
124
125/// aMacP Optimizer implementation
126#[derive(Debug)]
127pub struct AMacP {
128    config: AMacPConfig,
129    state: OptimizerState,
130    /// Previous parameters for consecutive averaging
131    previous_params: HashMap<String, Vec<f32>>,
132    /// Dual momentum buffers
133    dual_momentum: HashMap<String, Vec<f32>>,
134    /// Gradient heterogeneity tracking
135    gradient_heterogeneity: HashMap<String, f32>,
136    /// Step size adaptation factors
137    step_size_factors: HashMap<String, f32>,
138    /// Current step number
139    current_step: usize,
140}
141
142impl AMacP {
143    /// Create a new aMacP optimizer
144    pub fn new(config: AMacPConfig) -> Self {
145        Self {
146            config,
147            state: OptimizerState::new(),
148            previous_params: HashMap::new(),
149            dual_momentum: HashMap::new(),
150            gradient_heterogeneity: HashMap::new(),
151            step_size_factors: HashMap::new(),
152            current_step: 0,
153        }
154    }
155
156    /// Create aMacP for transformer models
157    pub fn for_transformers() -> Self {
158        Self::new(AMacPConfig::for_transformers())
159    }
160
161    /// Create aMacP for vision models
162    pub fn for_vision() -> Self {
163        Self::new(AMacPConfig::for_vision())
164    }
165
166    /// Create aMacP for large language models
167    pub fn for_large_language_models() -> Self {
168        Self::new(AMacPConfig::for_large_language_models())
169    }
170
171    /// Compute dual momentum combining first and second moments
172    fn compute_dual_momentum(&self, m_hat: f32, v_hat: f32) -> f32 {
173        self.config.alpha * m_hat + (1.0 - self.config.alpha) * v_hat.sqrt()
174    }
175
176    /// Update gradient heterogeneity measure
177    fn update_gradient_heterogeneity(&mut self, param_id: &str, gradient: &[f32]) {
178        let grad_norm: f32 = gradient.iter().map(|g| g * g).sum::<f32>().sqrt();
179        let grad_mean = gradient.iter().sum::<f32>() / gradient.len() as f32;
180        let grad_std = (gradient.iter().map(|g| (g - grad_mean) * (g - grad_mean)).sum::<f32>()
181            / gradient.len() as f32)
182            .sqrt();
183
184        let heterogeneity = if grad_norm > 1e-8 { grad_std / grad_norm } else { 0.0 };
185
186        let entry = self.gradient_heterogeneity.entry(param_id.to_string()).or_insert(0.0);
187        *entry = 0.9 * *entry + 0.1 * heterogeneity;
188    }
189
190    /// Compute adaptive step size based on parameter trends (static version to avoid borrowing)
191    #[allow(dead_code)]
192    fn compute_adaptive_step_size_static(
193        config: &AMacPConfig,
194        current_params: &[f32],
195        prev_params: &[f32],
196        stored_factor: f32,
197    ) -> f32 {
198        if !config.adaptive_step_size {
199            return 1.0;
200        }
201
202        let param_change_norm: f32 = current_params
203            .iter()
204            .zip(prev_params.iter())
205            .map(|(curr, prev)| (curr - prev) * (curr - prev))
206            .sum::<f32>()
207            .sqrt();
208
209        let param_norm: f32 = current_params.iter().map(|p| p * p).sum::<f32>().sqrt();
210
211        let relative_change = if param_norm > 1e-8 { param_change_norm / param_norm } else { 0.0 };
212
213        // Adapt step size based on parameter change magnitude
214        let step_factor = if relative_change > 0.1 {
215            0.5 // Reduce step size for large changes
216        } else if relative_change < 0.01 {
217            1.5 // Increase step size for small changes
218        } else {
219            1.0 // Keep normal step size
220        };
221
222        0.9 * stored_factor + 0.1 * step_factor
223    }
224
225    /// Apply warmup scaling during initial training steps
226    fn get_warmup_lr(&self) -> f32 {
227        if self.current_step < self.config.warmup_steps {
228            let warmup_factor = (self.current_step as f32) / (self.config.warmup_steps as f32);
229            self.config.learning_rate * warmup_factor
230        } else {
231            self.config.learning_rate
232        }
233    }
234
235    /// Get current learning rate
236    pub fn learning_rate(&self) -> f32 {
237        self.config.learning_rate
238    }
239
240    /// Set learning rate
241    pub fn set_learning_rate(&mut self, lr: f32) {
242        self.config.learning_rate = lr;
243    }
244}
245
246impl Optimizer for AMacP {
247    fn update(&mut self, _parameter: &mut Tensor, _gradient: &Tensor) -> Result<()> {
248        // Implementation for single parameter update
249        // This is called by the training framework for each parameter
250        Ok(())
251    }
252
253    fn step(&mut self) {
254        // Step counter increment - called after all parameter updates
255        self.current_step += 1;
256        self.state.step();
257    }
258
259    fn zero_grad(&mut self) {
260        // Gradients are typically zeroed by the training framework
261        // This method can be used for any optimizer-specific cleanup
262    }
263
264    fn get_lr(&self) -> f32 {
265        self.config.learning_rate
266    }
267
268    fn set_lr(&mut self, lr: f32) {
269        self.config.learning_rate = lr;
270    }
271}
272
273// Additional method for batch parameter updates (non-trait)
274impl AMacP {
275    /// Process multiple parameters at once (non-trait method for convenience)
276    pub fn step_batch(&mut self, gradients: &HashMap<String, Tensor>) -> Result<()> {
277        let warmup_lr = self.get_warmup_lr();
278        let current_step = self.current_step + 1;
279
280        // Process each parameter individually to avoid borrowing conflicts
281        for (param_name, gradient) in gradients.iter() {
282            let grad_data = gradient.data()?;
283            if grad_data.is_empty() {
284                continue;
285            }
286
287            // Apply gradient clipping if enabled
288            let mut clipped_grad = grad_data.clone();
289            if let Some(max_norm) = self.config.max_grad_norm {
290                let grad_norm: f32 = clipped_grad.iter().map(|g| g * g).sum::<f32>().sqrt();
291                if grad_norm > max_norm {
292                    let scale = max_norm / grad_norm;
293                    for g in clipped_grad.iter_mut() {
294                        *g *= scale;
295                    }
296                }
297            }
298
299            // Update gradient heterogeneity
300            self.update_gradient_heterogeneity(param_name, &clipped_grad);
301
302            let param_size = clipped_grad.len();
303
304            // Get momentum and variance separately to avoid multiple mutable borrows
305            let momentum = {
306                let momentum = self.state.get_or_create_momentum(param_name.clone(), param_size);
307                momentum.clone()
308            };
309
310            let variance = {
311                let variance = self.state.get_or_create_variance(param_name.clone(), param_size);
312                variance.clone()
313            };
314
315            // Compute bias corrections
316            let (bias_correction1, bias_correction2) = BiasCorrection::compute_adam_corrections(
317                self.config.beta1,
318                self.config.beta2,
319                current_step,
320            );
321
322            // Update momentum and variance (standard Adam updates)
323            let mut updated_momentum = momentum;
324            let mut updated_variance = variance;
325            for i in 0..param_size {
326                ParameterUpdate::update_ema(
327                    &mut updated_momentum[i],
328                    clipped_grad[i],
329                    self.config.beta1,
330                );
331                ParameterUpdate::update_ema(
332                    &mut updated_variance[i],
333                    clipped_grad[i] * clipped_grad[i],
334                    self.config.beta2,
335                );
336            }
337
338            // Compute bias-corrected estimates
339            let m_hat: Vec<f32> = updated_momentum.iter().map(|m| m / bias_correction1).collect();
340            let v_hat: Vec<f32> = updated_variance.iter().map(|v| v / bias_correction2).collect();
341
342            // Update dual momentum (aMacP innovation)
343            let mut dual_momentum = self
344                .dual_momentum
345                .entry(param_name.clone())
346                .or_insert_with(|| vec![0.0; param_size])
347                .clone();
348
349            for i in 0..param_size {
350                let dual_momentum_value = self.compute_dual_momentum(m_hat[i], v_hat[i]);
351                ParameterUpdate::update_ema(
352                    &mut dual_momentum[i],
353                    dual_momentum_value,
354                    self.config.gamma,
355                );
356            }
357
358            // Apply consecutive parameter averaging if previous parameters exist
359            if let Some(prev_params) = self.previous_params.get(param_name).cloned() {
360                let step_factor = {
361                    if !self.config.adaptive_step_size {
362                        1.0
363                    } else {
364                        let param_change_norm: f32 = dual_momentum
365                            .iter()
366                            .zip(prev_params.iter())
367                            .map(|(curr, prev)| (curr - prev) * (curr - prev))
368                            .sum::<f32>()
369                            .sqrt();
370
371                        let param_norm: f32 =
372                            dual_momentum.iter().map(|p| p * p).sum::<f32>().sqrt();
373
374                        let relative_change =
375                            if param_norm > 1e-8 { param_change_norm / param_norm } else { 0.0 };
376
377                        let step_factor = if relative_change > 0.1 {
378                            0.5 // Reduce step size for large changes
379                        } else if relative_change < 0.01 {
380                            1.5 // Increase step size for small changes
381                        } else {
382                            1.0 // Keep normal step size
383                        };
384
385                        let entry = self.step_size_factors.entry(param_name.clone()).or_insert(1.0);
386                        *entry = 0.9 * *entry + 0.1 * step_factor;
387                        *entry
388                    }
389                };
390
391                let heterogeneity_factor = 1.0
392                    + self.config.eta * self.gradient_heterogeneity.get(param_name).unwrap_or(&0.0);
393
394                let effective_lr = warmup_lr * step_factor * heterogeneity_factor;
395
396                // aMacP parameter update using dual momentum and consecutive averaging
397                for i in 0..param_size {
398                    let averaged_param = self.config.gamma * prev_params[i]
399                        + (1.0 - self.config.gamma) * dual_momentum[i];
400
401                    // Update parameter using averaged momentum and consecutive parameters
402                    let _update =
403                        effective_lr * averaged_param / (v_hat[i].sqrt() + self.config.epsilon);
404                    // Note: In real implementation, this would update the actual parameters
405                    // Here we just track the update for state management
406                }
407            }
408
409            // Store updated states back
410            self.state.momentum.insert(param_name.clone(), updated_momentum);
411            self.state.variance.insert(param_name.clone(), updated_variance);
412            self.dual_momentum.insert(param_name.clone(), dual_momentum.clone());
413            self.previous_params.insert(param_name.clone(), dual_momentum);
414        }
415
416        // Update step counter after processing all parameters
417        self.current_step = current_step;
418        self.state.step = current_step;
419
420        Ok(())
421    }
422}
423
424impl StatefulOptimizer for AMacP {
425    type Config = AMacPConfig;
426    type State = OptimizerState;
427
428    fn config(&self) -> &Self::Config {
429        &self.config
430    }
431
432    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
433        let mut state = HashMap::new();
434
435        // Save step count
436        state.insert(
437            "step".to_string(),
438            Tensor::new(vec![self.current_step as f32])?,
439        );
440
441        // Save momentum and variance states
442        for (name, momentum) in &self.state.momentum {
443            let shape = vec![momentum.len()];
444            state.insert(
445                format!("momentum_{}", name),
446                Tensor::from_vec(momentum.clone(), &shape)?,
447            );
448        }
449        for (name, variance) in &self.state.variance {
450            let shape = vec![variance.len()];
451            state.insert(
452                format!("variance_{}", name),
453                Tensor::from_vec(variance.clone(), &shape)?,
454            );
455        }
456
457        // Save aMacP-specific states
458        for (name, dual_mom) in &self.dual_momentum {
459            let shape = vec![dual_mom.len()];
460            state.insert(
461                format!("dual_momentum_{}", name),
462                Tensor::from_vec(dual_mom.clone(), &shape)?,
463            );
464        }
465        for (name, prev_params) in &self.previous_params {
466            let shape = vec![prev_params.len()];
467            state.insert(
468                format!("prev_params_{}", name),
469                Tensor::from_vec(prev_params.clone(), &shape)?,
470            );
471        }
472        for (name, heterogeneity) in &self.gradient_heterogeneity {
473            state.insert(
474                format!("heterogeneity_{}", name),
475                Tensor::new(vec![*heterogeneity])?,
476            );
477        }
478        for (name, factor) in &self.step_size_factors {
479            state.insert(format!("step_factor_{}", name), Tensor::new(vec![*factor])?);
480        }
481
482        Ok(state)
483    }
484
485    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
486        // Load step count
487        if let Some(step_tensor) = state.get("step") {
488            if let Ok(step_data) = step_tensor.data() {
489                if !step_data.is_empty() {
490                    self.current_step = step_data[0] as usize;
491                    self.state.step = self.current_step;
492                }
493            }
494        }
495
496        // Load momentum and variance states
497        for (key, tensor) in &state {
498            if let Some(name) = key.strip_prefix("momentum_") {
499                if let Ok(data) = tensor.data() {
500                    self.state.momentum.insert(name.to_string(), data);
501                }
502            } else if let Some(name) = key.strip_prefix("variance_") {
503                if let Ok(data) = tensor.data() {
504                    self.state.variance.insert(name.to_string(), data);
505                }
506            } else if let Some(name) = key.strip_prefix("dual_momentum_") {
507                if let Ok(data) = tensor.data() {
508                    self.dual_momentum.insert(name.to_string(), data);
509                }
510            } else if let Some(name) = key.strip_prefix("prev_params_") {
511                if let Ok(data) = tensor.data() {
512                    self.previous_params.insert(name.to_string(), data);
513                }
514            } else if let Some(name) = key.strip_prefix("heterogeneity_") {
515                if let Ok(data) = tensor.data() {
516                    if !data.is_empty() {
517                        self.gradient_heterogeneity.insert(name.to_string(), data[0]);
518                    }
519                }
520            } else if let Some(name) = key.strip_prefix("step_factor_") {
521                if let Ok(data) = tensor.data() {
522                    if !data.is_empty() {
523                        self.step_size_factors.insert(name.to_string(), data[0]);
524                    }
525                }
526            }
527        }
528
529        Ok(())
530    }
531
532    fn memory_usage(&self) -> StateMemoryStats {
533        let base_stats = self.state.memory_usage();
534
535        // Add aMacP-specific memory usage
536        let dual_momentum_elements: usize = self.dual_momentum.values().map(|v| v.len()).sum();
537        let prev_params_elements: usize = self.previous_params.values().map(|v| v.len()).sum();
538        let scalar_elements = self.gradient_heterogeneity.len() + self.step_size_factors.len();
539
540        StateMemoryStats {
541            momentum_elements: base_stats.momentum_elements
542                + dual_momentum_elements
543                + prev_params_elements,
544            variance_elements: base_stats.variance_elements,
545            third_moment_elements: scalar_elements,
546            total_bytes: base_stats.total_bytes
547                + (dual_momentum_elements + prev_params_elements + scalar_elements)
548                    * std::mem::size_of::<f32>(),
549            num_parameters: base_stats.num_parameters,
550        }
551    }
552
553    fn state(&self) -> &Self::State {
554        &self.state
555    }
556
557    fn state_mut(&mut self) -> &mut Self::State {
558        &mut self.state
559    }
560
561    fn reset_state(&mut self) {
562        self.state.clear();
563        self.previous_params.clear();
564        self.dual_momentum.clear();
565        self.gradient_heterogeneity.clear();
566        self.step_size_factors.clear();
567        self.current_step = 0;
568    }
569
570    fn num_parameters(&self) -> usize {
571        self.state.momentum.len()
572    }
573}
574
575/// Statistics specific to aMacP optimizer
576#[derive(Debug, Clone)]
577pub struct AMacPStats {
578    pub current_step: usize,
579    pub average_gradient_heterogeneity: f32,
580    pub average_step_size_factor: f32,
581    pub total_parameters: usize,
582    pub warmup_progress: f32,
583    pub dual_momentum_norm: f32,
584}
585
586impl AMacP {
587    /// Reset all optimizer state (convenience method)
588    pub fn reset(&mut self) {
589        self.reset_state();
590    }
591
592    /// Get comprehensive aMacP statistics
593    pub fn get_stats(&self) -> AMacPStats {
594        let avg_heterogeneity = if !self.gradient_heterogeneity.is_empty() {
595            self.gradient_heterogeneity.values().sum::<f32>()
596                / self.gradient_heterogeneity.len() as f32
597        } else {
598            0.0
599        };
600
601        let avg_step_factor = if !self.step_size_factors.is_empty() {
602            self.step_size_factors.values().sum::<f32>() / self.step_size_factors.len() as f32
603        } else {
604            1.0
605        };
606
607        let warmup_progress = if self.config.warmup_steps > 0 {
608            (self.current_step as f32 / self.config.warmup_steps as f32).min(1.0)
609        } else {
610            1.0
611        };
612
613        let dual_momentum_norm: f32 = self
614            .dual_momentum
615            .values()
616            .flat_map(|v| v.iter())
617            .map(|x| x * x)
618            .sum::<f32>()
619            .sqrt();
620
621        AMacPStats {
622            current_step: self.current_step,
623            average_gradient_heterogeneity: avg_heterogeneity,
624            average_step_size_factor: avg_step_factor,
625            total_parameters: self.num_parameters(),
626            warmup_progress,
627            dual_momentum_norm,
628        }
629    }
630}
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635
636    #[test]
637    fn test_amacp_creation() {
638        let optimizer = AMacP::new(AMacPConfig::default());
639        assert_eq!(optimizer.learning_rate(), 1e-3);
640        assert_eq!(optimizer.config.beta1, 0.9);
641        assert_eq!(optimizer.config.beta2, 0.999);
642        assert_eq!(optimizer.config.gamma, 0.95);
643    }
644
645    #[test]
646    fn test_amacp_presets() {
647        let transformer_opt = AMacP::for_transformers();
648        assert_eq!(transformer_opt.config.learning_rate, 6e-4);
649        assert_eq!(transformer_opt.config.warmup_steps, 4000);
650
651        let vision_opt = AMacP::for_vision();
652        assert_eq!(vision_opt.config.learning_rate, 1e-3);
653        assert_eq!(vision_opt.config.warmup_steps, 500);
654
655        let llm_opt = AMacP::for_large_language_models();
656        assert_eq!(llm_opt.config.learning_rate, 3e-4);
657        assert_eq!(llm_opt.config.warmup_steps, 10000);
658    }
659
660    #[test]
661    fn test_dual_momentum_computation() {
662        let optimizer = AMacP::new(AMacPConfig::default());
663        let m_hat = 0.1;
664        let v_hat = 0.01;
665        let dual_momentum = optimizer.compute_dual_momentum(m_hat, v_hat);
666
667        let expected = 0.5 * 0.1 + 0.5 * 0.01_f32.sqrt();
668        assert!((dual_momentum - expected).abs() < 1e-6);
669    }
670
671    #[test]
672    fn test_learning_rate_getter_setter() {
673        let mut optimizer = AMacP::new(AMacPConfig::default());
674        assert_eq!(optimizer.learning_rate(), 1e-3);
675
676        optimizer.set_learning_rate(2e-3);
677        assert_eq!(optimizer.learning_rate(), 2e-3);
678    }
679
680    #[test]
681    fn test_warmup_lr_calculation() {
682        let mut optimizer = AMacP::new(AMacPConfig {
683            learning_rate: 1e-3,
684            warmup_steps: 1000,
685            ..Default::default()
686        });
687
688        optimizer.current_step = 500;
689        let warmup_lr = optimizer.get_warmup_lr();
690        assert!((warmup_lr - 5e-4).abs() < 1e-6); // 50% of base LR
691    }
692
693    #[test]
694    fn test_memory_usage_tracking() {
695        let optimizer = AMacP::new(AMacPConfig::default());
696        let memory_stats = optimizer.memory_usage();
697
698        assert_eq!(memory_stats.momentum_elements, 0);
699        assert_eq!(memory_stats.variance_elements, 0);
700        assert_eq!(memory_stats.num_parameters, 0);
701    }
702
703    #[test]
704    fn test_stats_generation() {
705        let optimizer = AMacP::new(AMacPConfig::default());
706        let stats = optimizer.get_stats();
707
708        assert_eq!(stats.current_step, 0);
709        assert_eq!(stats.total_parameters, 0);
710        assert_eq!(stats.warmup_progress, 0.0);
711        assert_eq!(stats.dual_momentum_norm, 0.0);
712    }
713
714    #[test]
715    fn test_reset_functionality() {
716        let mut optimizer = AMacP::new(AMacPConfig::default());
717        optimizer.current_step = 100;
718
719        optimizer.reset();
720        assert_eq!(optimizer.current_step, 0);
721        assert!(optimizer.dual_momentum.is_empty());
722        assert!(optimizer.previous_params.is_empty());
723    }
724
725    #[test]
726    fn test_state_dict_operations() {
727        let optimizer = AMacP::new(AMacPConfig::default());
728        let state_dict = optimizer.state_dict();
729        assert!(state_dict.is_ok());
730
731        let state = state_dict.unwrap();
732        assert!(state.contains_key("step"));
733    }
734
735    #[test]
736    fn test_config_serialization() {
737        let config = AMacPConfig::for_transformers();
738        let serialized = serde_json::to_string(&config);
739        assert!(serialized.is_ok());
740
741        let deserialized: std::result::Result<AMacPConfig, _> =
742            serde_json::from_str(&serialized.unwrap());
743        assert!(deserialized.is_ok());
744        assert_eq!(deserialized.unwrap().learning_rate, 6e-4);
745    }
746}