Skip to main content

trustformers_optim/
adamax_plus.rs

1//! # AdaMax+ Optimizer
2//!
3//! Implementation of AdaMax+ (Enhanced AdaMax with Momentum Scheduling), an advanced
4//! variant of the AdaMax optimizer that incorporates adaptive momentum scheduling,
5//! improved numerical stability, and enhanced convergence properties.
6//!
7//! ## Key Features
8//!
9//! - **Adaptive Momentum Scheduling**: Dynamic β₁ adjustment based on training progress
10//! - **Enhanced Numerical Stability**: Improved handling of extreme gradient values
11//! - **Convergence Acceleration**: Advanced momentum scheduling for faster convergence
12//! - **Variance-Aware Updates**: Optional variance tracking for better adaptation
13//!
14//! ## Algorithm Description
15//!
16//! AdaMax+ extends the standard AdaMax algorithm with:
17//! 1. Adaptive momentum scheduling based on gradient variance
18//! 2. Enhanced infinity norm computation with outlier handling
19//! 3. Learning rate warm-up and scheduling capabilities
20//! 4. Optional bias correction improvements
21//!
22//! The AdaMax+ update rule:
23//! ```text
24//! # Adaptive momentum parameter
25//! β₁_t = β₁_base * (1 - α * variance_factor)
26//!
27//! # First moment estimation
28//! m_t = β₁_t * m_{t-1} + (1 - β₁_t) * g_t
29//!
30//! # Infinity norm with outlier handling
31//! u_t = max(β₂ * u_{t-1}, |g_t|_∞)
32//!
33//! # Enhanced bias correction
34//! m̂_t = m_t / (1 - β₁_t^t)
35//!
36//! # Parameter update with warm-up
37//! lr_t = lr * min(1, t / warmup_steps)
38//! θ_t = θ_{t-1} - (lr_t / u_t) * m̂_t
39//! ```
40//!
41//! ## Usage Example
42//!
43//! ```rust,no_run
44//! use trustformers_optim::{AdaMaxPlus, AdaMaxPlusConfig};
45//! use trustformers_core::traits::Optimizer;
46//!
47//! // Create AdaMax+ optimizer with default settings
48//! let mut optimizer = AdaMaxPlus::new(
49//!     1e-3,      // learning rate
50//!     (0.9, 0.999), // (β₁, β₂)
51//!     1e-8,      // epsilon
52//!     0.01,      // weight decay
53//! );
54//!
55//! // Or create with advanced configuration
56//! let config = AdaMaxPlusConfig::new()
57//!     .learning_rate(0.002)
58//!     .betas((0.95, 0.999))
59//!     .enable_adaptive_momentum(true)
60//!     .warmup_steps(1000)
61//!     .variance_tracking(true);
62//!
63//! let mut optimizer = AdaMaxPlus::from_config(config);
64//! ```
65//!
66//! ## Research Foundation
67//!
68//! This implementation builds on:
69//! - Original AdaMax algorithm (Kingma & Ba, 2014)
70//! - Adaptive momentum scheduling techniques
71//! - Recent advances in optimization stability and convergence
72
73use crate::common::{OptimizerState, StateMemoryStats};
74use crate::traits::StatefulOptimizer;
75use serde::{Deserialize, Serialize};
76use std::collections::HashMap;
77use trustformers_core::errors::Result;
78use trustformers_core::{tensor::Tensor, traits::Optimizer};
79
80/// Configuration for AdaMax+ optimizer
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct AdaMaxPlusConfig {
83    /// Learning rate
84    pub learning_rate: f32,
85    /// Exponential decay rates for moment estimates (β₁, β₂)
86    pub betas: (f32, f32),
87    /// Small constant for numerical stability
88    pub epsilon: f32,
89    /// Weight decay (L2 regularization)
90    pub weight_decay: f32,
91    /// Enable adaptive momentum scheduling
92    pub adaptive_momentum: bool,
93    /// Momentum adaptation strength
94    pub momentum_adaptation_strength: f32,
95    /// Number of warm-up steps
96    pub warmup_steps: usize,
97    /// Enable variance tracking for momentum adaptation
98    pub variance_tracking: bool,
99    /// Bias correction improvement factor
100    pub bias_correction_factor: f32,
101    /// Outlier handling threshold for infinity norm
102    pub outlier_threshold: f32,
103}
104
105impl Default for AdaMaxPlusConfig {
106    fn default() -> Self {
107        Self {
108            learning_rate: 0.001,
109            betas: (0.9, 0.999),
110            epsilon: 1e-8,
111            weight_decay: 0.0,
112            adaptive_momentum: true,
113            momentum_adaptation_strength: 0.1,
114            warmup_steps: 0,
115            variance_tracking: true,
116            bias_correction_factor: 1.0,
117            outlier_threshold: 10.0,
118        }
119    }
120}
121
122impl AdaMaxPlusConfig {
123    /// Create a new configuration with default values
124    pub fn new() -> Self {
125        Self::default()
126    }
127
128    /// Set learning rate
129    pub fn learning_rate(mut self, lr: f32) -> Self {
130        self.learning_rate = lr;
131        self
132    }
133
134    /// Set beta parameters
135    pub fn betas(mut self, betas: (f32, f32)) -> Self {
136        self.betas = betas;
137        self
138    }
139
140    /// Set epsilon
141    pub fn epsilon(mut self, eps: f32) -> Self {
142        self.epsilon = eps;
143        self
144    }
145
146    /// Set weight decay
147    pub fn weight_decay(mut self, wd: f32) -> Self {
148        self.weight_decay = wd;
149        self
150    }
151
152    /// Enable/disable adaptive momentum
153    pub fn enable_adaptive_momentum(mut self, enabled: bool) -> Self {
154        self.adaptive_momentum = enabled;
155        self
156    }
157
158    /// Set momentum adaptation strength
159    pub fn momentum_adaptation_strength(mut self, strength: f32) -> Self {
160        self.momentum_adaptation_strength = strength;
161        self
162    }
163
164    /// Set warmup steps
165    pub fn warmup_steps(mut self, steps: usize) -> Self {
166        self.warmup_steps = steps;
167        self
168    }
169
170    /// Enable/disable variance tracking
171    pub fn variance_tracking(mut self, enabled: bool) -> Self {
172        self.variance_tracking = enabled;
173        self
174    }
175
176    /// Set bias correction factor
177    pub fn bias_correction_factor(mut self, factor: f32) -> Self {
178        self.bias_correction_factor = factor;
179        self
180    }
181
182    /// Set outlier threshold
183    pub fn outlier_threshold(mut self, threshold: f32) -> Self {
184        self.outlier_threshold = threshold;
185        self
186    }
187}
188
189/// State for a single parameter in AdaMax+ optimizer
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct AdaMaxPlusParameterState {
192    /// First moment estimate (momentum)
193    pub momentum: Vec<f32>,
194    /// Infinity norm estimate
195    pub inf_norm: f32,
196    /// Gradient variance (if variance tracking is enabled)
197    pub gradient_variance: f32,
198    /// Step count for this parameter
199    pub step_count: usize,
200    /// Exponential moving average of gradients for variance computation
201    pub grad_ema: Option<Vec<f32>>,
202    /// Exponential moving average of squared gradients for variance computation
203    pub grad_sq_ema: Option<Vec<f32>>,
204}
205
206/// AdaMax+ optimizer state
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct AdaMaxPlusState {
209    /// Common optimizer state (momentum, variance, etc.)
210    pub state: OptimizerState,
211    /// Configuration
212    pub config: AdaMaxPlusConfig,
213    /// Global step count
214    pub step_count: usize,
215    /// Parameter-specific infinity norms
216    pub inf_norms: HashMap<String, f32>,
217    /// Gradient variances (if tracking enabled)
218    pub gradient_variances: HashMap<String, f32>,
219    /// Parameter step counts
220    pub param_step_counts: HashMap<String, usize>,
221}
222
223impl AdaMaxPlusState {
224    /// Create new optimizer state
225    pub fn new(config: AdaMaxPlusConfig) -> Self {
226        Self {
227            state: OptimizerState::new(),
228            config,
229            step_count: 0,
230            inf_norms: HashMap::new(),
231            gradient_variances: HashMap::new(),
232            param_step_counts: HashMap::new(),
233        }
234    }
235
236    /// Get memory usage in bytes
237    pub fn memory_usage(&self) -> usize {
238        // Calculate approximate memory usage
239        let momentum_size = self.state.momentum.values().map(|v| v.len() * 4).sum::<usize>(); // 4 bytes per f32
240        let variance_size = self.state.variance.values().map(|v| v.len() * 4).sum::<usize>();
241        let inf_norms_size = self.inf_norms.len() * 4; // 4 bytes per f32
242        let gradient_variances_size = self.gradient_variances.len() * 4;
243        let param_step_counts_size = self.param_step_counts.len() * 8; // 8 bytes per usize
244
245        momentum_size
246            + variance_size
247            + inf_norms_size
248            + gradient_variances_size
249            + param_step_counts_size
250    }
251}
252
253/// AdaMax+ optimizer implementation
254pub struct AdaMaxPlus {
255    state: AdaMaxPlusState,
256}
257
258impl AdaMaxPlus {
259    /// Create a new AdaMax+ optimizer with basic parameters
260    pub fn new(learning_rate: f32, betas: (f32, f32), epsilon: f32, weight_decay: f32) -> Self {
261        let config = AdaMaxPlusConfig {
262            learning_rate,
263            betas,
264            epsilon,
265            weight_decay,
266            ..Default::default()
267        };
268
269        Self {
270            state: AdaMaxPlusState::new(config),
271        }
272    }
273
274    /// Create AdaMax+ optimizer from configuration
275    pub fn from_config(config: AdaMaxPlusConfig) -> Self {
276        Self {
277            state: AdaMaxPlusState::new(config),
278        }
279    }
280
281    /// Create AdaMax+ optimized for large language models
282    pub fn for_large_models() -> Self {
283        let config = AdaMaxPlusConfig::new()
284            .learning_rate(0.0002)
285            .betas((0.9, 0.999))
286            .enable_adaptive_momentum(true)
287            .warmup_steps(10000)
288            .variance_tracking(true)
289            .weight_decay(0.1);
290
291        Self::from_config(config)
292    }
293
294    /// Create AdaMax+ optimized for fast training
295    pub fn for_fast_training() -> Self {
296        let config = AdaMaxPlusConfig::new()
297            .learning_rate(0.003)
298            .betas((0.95, 0.999))
299            .enable_adaptive_momentum(true)
300            .momentum_adaptation_strength(0.2)
301            .warmup_steps(500);
302
303        Self::from_config(config)
304    }
305
306    /// Create AdaMax+ optimized for stable training
307    pub fn for_stable_training() -> Self {
308        let config = AdaMaxPlusConfig::new()
309            .learning_rate(0.001)
310            .betas((0.9, 0.999))
311            .enable_adaptive_momentum(false)
312            .variance_tracking(false)
313            .bias_correction_factor(1.2)
314            .outlier_threshold(5.0);
315
316        Self::from_config(config)
317    }
318
319    /// Compute adaptive momentum parameter
320    fn compute_adaptive_momentum(&self, param_id: String) -> f32 {
321        if !self.state.config.adaptive_momentum {
322            return self.state.config.betas.0;
323        }
324
325        let base_beta1 = self.state.config.betas.0;
326        let adaptation_strength = self.state.config.momentum_adaptation_strength;
327
328        // Use gradient variance to adapt momentum
329        let variance_factor = if self.state.config.variance_tracking {
330            self.state.gradient_variances.get(&param_id).copied().unwrap_or(0.0).min(1.0)
331        } else {
332            0.0
333        };
334
335        // Adaptive momentum: higher variance -> lower momentum for better adaptation
336        let adaptive_beta1 = base_beta1 * (1.0 - adaptation_strength * variance_factor);
337        adaptive_beta1.clamp(0.1, 0.99) // Clamp to reasonable range
338    }
339
340    /// Compute learning rate with warm-up
341    fn compute_effective_learning_rate(&self) -> f32 {
342        let base_lr = self.state.config.learning_rate;
343
344        if self.state.config.warmup_steps == 0 {
345            return base_lr;
346        }
347
348        let warmup_factor = if self.state.step_count <= self.state.config.warmup_steps {
349            self.state.step_count as f32 / self.state.config.warmup_steps as f32
350        } else {
351            1.0
352        };
353
354        base_lr * warmup_factor
355    }
356
357    /// Update gradient variance tracking
358    fn update_gradient_variance(&mut self, param_id: String, gradient: &Tensor) -> Result<()> {
359        if !self.state.config.variance_tracking {
360            return Ok(());
361        }
362
363        let beta1 = self.state.config.betas.0;
364        let beta2 = self.state.config.betas.1;
365
366        let gradient_data = gradient.data()?;
367        let param_size = gradient_data.len();
368
369        // Get or initialize variance tracking buffers
370        let grad_ema = self
371            .state
372            .state
373            .get_or_create_momentum(format!("{}_grad_ema", param_id), param_size)
374            .clone();
375        let grad_sq_ema = self
376            .state
377            .state
378            .get_or_create_variance(format!("{}_grad_sq_ema", param_id), param_size)
379            .clone();
380
381        // Update gradient EMA: m = β₁ * m + (1 - β₁) * g
382        let updated_grad_ema: Vec<f32> = grad_ema
383            .iter()
384            .zip(gradient_data.iter())
385            .map(|(&m, &g)| beta1 * m + (1.0 - beta1) * g)
386            .collect();
387
388        // Update squared gradient EMA: v = β₂ * v + (1 - β₂) * g²
389        let updated_grad_sq_ema: Vec<f32> = grad_sq_ema
390            .iter()
391            .zip(gradient_data.iter())
392            .map(|(&v, &g)| beta2 * v + (1.0 - beta2) * g * g)
393            .collect();
394
395        // Compute variance: Var[g] = E[g²] - E[g]²
396        let variance: f32 = updated_grad_sq_ema
397            .iter()
398            .zip(updated_grad_ema.iter())
399            .map(|(&sq_ema, &ema)| sq_ema - ema * ema)
400            .sum::<f32>()
401            / param_size as f32;
402
403        // Store updated values
404        self.state
405            .state
406            .momentum
407            .insert(format!("{}_grad_ema", param_id), updated_grad_ema);
408        self.state
409            .state
410            .variance
411            .insert(format!("{}_grad_sq_ema", param_id), updated_grad_sq_ema);
412        self.state.gradient_variances.insert(param_id, variance);
413
414        Ok(())
415    }
416}
417
418impl Optimizer for AdaMaxPlus {
419    fn step(&mut self) {
420        // Default step implementation - parameters are updated via update() calls
421    }
422
423    fn zero_grad(&mut self) {
424        // Clear gradients - implementation specific to the framework
425    }
426
427    fn update(&mut self, parameter: &mut Tensor, gradient: &Tensor) -> Result<()> {
428        // Get parameter data (bind to variable before taking pointer)
429        let param_data = parameter.data()?;
430        let param_id = format!("{:p}", param_data.as_ptr());
431        let param_size = param_data.len();
432        self.state.step_count += 1;
433
434        // Get or initialize momentum using OptimizerState methods
435        let momentum_data = {
436            let momentum_buffer =
437                self.state.state.get_or_create_momentum(param_id.clone(), param_size);
438            momentum_buffer.clone()
439        };
440
441        // Update gradient variance if enabled
442        if self.state.config.variance_tracking {
443            self.update_gradient_variance(param_id.clone(), gradient)?;
444        }
445
446        // Apply weight decay to gradient if specified
447        let effective_gradient = if self.state.config.weight_decay > 0.0 {
448            gradient.add(&parameter.mul_scalar(self.state.config.weight_decay)?)?
449        } else {
450            gradient.clone()
451        };
452
453        // Get adaptive momentum parameter
454        let adaptive_beta1 = self.compute_adaptive_momentum(param_id.clone());
455        let beta2 = self.state.config.betas.1;
456
457        // Update momentum: m_t = β₁_adaptive * m_{t-1} + (1 - β₁_adaptive) * g_t
458        let gradient_data = effective_gradient.data()?;
459        let updated_momentum: Vec<f32> = momentum_data
460            .iter()
461            .zip(gradient_data.iter())
462            .map(|(&m, &g)| adaptive_beta1 * m + (1.0 - adaptive_beta1) * g)
463            .collect();
464
465        // Update infinity norm with outlier handling
466        let grad_inf_norm = gradient_data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
467        let clamped_grad_norm = grad_inf_norm.min(self.state.config.outlier_threshold);
468        let current_inf_norm = self.state.inf_norms.get(&param_id).copied().unwrap_or(0.0);
469        let new_inf_norm = (beta2 * current_inf_norm).max(clamped_grad_norm);
470        self.state.inf_norms.insert(param_id.clone(), new_inf_norm);
471
472        // Get parameter step count
473        let step_count = self.state.param_step_counts.entry(param_id.clone()).or_insert(0);
474        *step_count += 1;
475
476        // Enhanced bias correction
477        let bias_correction = 1.0 - adaptive_beta1.powi(*step_count as i32);
478        let bias_corrected_momentum: Vec<f32> = updated_momentum
479            .iter()
480            .map(|&m| m / (bias_correction * self.state.config.bias_correction_factor))
481            .collect();
482
483        // Compute effective learning rate with warm-up
484        let effective_lr = self.compute_effective_learning_rate();
485
486        // Compute step size with numerical stability
487        let step_size = effective_lr / (new_inf_norm + self.state.config.epsilon);
488
489        // Update parameters: θ_t = θ_{t-1} - step_size * m̂_t
490        let param_data = parameter.data()?;
491        let updated_params: Vec<f32> = param_data
492            .iter()
493            .zip(bias_corrected_momentum.iter())
494            .map(|(&p, &m)| p - step_size * m)
495            .collect();
496
497        *parameter = Tensor::new(updated_params)?;
498
499        // Store updated momentum
500        self.state.state.momentum.insert(param_id, updated_momentum);
501
502        Ok(())
503    }
504
505    fn set_lr(&mut self, lr: f32) {
506        self.state.config.learning_rate = lr;
507    }
508
509    fn get_lr(&self) -> f32 {
510        self.state.config.learning_rate
511    }
512}
513
514impl StatefulOptimizer for AdaMaxPlus {
515    type Config = AdaMaxPlusConfig;
516    type State = AdaMaxPlusState;
517
518    fn config(&self) -> &Self::Config {
519        &self.state.config
520    }
521
522    fn state(&self) -> &Self::State {
523        &self.state
524    }
525
526    fn state_mut(&mut self) -> &mut Self::State {
527        &mut self.state
528    }
529
530    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
531        let mut state_dict = HashMap::new();
532
533        // Convert momentum buffers to tensors
534        for (key, buffer) in &self.state.state.momentum {
535            let tensor = Tensor::new(buffer.clone())?;
536            state_dict.insert(format!("{}_momentum", key), tensor);
537        }
538
539        // Convert variance buffers to tensors (if any)
540        for (key, buffer) in &self.state.state.variance {
541            let tensor = Tensor::new(buffer.clone())?;
542            state_dict.insert(format!("{}_variance", key), tensor);
543        }
544
545        // Add infinity norms
546        for (key, &inf_norm) in &self.state.inf_norms {
547            let tensor = Tensor::new(vec![inf_norm])?;
548            state_dict.insert(format!("{}_inf_norm", key), tensor);
549        }
550
551        // Add gradient variances
552        for (key, &variance) in &self.state.gradient_variances {
553            let tensor = Tensor::new(vec![variance])?;
554            state_dict.insert(format!("{}_gradient_variance", key), tensor);
555        }
556
557        // Add parameter step counts
558        for (key, &step_count) in &self.state.param_step_counts {
559            let tensor = Tensor::new(vec![step_count as f32])?;
560            state_dict.insert(format!("{}_step_count", key), tensor);
561        }
562
563        // Add global step count
564        let step_tensor = Tensor::new(vec![self.state.step_count as f32])?;
565        state_dict.insert("step_count".to_string(), step_tensor);
566
567        Ok(state_dict)
568    }
569
570    fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()> {
571        for (key, tensor) in state_dict {
572            let data = tensor.data()?;
573
574            if key == "step_count" {
575                if let Some(&count) = data.first() {
576                    self.state.step_count = count as usize;
577                }
578            } else if let Some(param_id) = key.strip_suffix("_momentum") {
579                self.state.state.momentum.insert(param_id.to_string(), data.clone());
580            } else if let Some(param_id) = key.strip_suffix("_variance") {
581                self.state.state.variance.insert(param_id.to_string(), data.clone());
582            } else if let Some(param_id) = key.strip_suffix("_inf_norm") {
583                if let Some(&inf_norm) = data.first() {
584                    self.state.inf_norms.insert(param_id.to_string(), inf_norm);
585                }
586            } else if let Some(param_id) = key.strip_suffix("_gradient_variance") {
587                if let Some(&variance) = data.first() {
588                    self.state.gradient_variances.insert(param_id.to_string(), variance);
589                }
590            } else if let Some(param_id) = key.strip_suffix("_step_count") {
591                if let Some(&step_count) = data.first() {
592                    self.state.param_step_counts.insert(param_id.to_string(), step_count as usize);
593                }
594            }
595        }
596
597        Ok(())
598    }
599
600    fn memory_usage(&self) -> StateMemoryStats {
601        StateMemoryStats {
602            momentum_elements: self.state.state.momentum.values().map(|v| v.len()).sum::<usize>(),
603            variance_elements: self.state.state.variance.values().map(|v| v.len()).sum::<usize>(),
604            third_moment_elements: 0, // AdaMax+ doesn't use third moments
605            total_bytes: self.state.memory_usage(),
606            num_parameters: self.state.state.momentum.len(),
607        }
608    }
609
610    fn reset_state(&mut self) {
611        self.state.state.clear();
612        self.state.step_count = 0;
613        self.state.inf_norms.clear();
614        self.state.gradient_variances.clear();
615        self.state.param_step_counts.clear();
616    }
617
618    fn num_parameters(&self) -> usize {
619        self.state.state.momentum.len()
620    }
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626    use trustformers_core::tensor::Tensor;
627
628    #[test]
629    fn test_adamax_plus_creation() {
630        let optimizer = AdaMaxPlus::new(0.001, (0.9, 0.999), 1e-8, 0.01);
631        assert_eq!(optimizer.get_lr(), 0.001);
632        assert_eq!(optimizer.state.config.betas, (0.9, 0.999));
633        assert_eq!(optimizer.state.config.epsilon, 1e-8);
634        assert_eq!(optimizer.state.config.weight_decay, 0.01);
635    }
636
637    #[test]
638    fn test_adamax_plus_config() {
639        let config = AdaMaxPlusConfig::new()
640            .learning_rate(0.002)
641            .betas((0.95, 0.999))
642            .enable_adaptive_momentum(true)
643            .warmup_steps(1000);
644
645        let optimizer = AdaMaxPlus::from_config(config);
646        assert_eq!(optimizer.get_lr(), 0.002);
647        assert_eq!(optimizer.state.config.betas, (0.95, 0.999));
648        assert!(optimizer.state.config.adaptive_momentum);
649        assert_eq!(optimizer.state.config.warmup_steps, 1000);
650    }
651
652    #[test]
653    fn test_adamax_plus_presets() {
654        let llm_optimizer = AdaMaxPlus::for_large_models();
655        assert_eq!(llm_optimizer.get_lr(), 0.0002);
656        assert_eq!(llm_optimizer.state.config.warmup_steps, 10000);
657        assert!(llm_optimizer.state.config.adaptive_momentum);
658
659        let fast_optimizer = AdaMaxPlus::for_fast_training();
660        assert_eq!(fast_optimizer.get_lr(), 0.003);
661        assert_eq!(
662            fast_optimizer.state.config.momentum_adaptation_strength,
663            0.2
664        );
665
666        let stable_optimizer = AdaMaxPlus::for_stable_training();
667        assert!(!stable_optimizer.state.config.adaptive_momentum);
668        assert!(!stable_optimizer.state.config.variance_tracking);
669    }
670
671    #[test]
672    fn test_adamax_plus_step() -> Result<()> {
673        let mut optimizer = AdaMaxPlus::new(0.01, (0.9, 0.999), 1e-8, 0.0);
674
675        // Create test parameters and gradients directly
676        let mut param = Tensor::ones(&[2, 2])?;
677        let grad = Tensor::new(vec![0.1, 0.2, 0.3, 0.4])?;
678
679        // Store original parameters
680        let original_data = param.data()?.clone();
681
682        // Perform optimization step
683        optimizer.update(&mut param, &grad)?;
684
685        // Check that parameters were updated
686        let param_data = param.data()?;
687        assert!(param_data.iter().zip(original_data.iter()).all(|(&new, &orig)| new != orig)); // Parameters should change
688
689        Ok(())
690    }
691
692    #[test]
693    fn test_warmup_learning_rate() {
694        let mut optimizer =
695            AdaMaxPlus::from_config(AdaMaxPlusConfig::new().learning_rate(0.001).warmup_steps(100));
696
697        // At step 0, effective LR should be 0
698        assert_eq!(optimizer.compute_effective_learning_rate(), 0.0);
699
700        // At step 50, effective LR should be 50% of base LR
701        optimizer.state.step_count = 50;
702        assert!((optimizer.compute_effective_learning_rate() - 0.0005).abs() < 1e-9);
703
704        // At step 100, effective LR should be 100% of base LR
705        optimizer.state.step_count = 100;
706        assert!((optimizer.compute_effective_learning_rate() - 0.001).abs() < 1e-9);
707
708        // Beyond warmup, should remain at base LR
709        optimizer.state.step_count = 200;
710        assert!((optimizer.compute_effective_learning_rate() - 0.001).abs() < 1e-9);
711    }
712
713    #[test]
714    fn test_adaptive_momentum() {
715        let optimizer = AdaMaxPlus::from_config(
716            AdaMaxPlusConfig::new()
717                .enable_adaptive_momentum(true)
718                .momentum_adaptation_strength(0.2),
719        );
720
721        // Test with low variance (should use higher momentum)
722        let param_id = "test_param".to_string();
723
724        // Simulate low variance by setting it in the optimizer's gradient_variances
725        let mut test_optimizer = optimizer;
726        test_optimizer.state.gradient_variances.insert(param_id.clone(), 0.1);
727
728        let adaptive_beta1 = test_optimizer.compute_adaptive_momentum(param_id.clone());
729        assert!(adaptive_beta1 > 0.85); // Should be close to base beta1 (0.9)
730
731        // Test with high variance (should use lower momentum)
732        test_optimizer.state.gradient_variances.insert(param_id.clone(), 0.8);
733
734        let adaptive_beta1_high = test_optimizer.compute_adaptive_momentum(param_id);
735        assert!(adaptive_beta1_high < adaptive_beta1); // Should be lower than low variance case
736    }
737
738    #[test]
739    fn test_state_dict_save_load() -> Result<()> {
740        let mut optimizer = AdaMaxPlus::new(0.001, (0.9, 0.999), 1e-8, 0.01);
741
742        // Create and process some parameters directly
743        let mut param = Tensor::ones(&[2])?;
744        let grad = Tensor::new(vec![0.1, 0.2])?;
745        optimizer.update(&mut param, &grad)?;
746
747        // Save state
748        let state_dict = optimizer.state_dict()?;
749        assert!(!state_dict.is_empty());
750
751        // Create new optimizer and load state
752        let mut new_optimizer = AdaMaxPlus::new(0.002, (0.8, 0.99), 1e-7, 0.02);
753        new_optimizer.load_state_dict(state_dict)?;
754
755        // Check that state was loaded correctly (config doesn't change during load)
756        assert_eq!(new_optimizer.get_lr(), 0.002); // Should keep new config
757        assert_eq!(new_optimizer.state.config.betas, (0.8, 0.99));
758        assert!(new_optimizer.state.step_count > 0);
759
760        Ok(())
761    }
762
763    #[test]
764    fn test_zero_grad() -> Result<()> {
765        let mut optimizer = AdaMaxPlus::new(0.001, (0.9, 0.999), 1e-8, 0.0);
766
767        // Test that zero_grad doesn't crash (implementation depends on framework)
768        optimizer.zero_grad();
769
770        // Since gradient tracking isn't implemented yet, we just ensure the method exists
771        // and can be called without errors
772        assert_eq!(optimizer.get_lr(), 0.001);
773
774        Ok(())
775    }
776
777    #[test]
778    fn test_memory_usage_tracking() {
779        let optimizer = AdaMaxPlus::new(0.001, (0.9, 0.999), 1e-8, 0.0);
780        let memory_usage = optimizer.memory_usage();
781        assert_eq!(memory_usage.total_bytes, 0); // Should start at 0
782    }
783
784    #[test]
785    fn test_lr_get_set() {
786        let mut optimizer = AdaMaxPlus::new(0.001, (0.9, 0.999), 1e-8, 0.0);
787        assert_eq!(optimizer.get_lr(), 0.001);
788
789        optimizer.set_lr(0.002);
790        assert_eq!(optimizer.get_lr(), 0.002);
791    }
792}