1use crate::common::{OptimizerState, StateMemoryStats};
74use crate::traits::StatefulOptimizer;
75use serde::{Deserialize, Serialize};
76use std::collections::HashMap;
77use trustformers_core::errors::Result;
78use trustformers_core::{tensor::Tensor, traits::Optimizer};
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct AdaMaxPlusConfig {
83 pub learning_rate: f32,
85 pub betas: (f32, f32),
87 pub epsilon: f32,
89 pub weight_decay: f32,
91 pub adaptive_momentum: bool,
93 pub momentum_adaptation_strength: f32,
95 pub warmup_steps: usize,
97 pub variance_tracking: bool,
99 pub bias_correction_factor: f32,
101 pub outlier_threshold: f32,
103}
104
105impl Default for AdaMaxPlusConfig {
106 fn default() -> Self {
107 Self {
108 learning_rate: 0.001,
109 betas: (0.9, 0.999),
110 epsilon: 1e-8,
111 weight_decay: 0.0,
112 adaptive_momentum: true,
113 momentum_adaptation_strength: 0.1,
114 warmup_steps: 0,
115 variance_tracking: true,
116 bias_correction_factor: 1.0,
117 outlier_threshold: 10.0,
118 }
119 }
120}
121
122impl AdaMaxPlusConfig {
123 pub fn new() -> Self {
125 Self::default()
126 }
127
128 pub fn learning_rate(mut self, lr: f32) -> Self {
130 self.learning_rate = lr;
131 self
132 }
133
134 pub fn betas(mut self, betas: (f32, f32)) -> Self {
136 self.betas = betas;
137 self
138 }
139
140 pub fn epsilon(mut self, eps: f32) -> Self {
142 self.epsilon = eps;
143 self
144 }
145
146 pub fn weight_decay(mut self, wd: f32) -> Self {
148 self.weight_decay = wd;
149 self
150 }
151
152 pub fn enable_adaptive_momentum(mut self, enabled: bool) -> Self {
154 self.adaptive_momentum = enabled;
155 self
156 }
157
158 pub fn momentum_adaptation_strength(mut self, strength: f32) -> Self {
160 self.momentum_adaptation_strength = strength;
161 self
162 }
163
164 pub fn warmup_steps(mut self, steps: usize) -> Self {
166 self.warmup_steps = steps;
167 self
168 }
169
170 pub fn variance_tracking(mut self, enabled: bool) -> Self {
172 self.variance_tracking = enabled;
173 self
174 }
175
176 pub fn bias_correction_factor(mut self, factor: f32) -> Self {
178 self.bias_correction_factor = factor;
179 self
180 }
181
182 pub fn outlier_threshold(mut self, threshold: f32) -> Self {
184 self.outlier_threshold = threshold;
185 self
186 }
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct AdaMaxPlusParameterState {
192 pub momentum: Vec<f32>,
194 pub inf_norm: f32,
196 pub gradient_variance: f32,
198 pub step_count: usize,
200 pub grad_ema: Option<Vec<f32>>,
202 pub grad_sq_ema: Option<Vec<f32>>,
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct AdaMaxPlusState {
209 pub state: OptimizerState,
211 pub config: AdaMaxPlusConfig,
213 pub step_count: usize,
215 pub inf_norms: HashMap<String, f32>,
217 pub gradient_variances: HashMap<String, f32>,
219 pub param_step_counts: HashMap<String, usize>,
221}
222
223impl AdaMaxPlusState {
224 pub fn new(config: AdaMaxPlusConfig) -> Self {
226 Self {
227 state: OptimizerState::new(),
228 config,
229 step_count: 0,
230 inf_norms: HashMap::new(),
231 gradient_variances: HashMap::new(),
232 param_step_counts: HashMap::new(),
233 }
234 }
235
236 pub fn memory_usage(&self) -> usize {
238 let momentum_size = self.state.momentum.values().map(|v| v.len() * 4).sum::<usize>(); let variance_size = self.state.variance.values().map(|v| v.len() * 4).sum::<usize>();
241 let inf_norms_size = self.inf_norms.len() * 4; let gradient_variances_size = self.gradient_variances.len() * 4;
243 let param_step_counts_size = self.param_step_counts.len() * 8; momentum_size
246 + variance_size
247 + inf_norms_size
248 + gradient_variances_size
249 + param_step_counts_size
250 }
251}
252
253pub struct AdaMaxPlus {
255 state: AdaMaxPlusState,
256}
257
258impl AdaMaxPlus {
259 pub fn new(learning_rate: f32, betas: (f32, f32), epsilon: f32, weight_decay: f32) -> Self {
261 let config = AdaMaxPlusConfig {
262 learning_rate,
263 betas,
264 epsilon,
265 weight_decay,
266 ..Default::default()
267 };
268
269 Self {
270 state: AdaMaxPlusState::new(config),
271 }
272 }
273
274 pub fn from_config(config: AdaMaxPlusConfig) -> Self {
276 Self {
277 state: AdaMaxPlusState::new(config),
278 }
279 }
280
281 pub fn for_large_models() -> Self {
283 let config = AdaMaxPlusConfig::new()
284 .learning_rate(0.0002)
285 .betas((0.9, 0.999))
286 .enable_adaptive_momentum(true)
287 .warmup_steps(10000)
288 .variance_tracking(true)
289 .weight_decay(0.1);
290
291 Self::from_config(config)
292 }
293
294 pub fn for_fast_training() -> Self {
296 let config = AdaMaxPlusConfig::new()
297 .learning_rate(0.003)
298 .betas((0.95, 0.999))
299 .enable_adaptive_momentum(true)
300 .momentum_adaptation_strength(0.2)
301 .warmup_steps(500);
302
303 Self::from_config(config)
304 }
305
306 pub fn for_stable_training() -> Self {
308 let config = AdaMaxPlusConfig::new()
309 .learning_rate(0.001)
310 .betas((0.9, 0.999))
311 .enable_adaptive_momentum(false)
312 .variance_tracking(false)
313 .bias_correction_factor(1.2)
314 .outlier_threshold(5.0);
315
316 Self::from_config(config)
317 }
318
319 fn compute_adaptive_momentum(&self, param_id: String) -> f32 {
321 if !self.state.config.adaptive_momentum {
322 return self.state.config.betas.0;
323 }
324
325 let base_beta1 = self.state.config.betas.0;
326 let adaptation_strength = self.state.config.momentum_adaptation_strength;
327
328 let variance_factor = if self.state.config.variance_tracking {
330 self.state.gradient_variances.get(¶m_id).copied().unwrap_or(0.0).min(1.0)
331 } else {
332 0.0
333 };
334
335 let adaptive_beta1 = base_beta1 * (1.0 - adaptation_strength * variance_factor);
337 adaptive_beta1.clamp(0.1, 0.99) }
339
340 fn compute_effective_learning_rate(&self) -> f32 {
342 let base_lr = self.state.config.learning_rate;
343
344 if self.state.config.warmup_steps == 0 {
345 return base_lr;
346 }
347
348 let warmup_factor = if self.state.step_count <= self.state.config.warmup_steps {
349 self.state.step_count as f32 / self.state.config.warmup_steps as f32
350 } else {
351 1.0
352 };
353
354 base_lr * warmup_factor
355 }
356
357 fn update_gradient_variance(&mut self, param_id: String, gradient: &Tensor) -> Result<()> {
359 if !self.state.config.variance_tracking {
360 return Ok(());
361 }
362
363 let beta1 = self.state.config.betas.0;
364 let beta2 = self.state.config.betas.1;
365
366 let gradient_data = gradient.data()?;
367 let param_size = gradient_data.len();
368
369 let grad_ema = self
371 .state
372 .state
373 .get_or_create_momentum(format!("{}_grad_ema", param_id), param_size)
374 .clone();
375 let grad_sq_ema = self
376 .state
377 .state
378 .get_or_create_variance(format!("{}_grad_sq_ema", param_id), param_size)
379 .clone();
380
381 let updated_grad_ema: Vec<f32> = grad_ema
383 .iter()
384 .zip(gradient_data.iter())
385 .map(|(&m, &g)| beta1 * m + (1.0 - beta1) * g)
386 .collect();
387
388 let updated_grad_sq_ema: Vec<f32> = grad_sq_ema
390 .iter()
391 .zip(gradient_data.iter())
392 .map(|(&v, &g)| beta2 * v + (1.0 - beta2) * g * g)
393 .collect();
394
395 let variance: f32 = updated_grad_sq_ema
397 .iter()
398 .zip(updated_grad_ema.iter())
399 .map(|(&sq_ema, &ema)| sq_ema - ema * ema)
400 .sum::<f32>()
401 / param_size as f32;
402
403 self.state
405 .state
406 .momentum
407 .insert(format!("{}_grad_ema", param_id), updated_grad_ema);
408 self.state
409 .state
410 .variance
411 .insert(format!("{}_grad_sq_ema", param_id), updated_grad_sq_ema);
412 self.state.gradient_variances.insert(param_id, variance);
413
414 Ok(())
415 }
416}
417
418impl Optimizer for AdaMaxPlus {
419 fn step(&mut self) {
420 }
422
423 fn zero_grad(&mut self) {
424 }
426
427 fn update(&mut self, parameter: &mut Tensor, gradient: &Tensor) -> Result<()> {
428 let param_data = parameter.data()?;
430 let param_id = format!("{:p}", param_data.as_ptr());
431 let param_size = param_data.len();
432 self.state.step_count += 1;
433
434 let momentum_data = {
436 let momentum_buffer =
437 self.state.state.get_or_create_momentum(param_id.clone(), param_size);
438 momentum_buffer.clone()
439 };
440
441 if self.state.config.variance_tracking {
443 self.update_gradient_variance(param_id.clone(), gradient)?;
444 }
445
446 let effective_gradient = if self.state.config.weight_decay > 0.0 {
448 gradient.add(¶meter.mul_scalar(self.state.config.weight_decay)?)?
449 } else {
450 gradient.clone()
451 };
452
453 let adaptive_beta1 = self.compute_adaptive_momentum(param_id.clone());
455 let beta2 = self.state.config.betas.1;
456
457 let gradient_data = effective_gradient.data()?;
459 let updated_momentum: Vec<f32> = momentum_data
460 .iter()
461 .zip(gradient_data.iter())
462 .map(|(&m, &g)| adaptive_beta1 * m + (1.0 - adaptive_beta1) * g)
463 .collect();
464
465 let grad_inf_norm = gradient_data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
467 let clamped_grad_norm = grad_inf_norm.min(self.state.config.outlier_threshold);
468 let current_inf_norm = self.state.inf_norms.get(¶m_id).copied().unwrap_or(0.0);
469 let new_inf_norm = (beta2 * current_inf_norm).max(clamped_grad_norm);
470 self.state.inf_norms.insert(param_id.clone(), new_inf_norm);
471
472 let step_count = self.state.param_step_counts.entry(param_id.clone()).or_insert(0);
474 *step_count += 1;
475
476 let bias_correction = 1.0 - adaptive_beta1.powi(*step_count as i32);
478 let bias_corrected_momentum: Vec<f32> = updated_momentum
479 .iter()
480 .map(|&m| m / (bias_correction * self.state.config.bias_correction_factor))
481 .collect();
482
483 let effective_lr = self.compute_effective_learning_rate();
485
486 let step_size = effective_lr / (new_inf_norm + self.state.config.epsilon);
488
489 let param_data = parameter.data()?;
491 let updated_params: Vec<f32> = param_data
492 .iter()
493 .zip(bias_corrected_momentum.iter())
494 .map(|(&p, &m)| p - step_size * m)
495 .collect();
496
497 *parameter = Tensor::new(updated_params)?;
498
499 self.state.state.momentum.insert(param_id, updated_momentum);
501
502 Ok(())
503 }
504
505 fn set_lr(&mut self, lr: f32) {
506 self.state.config.learning_rate = lr;
507 }
508
509 fn get_lr(&self) -> f32 {
510 self.state.config.learning_rate
511 }
512}
513
514impl StatefulOptimizer for AdaMaxPlus {
515 type Config = AdaMaxPlusConfig;
516 type State = AdaMaxPlusState;
517
518 fn config(&self) -> &Self::Config {
519 &self.state.config
520 }
521
522 fn state(&self) -> &Self::State {
523 &self.state
524 }
525
526 fn state_mut(&mut self) -> &mut Self::State {
527 &mut self.state
528 }
529
530 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
531 let mut state_dict = HashMap::new();
532
533 for (key, buffer) in &self.state.state.momentum {
535 let tensor = Tensor::new(buffer.clone())?;
536 state_dict.insert(format!("{}_momentum", key), tensor);
537 }
538
539 for (key, buffer) in &self.state.state.variance {
541 let tensor = Tensor::new(buffer.clone())?;
542 state_dict.insert(format!("{}_variance", key), tensor);
543 }
544
545 for (key, &inf_norm) in &self.state.inf_norms {
547 let tensor = Tensor::new(vec![inf_norm])?;
548 state_dict.insert(format!("{}_inf_norm", key), tensor);
549 }
550
551 for (key, &variance) in &self.state.gradient_variances {
553 let tensor = Tensor::new(vec![variance])?;
554 state_dict.insert(format!("{}_gradient_variance", key), tensor);
555 }
556
557 for (key, &step_count) in &self.state.param_step_counts {
559 let tensor = Tensor::new(vec![step_count as f32])?;
560 state_dict.insert(format!("{}_step_count", key), tensor);
561 }
562
563 let step_tensor = Tensor::new(vec![self.state.step_count as f32])?;
565 state_dict.insert("step_count".to_string(), step_tensor);
566
567 Ok(state_dict)
568 }
569
570 fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()> {
571 for (key, tensor) in state_dict {
572 let data = tensor.data()?;
573
574 if key == "step_count" {
575 if let Some(&count) = data.first() {
576 self.state.step_count = count as usize;
577 }
578 } else if let Some(param_id) = key.strip_suffix("_momentum") {
579 self.state.state.momentum.insert(param_id.to_string(), data.clone());
580 } else if let Some(param_id) = key.strip_suffix("_variance") {
581 self.state.state.variance.insert(param_id.to_string(), data.clone());
582 } else if let Some(param_id) = key.strip_suffix("_inf_norm") {
583 if let Some(&inf_norm) = data.first() {
584 self.state.inf_norms.insert(param_id.to_string(), inf_norm);
585 }
586 } else if let Some(param_id) = key.strip_suffix("_gradient_variance") {
587 if let Some(&variance) = data.first() {
588 self.state.gradient_variances.insert(param_id.to_string(), variance);
589 }
590 } else if let Some(param_id) = key.strip_suffix("_step_count") {
591 if let Some(&step_count) = data.first() {
592 self.state.param_step_counts.insert(param_id.to_string(), step_count as usize);
593 }
594 }
595 }
596
597 Ok(())
598 }
599
600 fn memory_usage(&self) -> StateMemoryStats {
601 StateMemoryStats {
602 momentum_elements: self.state.state.momentum.values().map(|v| v.len()).sum::<usize>(),
603 variance_elements: self.state.state.variance.values().map(|v| v.len()).sum::<usize>(),
604 third_moment_elements: 0, total_bytes: self.state.memory_usage(),
606 num_parameters: self.state.state.momentum.len(),
607 }
608 }
609
610 fn reset_state(&mut self) {
611 self.state.state.clear();
612 self.state.step_count = 0;
613 self.state.inf_norms.clear();
614 self.state.gradient_variances.clear();
615 self.state.param_step_counts.clear();
616 }
617
618 fn num_parameters(&self) -> usize {
619 self.state.state.momentum.len()
620 }
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626 use trustformers_core::tensor::Tensor;
627
628 #[test]
629 fn test_adamax_plus_creation() {
630 let optimizer = AdaMaxPlus::new(0.001, (0.9, 0.999), 1e-8, 0.01);
631 assert_eq!(optimizer.get_lr(), 0.001);
632 assert_eq!(optimizer.state.config.betas, (0.9, 0.999));
633 assert_eq!(optimizer.state.config.epsilon, 1e-8);
634 assert_eq!(optimizer.state.config.weight_decay, 0.01);
635 }
636
637 #[test]
638 fn test_adamax_plus_config() {
639 let config = AdaMaxPlusConfig::new()
640 .learning_rate(0.002)
641 .betas((0.95, 0.999))
642 .enable_adaptive_momentum(true)
643 .warmup_steps(1000);
644
645 let optimizer = AdaMaxPlus::from_config(config);
646 assert_eq!(optimizer.get_lr(), 0.002);
647 assert_eq!(optimizer.state.config.betas, (0.95, 0.999));
648 assert!(optimizer.state.config.adaptive_momentum);
649 assert_eq!(optimizer.state.config.warmup_steps, 1000);
650 }
651
652 #[test]
653 fn test_adamax_plus_presets() {
654 let llm_optimizer = AdaMaxPlus::for_large_models();
655 assert_eq!(llm_optimizer.get_lr(), 0.0002);
656 assert_eq!(llm_optimizer.state.config.warmup_steps, 10000);
657 assert!(llm_optimizer.state.config.adaptive_momentum);
658
659 let fast_optimizer = AdaMaxPlus::for_fast_training();
660 assert_eq!(fast_optimizer.get_lr(), 0.003);
661 assert_eq!(
662 fast_optimizer.state.config.momentum_adaptation_strength,
663 0.2
664 );
665
666 let stable_optimizer = AdaMaxPlus::for_stable_training();
667 assert!(!stable_optimizer.state.config.adaptive_momentum);
668 assert!(!stable_optimizer.state.config.variance_tracking);
669 }
670
671 #[test]
672 fn test_adamax_plus_step() -> Result<()> {
673 let mut optimizer = AdaMaxPlus::new(0.01, (0.9, 0.999), 1e-8, 0.0);
674
675 let mut param = Tensor::ones(&[2, 2])?;
677 let grad = Tensor::new(vec![0.1, 0.2, 0.3, 0.4])?;
678
679 let original_data = param.data()?.clone();
681
682 optimizer.update(&mut param, &grad)?;
684
685 let param_data = param.data()?;
687 assert!(param_data.iter().zip(original_data.iter()).all(|(&new, &orig)| new != orig)); Ok(())
690 }
691
692 #[test]
693 fn test_warmup_learning_rate() {
694 let mut optimizer =
695 AdaMaxPlus::from_config(AdaMaxPlusConfig::new().learning_rate(0.001).warmup_steps(100));
696
697 assert_eq!(optimizer.compute_effective_learning_rate(), 0.0);
699
700 optimizer.state.step_count = 50;
702 assert!((optimizer.compute_effective_learning_rate() - 0.0005).abs() < 1e-9);
703
704 optimizer.state.step_count = 100;
706 assert!((optimizer.compute_effective_learning_rate() - 0.001).abs() < 1e-9);
707
708 optimizer.state.step_count = 200;
710 assert!((optimizer.compute_effective_learning_rate() - 0.001).abs() < 1e-9);
711 }
712
713 #[test]
714 fn test_adaptive_momentum() {
715 let optimizer = AdaMaxPlus::from_config(
716 AdaMaxPlusConfig::new()
717 .enable_adaptive_momentum(true)
718 .momentum_adaptation_strength(0.2),
719 );
720
721 let param_id = "test_param".to_string();
723
724 let mut test_optimizer = optimizer;
726 test_optimizer.state.gradient_variances.insert(param_id.clone(), 0.1);
727
728 let adaptive_beta1 = test_optimizer.compute_adaptive_momentum(param_id.clone());
729 assert!(adaptive_beta1 > 0.85); test_optimizer.state.gradient_variances.insert(param_id.clone(), 0.8);
733
734 let adaptive_beta1_high = test_optimizer.compute_adaptive_momentum(param_id);
735 assert!(adaptive_beta1_high < adaptive_beta1); }
737
738 #[test]
739 fn test_state_dict_save_load() -> Result<()> {
740 let mut optimizer = AdaMaxPlus::new(0.001, (0.9, 0.999), 1e-8, 0.01);
741
742 let mut param = Tensor::ones(&[2])?;
744 let grad = Tensor::new(vec![0.1, 0.2])?;
745 optimizer.update(&mut param, &grad)?;
746
747 let state_dict = optimizer.state_dict()?;
749 assert!(!state_dict.is_empty());
750
751 let mut new_optimizer = AdaMaxPlus::new(0.002, (0.8, 0.99), 1e-7, 0.02);
753 new_optimizer.load_state_dict(state_dict)?;
754
755 assert_eq!(new_optimizer.get_lr(), 0.002); assert_eq!(new_optimizer.state.config.betas, (0.8, 0.99));
758 assert!(new_optimizer.state.step_count > 0);
759
760 Ok(())
761 }
762
763 #[test]
764 fn test_zero_grad() -> Result<()> {
765 let mut optimizer = AdaMaxPlus::new(0.001, (0.9, 0.999), 1e-8, 0.0);
766
767 optimizer.zero_grad();
769
770 assert_eq!(optimizer.get_lr(), 0.001);
773
774 Ok(())
775 }
776
777 #[test]
778 fn test_memory_usage_tracking() {
779 let optimizer = AdaMaxPlus::new(0.001, (0.9, 0.999), 1e-8, 0.0);
780 let memory_usage = optimizer.memory_usage();
781 assert_eq!(memory_usage.total_bytes, 0); }
783
784 #[test]
785 fn test_lr_get_set() {
786 let mut optimizer = AdaMaxPlus::new(0.001, (0.9, 0.999), 1e-8, 0.0);
787 assert_eq!(optimizer.get_lr(), 0.001);
788
789 optimizer.set_lr(0.002);
790 assert_eq!(optimizer.get_lr(), 0.002);
791 }
792}