1use crate::{
20 common::{BiasCorrection, OptimizerState, ParameterUpdate, StateMemoryStats},
21 traits::StatefulOptimizer,
22};
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use trustformers_core::{errors::Result, tensor::Tensor, traits::Optimizer};
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct AMacPConfig {
30 pub learning_rate: f32,
32 pub beta1: f32,
34 pub beta2: f32,
36 pub gamma: f32,
38 pub alpha: f32,
40 pub eta: f32,
42 pub epsilon: f32,
44 pub weight_decay: f32,
46 pub max_grad_norm: Option<f32>,
48 pub adaptive_step_size: bool,
50 pub warmup_steps: usize,
52}
53
54impl Default for AMacPConfig {
55 fn default() -> Self {
56 Self {
57 learning_rate: 1e-3,
58 beta1: 0.9,
59 beta2: 0.999,
60 gamma: 0.95, alpha: 0.5, eta: 0.1, epsilon: 1e-8,
64 weight_decay: 0.0,
65 max_grad_norm: Some(1.0),
66 adaptive_step_size: true,
67 warmup_steps: 1000,
68 }
69 }
70}
71
72impl AMacPConfig {
73 pub fn for_transformers() -> Self {
75 Self {
76 learning_rate: 6e-4,
77 beta1: 0.9,
78 beta2: 0.95,
79 gamma: 0.98, alpha: 0.6, eta: 0.15, epsilon: 1e-8,
83 weight_decay: 1e-2,
84 max_grad_norm: Some(1.0),
85 adaptive_step_size: true,
86 warmup_steps: 4000, }
88 }
89
90 pub fn for_vision() -> Self {
92 Self {
93 learning_rate: 1e-3,
94 beta1: 0.9,
95 beta2: 0.999,
96 gamma: 0.92, alpha: 0.4, eta: 0.08, epsilon: 1e-8,
100 weight_decay: 5e-4,
101 max_grad_norm: Some(0.5),
102 adaptive_step_size: true,
103 warmup_steps: 500, }
105 }
106
107 pub fn for_large_language_models() -> Self {
109 Self {
110 learning_rate: 3e-4,
111 beta1: 0.9,
112 beta2: 0.95,
113 gamma: 0.99, alpha: 0.7, eta: 0.2, epsilon: 1e-8,
117 weight_decay: 1e-1,
118 max_grad_norm: Some(1.0),
119 adaptive_step_size: true,
120 warmup_steps: 10000, }
122 }
123}
124
125#[derive(Debug)]
127pub struct AMacP {
128 config: AMacPConfig,
129 state: OptimizerState,
130 previous_params: HashMap<String, Vec<f32>>,
132 dual_momentum: HashMap<String, Vec<f32>>,
134 gradient_heterogeneity: HashMap<String, f32>,
136 step_size_factors: HashMap<String, f32>,
138 current_step: usize,
140}
141
142impl AMacP {
143 pub fn new(config: AMacPConfig) -> Self {
145 Self {
146 config,
147 state: OptimizerState::new(),
148 previous_params: HashMap::new(),
149 dual_momentum: HashMap::new(),
150 gradient_heterogeneity: HashMap::new(),
151 step_size_factors: HashMap::new(),
152 current_step: 0,
153 }
154 }
155
156 pub fn for_transformers() -> Self {
158 Self::new(AMacPConfig::for_transformers())
159 }
160
161 pub fn for_vision() -> Self {
163 Self::new(AMacPConfig::for_vision())
164 }
165
166 pub fn for_large_language_models() -> Self {
168 Self::new(AMacPConfig::for_large_language_models())
169 }
170
171 fn compute_dual_momentum(&self, m_hat: f32, v_hat: f32) -> f32 {
173 self.config.alpha * m_hat + (1.0 - self.config.alpha) * v_hat.sqrt()
174 }
175
176 fn update_gradient_heterogeneity(&mut self, param_id: &str, gradient: &[f32]) {
178 let grad_norm: f32 = gradient.iter().map(|g| g * g).sum::<f32>().sqrt();
179 let grad_mean = gradient.iter().sum::<f32>() / gradient.len() as f32;
180 let grad_std = (gradient.iter().map(|g| (g - grad_mean) * (g - grad_mean)).sum::<f32>()
181 / gradient.len() as f32)
182 .sqrt();
183
184 let heterogeneity = if grad_norm > 1e-8 { grad_std / grad_norm } else { 0.0 };
185
186 let entry = self.gradient_heterogeneity.entry(param_id.to_string()).or_insert(0.0);
187 *entry = 0.9 * *entry + 0.1 * heterogeneity;
188 }
189
190 #[allow(dead_code)]
192 fn compute_adaptive_step_size_static(
193 config: &AMacPConfig,
194 current_params: &[f32],
195 prev_params: &[f32],
196 stored_factor: f32,
197 ) -> f32 {
198 if !config.adaptive_step_size {
199 return 1.0;
200 }
201
202 let param_change_norm: f32 = current_params
203 .iter()
204 .zip(prev_params.iter())
205 .map(|(curr, prev)| (curr - prev) * (curr - prev))
206 .sum::<f32>()
207 .sqrt();
208
209 let param_norm: f32 = current_params.iter().map(|p| p * p).sum::<f32>().sqrt();
210
211 let relative_change = if param_norm > 1e-8 { param_change_norm / param_norm } else { 0.0 };
212
213 let step_factor = if relative_change > 0.1 {
215 0.5 } else if relative_change < 0.01 {
217 1.5 } else {
219 1.0 };
221
222 0.9 * stored_factor + 0.1 * step_factor
223 }
224
225 fn get_warmup_lr(&self) -> f32 {
227 if self.current_step < self.config.warmup_steps {
228 let warmup_factor = (self.current_step as f32) / (self.config.warmup_steps as f32);
229 self.config.learning_rate * warmup_factor
230 } else {
231 self.config.learning_rate
232 }
233 }
234
235 pub fn learning_rate(&self) -> f32 {
237 self.config.learning_rate
238 }
239
240 pub fn set_learning_rate(&mut self, lr: f32) {
242 self.config.learning_rate = lr;
243 }
244}
245
246impl Optimizer for AMacP {
247 fn update(&mut self, _parameter: &mut Tensor, _gradient: &Tensor) -> Result<()> {
248 Ok(())
251 }
252
253 fn step(&mut self) {
254 self.current_step += 1;
256 self.state.step();
257 }
258
259 fn zero_grad(&mut self) {
260 }
263
264 fn get_lr(&self) -> f32 {
265 self.config.learning_rate
266 }
267
268 fn set_lr(&mut self, lr: f32) {
269 self.config.learning_rate = lr;
270 }
271}
272
273impl AMacP {
275 pub fn step_batch(&mut self, gradients: &HashMap<String, Tensor>) -> Result<()> {
277 let warmup_lr = self.get_warmup_lr();
278 let current_step = self.current_step + 1;
279
280 for (param_name, gradient) in gradients.iter() {
282 let grad_data = gradient.data()?;
283 if grad_data.is_empty() {
284 continue;
285 }
286
287 let mut clipped_grad = grad_data.clone();
289 if let Some(max_norm) = self.config.max_grad_norm {
290 let grad_norm: f32 = clipped_grad.iter().map(|g| g * g).sum::<f32>().sqrt();
291 if grad_norm > max_norm {
292 let scale = max_norm / grad_norm;
293 for g in clipped_grad.iter_mut() {
294 *g *= scale;
295 }
296 }
297 }
298
299 self.update_gradient_heterogeneity(param_name, &clipped_grad);
301
302 let param_size = clipped_grad.len();
303
304 let momentum = {
306 let momentum = self.state.get_or_create_momentum(param_name.clone(), param_size);
307 momentum.clone()
308 };
309
310 let variance = {
311 let variance = self.state.get_or_create_variance(param_name.clone(), param_size);
312 variance.clone()
313 };
314
315 let (bias_correction1, bias_correction2) = BiasCorrection::compute_adam_corrections(
317 self.config.beta1,
318 self.config.beta2,
319 current_step,
320 );
321
322 let mut updated_momentum = momentum;
324 let mut updated_variance = variance;
325 for i in 0..param_size {
326 ParameterUpdate::update_ema(
327 &mut updated_momentum[i],
328 clipped_grad[i],
329 self.config.beta1,
330 );
331 ParameterUpdate::update_ema(
332 &mut updated_variance[i],
333 clipped_grad[i] * clipped_grad[i],
334 self.config.beta2,
335 );
336 }
337
338 let m_hat: Vec<f32> = updated_momentum.iter().map(|m| m / bias_correction1).collect();
340 let v_hat: Vec<f32> = updated_variance.iter().map(|v| v / bias_correction2).collect();
341
342 let mut dual_momentum = self
344 .dual_momentum
345 .entry(param_name.clone())
346 .or_insert_with(|| vec![0.0; param_size])
347 .clone();
348
349 for i in 0..param_size {
350 let dual_momentum_value = self.compute_dual_momentum(m_hat[i], v_hat[i]);
351 ParameterUpdate::update_ema(
352 &mut dual_momentum[i],
353 dual_momentum_value,
354 self.config.gamma,
355 );
356 }
357
358 if let Some(prev_params) = self.previous_params.get(param_name).cloned() {
360 let step_factor = {
361 if !self.config.adaptive_step_size {
362 1.0
363 } else {
364 let param_change_norm: f32 = dual_momentum
365 .iter()
366 .zip(prev_params.iter())
367 .map(|(curr, prev)| (curr - prev) * (curr - prev))
368 .sum::<f32>()
369 .sqrt();
370
371 let param_norm: f32 =
372 dual_momentum.iter().map(|p| p * p).sum::<f32>().sqrt();
373
374 let relative_change =
375 if param_norm > 1e-8 { param_change_norm / param_norm } else { 0.0 };
376
377 let step_factor = if relative_change > 0.1 {
378 0.5 } else if relative_change < 0.01 {
380 1.5 } else {
382 1.0 };
384
385 let entry = self.step_size_factors.entry(param_name.clone()).or_insert(1.0);
386 *entry = 0.9 * *entry + 0.1 * step_factor;
387 *entry
388 }
389 };
390
391 let heterogeneity_factor = 1.0
392 + self.config.eta * self.gradient_heterogeneity.get(param_name).unwrap_or(&0.0);
393
394 let effective_lr = warmup_lr * step_factor * heterogeneity_factor;
395
396 for i in 0..param_size {
398 let averaged_param = self.config.gamma * prev_params[i]
399 + (1.0 - self.config.gamma) * dual_momentum[i];
400
401 let _update =
403 effective_lr * averaged_param / (v_hat[i].sqrt() + self.config.epsilon);
404 }
407 }
408
409 self.state.momentum.insert(param_name.clone(), updated_momentum);
411 self.state.variance.insert(param_name.clone(), updated_variance);
412 self.dual_momentum.insert(param_name.clone(), dual_momentum.clone());
413 self.previous_params.insert(param_name.clone(), dual_momentum);
414 }
415
416 self.current_step = current_step;
418 self.state.step = current_step;
419
420 Ok(())
421 }
422}
423
424impl StatefulOptimizer for AMacP {
425 type Config = AMacPConfig;
426 type State = OptimizerState;
427
428 fn config(&self) -> &Self::Config {
429 &self.config
430 }
431
432 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
433 let mut state = HashMap::new();
434
435 state.insert(
437 "step".to_string(),
438 Tensor::new(vec![self.current_step as f32])?,
439 );
440
441 for (name, momentum) in &self.state.momentum {
443 let shape = vec![momentum.len()];
444 state.insert(
445 format!("momentum_{}", name),
446 Tensor::from_vec(momentum.clone(), &shape)?,
447 );
448 }
449 for (name, variance) in &self.state.variance {
450 let shape = vec![variance.len()];
451 state.insert(
452 format!("variance_{}", name),
453 Tensor::from_vec(variance.clone(), &shape)?,
454 );
455 }
456
457 for (name, dual_mom) in &self.dual_momentum {
459 let shape = vec![dual_mom.len()];
460 state.insert(
461 format!("dual_momentum_{}", name),
462 Tensor::from_vec(dual_mom.clone(), &shape)?,
463 );
464 }
465 for (name, prev_params) in &self.previous_params {
466 let shape = vec![prev_params.len()];
467 state.insert(
468 format!("prev_params_{}", name),
469 Tensor::from_vec(prev_params.clone(), &shape)?,
470 );
471 }
472 for (name, heterogeneity) in &self.gradient_heterogeneity {
473 state.insert(
474 format!("heterogeneity_{}", name),
475 Tensor::new(vec![*heterogeneity])?,
476 );
477 }
478 for (name, factor) in &self.step_size_factors {
479 state.insert(format!("step_factor_{}", name), Tensor::new(vec![*factor])?);
480 }
481
482 Ok(state)
483 }
484
485 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
486 if let Some(step_tensor) = state.get("step") {
488 if let Ok(step_data) = step_tensor.data() {
489 if !step_data.is_empty() {
490 self.current_step = step_data[0] as usize;
491 self.state.step = self.current_step;
492 }
493 }
494 }
495
496 for (key, tensor) in &state {
498 if let Some(name) = key.strip_prefix("momentum_") {
499 if let Ok(data) = tensor.data() {
500 self.state.momentum.insert(name.to_string(), data);
501 }
502 } else if let Some(name) = key.strip_prefix("variance_") {
503 if let Ok(data) = tensor.data() {
504 self.state.variance.insert(name.to_string(), data);
505 }
506 } else if let Some(name) = key.strip_prefix("dual_momentum_") {
507 if let Ok(data) = tensor.data() {
508 self.dual_momentum.insert(name.to_string(), data);
509 }
510 } else if let Some(name) = key.strip_prefix("prev_params_") {
511 if let Ok(data) = tensor.data() {
512 self.previous_params.insert(name.to_string(), data);
513 }
514 } else if let Some(name) = key.strip_prefix("heterogeneity_") {
515 if let Ok(data) = tensor.data() {
516 if !data.is_empty() {
517 self.gradient_heterogeneity.insert(name.to_string(), data[0]);
518 }
519 }
520 } else if let Some(name) = key.strip_prefix("step_factor_") {
521 if let Ok(data) = tensor.data() {
522 if !data.is_empty() {
523 self.step_size_factors.insert(name.to_string(), data[0]);
524 }
525 }
526 }
527 }
528
529 Ok(())
530 }
531
532 fn memory_usage(&self) -> StateMemoryStats {
533 let base_stats = self.state.memory_usage();
534
535 let dual_momentum_elements: usize = self.dual_momentum.values().map(|v| v.len()).sum();
537 let prev_params_elements: usize = self.previous_params.values().map(|v| v.len()).sum();
538 let scalar_elements = self.gradient_heterogeneity.len() + self.step_size_factors.len();
539
540 StateMemoryStats {
541 momentum_elements: base_stats.momentum_elements
542 + dual_momentum_elements
543 + prev_params_elements,
544 variance_elements: base_stats.variance_elements,
545 third_moment_elements: scalar_elements,
546 total_bytes: base_stats.total_bytes
547 + (dual_momentum_elements + prev_params_elements + scalar_elements)
548 * std::mem::size_of::<f32>(),
549 num_parameters: base_stats.num_parameters,
550 }
551 }
552
553 fn state(&self) -> &Self::State {
554 &self.state
555 }
556
557 fn state_mut(&mut self) -> &mut Self::State {
558 &mut self.state
559 }
560
561 fn reset_state(&mut self) {
562 self.state.clear();
563 self.previous_params.clear();
564 self.dual_momentum.clear();
565 self.gradient_heterogeneity.clear();
566 self.step_size_factors.clear();
567 self.current_step = 0;
568 }
569
570 fn num_parameters(&self) -> usize {
571 self.state.momentum.len()
572 }
573}
574
575#[derive(Debug, Clone)]
577pub struct AMacPStats {
578 pub current_step: usize,
579 pub average_gradient_heterogeneity: f32,
580 pub average_step_size_factor: f32,
581 pub total_parameters: usize,
582 pub warmup_progress: f32,
583 pub dual_momentum_norm: f32,
584}
585
586impl AMacP {
587 pub fn reset(&mut self) {
589 self.reset_state();
590 }
591
592 pub fn get_stats(&self) -> AMacPStats {
594 let avg_heterogeneity = if !self.gradient_heterogeneity.is_empty() {
595 self.gradient_heterogeneity.values().sum::<f32>()
596 / self.gradient_heterogeneity.len() as f32
597 } else {
598 0.0
599 };
600
601 let avg_step_factor = if !self.step_size_factors.is_empty() {
602 self.step_size_factors.values().sum::<f32>() / self.step_size_factors.len() as f32
603 } else {
604 1.0
605 };
606
607 let warmup_progress = if self.config.warmup_steps > 0 {
608 (self.current_step as f32 / self.config.warmup_steps as f32).min(1.0)
609 } else {
610 1.0
611 };
612
613 let dual_momentum_norm: f32 = self
614 .dual_momentum
615 .values()
616 .flat_map(|v| v.iter())
617 .map(|x| x * x)
618 .sum::<f32>()
619 .sqrt();
620
621 AMacPStats {
622 current_step: self.current_step,
623 average_gradient_heterogeneity: avg_heterogeneity,
624 average_step_size_factor: avg_step_factor,
625 total_parameters: self.num_parameters(),
626 warmup_progress,
627 dual_momentum_norm,
628 }
629 }
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635
636 #[test]
637 fn test_amacp_creation() {
638 let optimizer = AMacP::new(AMacPConfig::default());
639 assert_eq!(optimizer.learning_rate(), 1e-3);
640 assert_eq!(optimizer.config.beta1, 0.9);
641 assert_eq!(optimizer.config.beta2, 0.999);
642 assert_eq!(optimizer.config.gamma, 0.95);
643 }
644
645 #[test]
646 fn test_amacp_presets() {
647 let transformer_opt = AMacP::for_transformers();
648 assert_eq!(transformer_opt.config.learning_rate, 6e-4);
649 assert_eq!(transformer_opt.config.warmup_steps, 4000);
650
651 let vision_opt = AMacP::for_vision();
652 assert_eq!(vision_opt.config.learning_rate, 1e-3);
653 assert_eq!(vision_opt.config.warmup_steps, 500);
654
655 let llm_opt = AMacP::for_large_language_models();
656 assert_eq!(llm_opt.config.learning_rate, 3e-4);
657 assert_eq!(llm_opt.config.warmup_steps, 10000);
658 }
659
660 #[test]
661 fn test_dual_momentum_computation() {
662 let optimizer = AMacP::new(AMacPConfig::default());
663 let m_hat = 0.1;
664 let v_hat = 0.01;
665 let dual_momentum = optimizer.compute_dual_momentum(m_hat, v_hat);
666
667 let expected = 0.5 * 0.1 + 0.5 * 0.01_f32.sqrt();
668 assert!((dual_momentum - expected).abs() < 1e-6);
669 }
670
671 #[test]
672 fn test_learning_rate_getter_setter() {
673 let mut optimizer = AMacP::new(AMacPConfig::default());
674 assert_eq!(optimizer.learning_rate(), 1e-3);
675
676 optimizer.set_learning_rate(2e-3);
677 assert_eq!(optimizer.learning_rate(), 2e-3);
678 }
679
680 #[test]
681 fn test_warmup_lr_calculation() {
682 let mut optimizer = AMacP::new(AMacPConfig {
683 learning_rate: 1e-3,
684 warmup_steps: 1000,
685 ..Default::default()
686 });
687
688 optimizer.current_step = 500;
689 let warmup_lr = optimizer.get_warmup_lr();
690 assert!((warmup_lr - 5e-4).abs() < 1e-6); }
692
693 #[test]
694 fn test_memory_usage_tracking() {
695 let optimizer = AMacP::new(AMacPConfig::default());
696 let memory_stats = optimizer.memory_usage();
697
698 assert_eq!(memory_stats.momentum_elements, 0);
699 assert_eq!(memory_stats.variance_elements, 0);
700 assert_eq!(memory_stats.num_parameters, 0);
701 }
702
703 #[test]
704 fn test_stats_generation() {
705 let optimizer = AMacP::new(AMacPConfig::default());
706 let stats = optimizer.get_stats();
707
708 assert_eq!(stats.current_step, 0);
709 assert_eq!(stats.total_parameters, 0);
710 assert_eq!(stats.warmup_progress, 0.0);
711 assert_eq!(stats.dual_momentum_norm, 0.0);
712 }
713
714 #[test]
715 fn test_reset_functionality() {
716 let mut optimizer = AMacP::new(AMacPConfig::default());
717 optimizer.current_step = 100;
718
719 optimizer.reset();
720 assert_eq!(optimizer.current_step, 0);
721 assert!(optimizer.dual_momentum.is_empty());
722 assert!(optimizer.previous_params.is_empty());
723 }
724
725 #[test]
726 fn test_state_dict_operations() {
727 let optimizer = AMacP::new(AMacPConfig::default());
728 let state_dict = optimizer.state_dict();
729 assert!(state_dict.is_ok());
730
731 let state = state_dict.unwrap();
732 assert!(state.contains_key("step"));
733 }
734
735 #[test]
736 fn test_config_serialization() {
737 let config = AMacPConfig::for_transformers();
738 let serialized = serde_json::to_string(&config);
739 assert!(serialized.is_ok());
740
741 let deserialized: std::result::Result<AMacPConfig, _> =
742 serde_json::from_str(&serialized.unwrap());
743 assert!(deserialized.is_ok());
744 assert_eq!(deserialized.unwrap().learning_rate, 6e-4);
745 }
746}