tensorlogic_train/
optimizer.rs

1//! Optimizer wrappers around SciRS2 optimizers.
2
3use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{Array, Ix2};
5use std::collections::HashMap;
6
7/// Compute the global L2 norm of all gradients.
8///
9/// # Arguments
10/// * `gradients` - Gradients for all parameters
11///
12/// # Returns
13/// The L2 norm of all gradients combined
14fn compute_gradient_norm(gradients: &HashMap<String, Array<f64, Ix2>>) -> f64 {
15    let mut total_norm_sq = 0.0;
16
17    for grad in gradients.values() {
18        for &g in grad.iter() {
19            total_norm_sq += g * g;
20        }
21    }
22
23    total_norm_sq.sqrt()
24}
25
26/// Gradient clipping mode.
27#[derive(Debug, Clone, Copy, PartialEq)]
28pub enum GradClipMode {
29    /// Clip by value (element-wise).
30    Value,
31    /// Clip by global L2 norm.
32    Norm,
33}
34
35/// Configuration for optimizers.
36#[derive(Debug, Clone)]
37pub struct OptimizerConfig {
38    /// Learning rate.
39    pub learning_rate: f64,
40    /// Momentum (for SGD).
41    pub momentum: f64,
42    /// Beta1 (for Adam/AdamW).
43    pub beta1: f64,
44    /// Beta2 (for Adam/AdamW).
45    pub beta2: f64,
46    /// Epsilon for numerical stability.
47    pub epsilon: f64,
48    /// Weight decay (for AdamW).
49    pub weight_decay: f64,
50    /// Gradient clipping threshold (None = no clipping).
51    pub grad_clip: Option<f64>,
52    /// Gradient clipping mode.
53    pub grad_clip_mode: GradClipMode,
54}
55
56impl Default for OptimizerConfig {
57    fn default() -> Self {
58        Self {
59            learning_rate: 0.001,
60            momentum: 0.9,
61            beta1: 0.9,
62            beta2: 0.999,
63            epsilon: 1e-8,
64            weight_decay: 0.01,
65            grad_clip: None,
66            grad_clip_mode: GradClipMode::Value,
67        }
68    }
69}
70
71/// Trait for optimizers.
72pub trait Optimizer {
73    /// Update parameters with computed gradients.
74    fn step(
75        &mut self,
76        parameters: &mut HashMap<String, Array<f64, Ix2>>,
77        gradients: &HashMap<String, Array<f64, Ix2>>,
78    ) -> TrainResult<()>;
79
80    /// Zero all gradients.
81    fn zero_grad(&mut self);
82
83    /// Get current learning rate.
84    fn get_lr(&self) -> f64;
85
86    /// Set learning rate.
87    fn set_lr(&mut self, lr: f64);
88
89    /// Get optimizer state for checkpointing.
90    fn state_dict(&self) -> HashMap<String, Vec<f64>>;
91
92    /// Load optimizer state from checkpoint.
93    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>);
94}
95
96/// SGD optimizer with momentum.
97#[derive(Debug)]
98pub struct SgdOptimizer {
99    config: OptimizerConfig,
100    /// Momentum buffers for each parameter.
101    velocity: HashMap<String, Array<f64, Ix2>>,
102}
103
104impl SgdOptimizer {
105    /// Create a new SGD optimizer.
106    pub fn new(config: OptimizerConfig) -> Self {
107        Self {
108            config,
109            velocity: HashMap::new(),
110        }
111    }
112
113    /// Apply gradient clipping if configured.
114    fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
115        if let Some(clip_value) = self.config.grad_clip {
116            match self.config.grad_clip_mode {
117                GradClipMode::Value => {
118                    // Clip by value (element-wise)
119                    for grad in gradients.values_mut() {
120                        grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
121                    }
122                }
123                GradClipMode::Norm => {
124                    // Clip by global L2 norm
125                    let total_norm = compute_gradient_norm(gradients);
126
127                    if total_norm > clip_value {
128                        let scale = clip_value / total_norm;
129                        for grad in gradients.values_mut() {
130                            grad.mapv_inplace(|g| g * scale);
131                        }
132                    }
133                }
134            }
135        }
136    }
137}
138
139impl Optimizer for SgdOptimizer {
140    fn step(
141        &mut self,
142        parameters: &mut HashMap<String, Array<f64, Ix2>>,
143        gradients: &HashMap<String, Array<f64, Ix2>>,
144    ) -> TrainResult<()> {
145        let mut clipped_gradients = gradients.clone();
146        self.clip_gradients(&mut clipped_gradients);
147
148        for (name, param) in parameters.iter_mut() {
149            let grad = clipped_gradients.get(name).ok_or_else(|| {
150                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
151            })?;
152
153            // Initialize velocity if not present
154            if !self.velocity.contains_key(name) {
155                self.velocity
156                    .insert(name.clone(), Array::zeros(param.raw_dim()));
157            }
158
159            let velocity = self.velocity.get_mut(name).unwrap();
160
161            // Update velocity: v = momentum * v + lr * grad
162            velocity.mapv_inplace(|v| self.config.momentum * v);
163            *velocity = &*velocity + &(grad * self.config.learning_rate);
164
165            // Update parameter: param = param - velocity
166            *param = &*param - &*velocity;
167        }
168
169        Ok(())
170    }
171
172    fn zero_grad(&mut self) {
173        // Gradients are managed externally, nothing to do here
174    }
175
176    fn get_lr(&self) -> f64 {
177        self.config.learning_rate
178    }
179
180    fn set_lr(&mut self, lr: f64) {
181        self.config.learning_rate = lr;
182    }
183
184    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
185        let mut state = HashMap::new();
186        for (name, velocity) in &self.velocity {
187            state.insert(
188                format!("velocity_{}", name),
189                velocity.iter().copied().collect(),
190            );
191        }
192        state
193    }
194
195    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
196        for (key, values) in state {
197            if let Some(name) = key.strip_prefix("velocity_") {
198                // Reconstruct array from values (assumes correct shape)
199                if let Some(velocity) = self.velocity.get(name) {
200                    let shape = velocity.raw_dim();
201                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
202                        self.velocity.insert(name.to_string(), arr);
203                    }
204                }
205            }
206        }
207    }
208}
209
210/// Adam optimizer.
211#[derive(Debug)]
212pub struct AdamOptimizer {
213    config: OptimizerConfig,
214    /// First moment estimates (exponential moving average of gradients).
215    m: HashMap<String, Array<f64, Ix2>>,
216    /// Second moment estimates (exponential moving average of squared gradients).
217    v: HashMap<String, Array<f64, Ix2>>,
218    /// Timestep counter.
219    t: usize,
220}
221
222impl AdamOptimizer {
223    /// Create a new Adam optimizer.
224    pub fn new(config: OptimizerConfig) -> Self {
225        Self {
226            config,
227            m: HashMap::new(),
228            v: HashMap::new(),
229            t: 0,
230        }
231    }
232
233    /// Apply gradient clipping if configured.
234    fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
235        if let Some(clip_value) = self.config.grad_clip {
236            match self.config.grad_clip_mode {
237                GradClipMode::Value => {
238                    // Clip by value (element-wise)
239                    for grad in gradients.values_mut() {
240                        grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
241                    }
242                }
243                GradClipMode::Norm => {
244                    // Clip by global L2 norm
245                    let total_norm = compute_gradient_norm(gradients);
246
247                    if total_norm > clip_value {
248                        let scale = clip_value / total_norm;
249                        for grad in gradients.values_mut() {
250                            grad.mapv_inplace(|g| g * scale);
251                        }
252                    }
253                }
254            }
255        }
256    }
257}
258
259impl Optimizer for AdamOptimizer {
260    fn step(
261        &mut self,
262        parameters: &mut HashMap<String, Array<f64, Ix2>>,
263        gradients: &HashMap<String, Array<f64, Ix2>>,
264    ) -> TrainResult<()> {
265        let mut clipped_gradients = gradients.clone();
266        self.clip_gradients(&mut clipped_gradients);
267
268        self.t += 1;
269        let lr = self.config.learning_rate;
270        let beta1 = self.config.beta1;
271        let beta2 = self.config.beta2;
272        let eps = self.config.epsilon;
273
274        // Bias correction
275        let lr_t =
276            lr * ((1.0 - beta2.powi(self.t as i32)).sqrt()) / (1.0 - beta1.powi(self.t as i32));
277
278        for (name, param) in parameters.iter_mut() {
279            let grad = clipped_gradients.get(name).ok_or_else(|| {
280                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
281            })?;
282
283            // Initialize moments if not present
284            if !self.m.contains_key(name) {
285                self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
286                self.v.insert(name.clone(), Array::zeros(param.raw_dim()));
287            }
288
289            let m = self.m.get_mut(name).unwrap();
290            let v = self.v.get_mut(name).unwrap();
291
292            // Update biased first moment estimate: m = beta1 * m + (1 - beta1) * grad
293            *m = &*m * beta1 + &(grad * (1.0 - beta1));
294
295            // Update biased second raw moment estimate: v = beta2 * v + (1 - beta2) * grad^2
296            let grad_squared = grad.mapv(|g| g * g);
297            *v = &*v * beta2 + &(grad_squared * (1.0 - beta2));
298
299            // Update parameter: param = param - lr_t * m / (sqrt(v) + eps)
300            let update = m.mapv(|m_val| m_val * lr_t) / &v.mapv(|v_val| v_val.sqrt() + eps);
301            *param = &*param - &update;
302        }
303
304        Ok(())
305    }
306
307    fn zero_grad(&mut self) {
308        // Gradients are managed externally, nothing to do here
309    }
310
311    fn get_lr(&self) -> f64 {
312        self.config.learning_rate
313    }
314
315    fn set_lr(&mut self, lr: f64) {
316        self.config.learning_rate = lr;
317    }
318
319    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
320        let mut state = HashMap::new();
321        state.insert("t".to_string(), vec![self.t as f64]);
322
323        for (name, m_val) in &self.m {
324            state.insert(format!("m_{}", name), m_val.iter().copied().collect());
325        }
326        for (name, v_val) in &self.v {
327            state.insert(format!("v_{}", name), v_val.iter().copied().collect());
328        }
329        state
330    }
331
332    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
333        if let Some(t_vals) = state.get("t") {
334            self.t = t_vals[0] as usize;
335        }
336
337        for (key, values) in state {
338            if let Some(name) = key.strip_prefix("m_") {
339                if let Some(m) = self.m.get(name) {
340                    let shape = m.raw_dim();
341                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
342                        self.m.insert(name.to_string(), arr);
343                    }
344                }
345            } else if let Some(name) = key.strip_prefix("v_") {
346                if let Some(v) = self.v.get(name) {
347                    let shape = v.raw_dim();
348                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
349                        self.v.insert(name.to_string(), arr);
350                    }
351                }
352            }
353        }
354    }
355}
356
357/// AdamW optimizer (Adam with decoupled weight decay).
358#[derive(Debug)]
359pub struct AdamWOptimizer {
360    config: OptimizerConfig,
361    /// First moment estimates.
362    m: HashMap<String, Array<f64, Ix2>>,
363    /// Second moment estimates.
364    v: HashMap<String, Array<f64, Ix2>>,
365    /// Timestep counter.
366    t: usize,
367}
368
369impl AdamWOptimizer {
370    /// Create a new AdamW optimizer.
371    pub fn new(config: OptimizerConfig) -> Self {
372        Self {
373            config,
374            m: HashMap::new(),
375            v: HashMap::new(),
376            t: 0,
377        }
378    }
379
380    /// Apply gradient clipping if configured.
381    fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
382        if let Some(clip_value) = self.config.grad_clip {
383            match self.config.grad_clip_mode {
384                GradClipMode::Value => {
385                    // Clip by value (element-wise)
386                    for grad in gradients.values_mut() {
387                        grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
388                    }
389                }
390                GradClipMode::Norm => {
391                    // Clip by global L2 norm
392                    let total_norm = compute_gradient_norm(gradients);
393
394                    if total_norm > clip_value {
395                        let scale = clip_value / total_norm;
396                        for grad in gradients.values_mut() {
397                            grad.mapv_inplace(|g| g * scale);
398                        }
399                    }
400                }
401            }
402        }
403    }
404}
405
406impl Optimizer for AdamWOptimizer {
407    fn step(
408        &mut self,
409        parameters: &mut HashMap<String, Array<f64, Ix2>>,
410        gradients: &HashMap<String, Array<f64, Ix2>>,
411    ) -> TrainResult<()> {
412        let mut clipped_gradients = gradients.clone();
413        self.clip_gradients(&mut clipped_gradients);
414
415        self.t += 1;
416        let lr = self.config.learning_rate;
417        let beta1 = self.config.beta1;
418        let beta2 = self.config.beta2;
419        let eps = self.config.epsilon;
420        let weight_decay = self.config.weight_decay;
421
422        // Bias correction
423        let lr_t =
424            lr * ((1.0 - beta2.powi(self.t as i32)).sqrt()) / (1.0 - beta1.powi(self.t as i32));
425
426        for (name, param) in parameters.iter_mut() {
427            let grad = clipped_gradients.get(name).ok_or_else(|| {
428                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
429            })?;
430
431            // Initialize moments if not present
432            if !self.m.contains_key(name) {
433                self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
434                self.v.insert(name.clone(), Array::zeros(param.raw_dim()));
435            }
436
437            let m = self.m.get_mut(name).unwrap();
438            let v = self.v.get_mut(name).unwrap();
439
440            // Update biased first moment estimate
441            *m = &*m * beta1 + &(grad * (1.0 - beta1));
442
443            // Update biased second raw moment estimate
444            let grad_squared = grad.mapv(|g| g * g);
445            *v = &*v * beta2 + &(grad_squared * (1.0 - beta2));
446
447            // Compute Adam update
448            let update = m.mapv(|m_val| m_val * lr_t) / &v.mapv(|v_val| v_val.sqrt() + eps);
449
450            // Apply weight decay (decoupled from gradient)
451            let decay = param.mapv(|p| p * lr * weight_decay);
452
453            // Update parameter: param = param - update - decay
454            *param = &*param - &update - &decay;
455        }
456
457        Ok(())
458    }
459
460    fn zero_grad(&mut self) {
461        // Gradients are managed externally
462    }
463
464    fn get_lr(&self) -> f64 {
465        self.config.learning_rate
466    }
467
468    fn set_lr(&mut self, lr: f64) {
469        self.config.learning_rate = lr;
470    }
471
472    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
473        let mut state = HashMap::new();
474        state.insert("t".to_string(), vec![self.t as f64]);
475
476        for (name, m_val) in &self.m {
477            state.insert(format!("m_{}", name), m_val.iter().copied().collect());
478        }
479        for (name, v_val) in &self.v {
480            state.insert(format!("v_{}", name), v_val.iter().copied().collect());
481        }
482        state
483    }
484
485    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
486        if let Some(t_vals) = state.get("t") {
487            self.t = t_vals[0] as usize;
488        }
489
490        for (key, values) in state {
491            if let Some(name) = key.strip_prefix("m_") {
492                if let Some(m) = self.m.get(name) {
493                    let shape = m.raw_dim();
494                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
495                        self.m.insert(name.to_string(), arr);
496                    }
497                }
498            } else if let Some(name) = key.strip_prefix("v_") {
499                if let Some(v) = self.v.get(name) {
500                    let shape = v.raw_dim();
501                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
502                        self.v.insert(name.to_string(), arr);
503                    }
504                }
505            }
506        }
507    }
508}
509
510/// RMSprop optimizer (Root Mean Square Propagation).
511#[derive(Debug)]
512pub struct RMSpropOptimizer {
513    config: OptimizerConfig,
514    /// Moving average of squared gradients.
515    v: HashMap<String, Array<f64, Ix2>>,
516}
517
518impl RMSpropOptimizer {
519    /// Create a new RMSprop optimizer.
520    pub fn new(config: OptimizerConfig) -> Self {
521        Self {
522            config,
523            v: HashMap::new(),
524        }
525    }
526
527    /// Apply gradient clipping if configured.
528    fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
529        if let Some(clip_value) = self.config.grad_clip {
530            match self.config.grad_clip_mode {
531                GradClipMode::Value => {
532                    for grad in gradients.values_mut() {
533                        grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
534                    }
535                }
536                GradClipMode::Norm => {
537                    let total_norm = compute_gradient_norm(gradients);
538                    if total_norm > clip_value {
539                        let scale = clip_value / total_norm;
540                        for grad in gradients.values_mut() {
541                            grad.mapv_inplace(|g| g * scale);
542                        }
543                    }
544                }
545            }
546        }
547    }
548}
549
550impl Optimizer for RMSpropOptimizer {
551    fn step(
552        &mut self,
553        parameters: &mut HashMap<String, Array<f64, Ix2>>,
554        gradients: &HashMap<String, Array<f64, Ix2>>,
555    ) -> TrainResult<()> {
556        let mut clipped_gradients = gradients.clone();
557        self.clip_gradients(&mut clipped_gradients);
558
559        let lr = self.config.learning_rate;
560        let alpha = self.config.beta2; // Use beta2 as decay rate for RMSprop
561        let eps = self.config.epsilon;
562
563        for (name, param) in parameters.iter_mut() {
564            let grad = clipped_gradients.get(name).ok_or_else(|| {
565                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
566            })?;
567
568            // Initialize moving average if not present
569            if !self.v.contains_key(name) {
570                self.v.insert(name.clone(), Array::zeros(param.raw_dim()));
571            }
572
573            let v = self.v.get_mut(name).unwrap();
574
575            // Update moving average: v = alpha * v + (1 - alpha) * grad^2
576            let grad_squared = grad.mapv(|g| g * g);
577            *v = &*v * alpha + &(grad_squared * (1.0 - alpha));
578
579            // Update parameter: param = param - lr * grad / (sqrt(v) + eps)
580            let update = grad / &v.mapv(|v_val| v_val.sqrt() + eps);
581            *param = &*param - &(update * lr);
582        }
583
584        Ok(())
585    }
586
587    fn zero_grad(&mut self) {}
588
589    fn get_lr(&self) -> f64 {
590        self.config.learning_rate
591    }
592
593    fn set_lr(&mut self, lr: f64) {
594        self.config.learning_rate = lr;
595    }
596
597    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
598        let mut state = HashMap::new();
599        for (name, v_val) in &self.v {
600            state.insert(format!("v_{}", name), v_val.iter().copied().collect());
601        }
602        state
603    }
604
605    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
606        for (key, values) in state {
607            if let Some(name) = key.strip_prefix("v_") {
608                if let Some(v) = self.v.get(name) {
609                    let shape = v.raw_dim();
610                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
611                        self.v.insert(name.to_string(), arr);
612                    }
613                }
614            }
615        }
616    }
617}
618
619/// Adagrad optimizer (Adaptive Gradient).
620#[derive(Debug)]
621pub struct AdagradOptimizer {
622    config: OptimizerConfig,
623    /// Accumulated sum of squared gradients.
624    sum_squared_grads: HashMap<String, Array<f64, Ix2>>,
625}
626
627impl AdagradOptimizer {
628    /// Create a new Adagrad optimizer.
629    pub fn new(config: OptimizerConfig) -> Self {
630        Self {
631            config,
632            sum_squared_grads: HashMap::new(),
633        }
634    }
635
636    /// Apply gradient clipping if configured.
637    fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
638        if let Some(clip_value) = self.config.grad_clip {
639            match self.config.grad_clip_mode {
640                GradClipMode::Value => {
641                    for grad in gradients.values_mut() {
642                        grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
643                    }
644                }
645                GradClipMode::Norm => {
646                    let total_norm = compute_gradient_norm(gradients);
647                    if total_norm > clip_value {
648                        let scale = clip_value / total_norm;
649                        for grad in gradients.values_mut() {
650                            grad.mapv_inplace(|g| g * scale);
651                        }
652                    }
653                }
654            }
655        }
656    }
657}
658
659impl Optimizer for AdagradOptimizer {
660    fn step(
661        &mut self,
662        parameters: &mut HashMap<String, Array<f64, Ix2>>,
663        gradients: &HashMap<String, Array<f64, Ix2>>,
664    ) -> TrainResult<()> {
665        let mut clipped_gradients = gradients.clone();
666        self.clip_gradients(&mut clipped_gradients);
667
668        let lr = self.config.learning_rate;
669        let eps = self.config.epsilon;
670
671        for (name, param) in parameters.iter_mut() {
672            let grad = clipped_gradients.get(name).ok_or_else(|| {
673                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
674            })?;
675
676            // Initialize accumulated sum if not present
677            if !self.sum_squared_grads.contains_key(name) {
678                self.sum_squared_grads
679                    .insert(name.clone(), Array::zeros(param.raw_dim()));
680            }
681
682            let sum_sq = self.sum_squared_grads.get_mut(name).unwrap();
683
684            // Accumulate squared gradients: sum_sq = sum_sq + grad^2
685            let grad_squared = grad.mapv(|g| g * g);
686            *sum_sq = &*sum_sq + &grad_squared;
687
688            // Update parameter: param = param - lr * grad / (sqrt(sum_sq) + eps)
689            let update = grad / &sum_sq.mapv(|s| s.sqrt() + eps);
690            *param = &*param - &(update * lr);
691        }
692
693        Ok(())
694    }
695
696    fn zero_grad(&mut self) {}
697
698    fn get_lr(&self) -> f64 {
699        self.config.learning_rate
700    }
701
702    fn set_lr(&mut self, lr: f64) {
703        self.config.learning_rate = lr;
704    }
705
706    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
707        let mut state = HashMap::new();
708        for (name, sum_sq) in &self.sum_squared_grads {
709            state.insert(
710                format!("sum_squared_grads_{}", name),
711                sum_sq.iter().copied().collect(),
712            );
713        }
714        state
715    }
716
717    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
718        for (key, values) in state {
719            if let Some(name) = key.strip_prefix("sum_squared_grads_") {
720                if let Some(sum_sq) = self.sum_squared_grads.get(name) {
721                    let shape = sum_sq.raw_dim();
722                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
723                        self.sum_squared_grads.insert(name.to_string(), arr);
724                    }
725                }
726            }
727        }
728    }
729}
730
731/// NAdam optimizer (Nesterov-accelerated Adam).
732#[derive(Debug)]
733pub struct NAdamOptimizer {
734    config: OptimizerConfig,
735    /// First moment estimates.
736    m: HashMap<String, Array<f64, Ix2>>,
737    /// Second moment estimates.
738    v: HashMap<String, Array<f64, Ix2>>,
739    /// Timestep counter.
740    t: usize,
741}
742
743impl NAdamOptimizer {
744    /// Create a new NAdam optimizer.
745    pub fn new(config: OptimizerConfig) -> Self {
746        Self {
747            config,
748            m: HashMap::new(),
749            v: HashMap::new(),
750            t: 0,
751        }
752    }
753
754    /// Apply gradient clipping if configured.
755    fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
756        if let Some(clip_value) = self.config.grad_clip {
757            match self.config.grad_clip_mode {
758                GradClipMode::Value => {
759                    for grad in gradients.values_mut() {
760                        grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
761                    }
762                }
763                GradClipMode::Norm => {
764                    let total_norm = compute_gradient_norm(gradients);
765                    if total_norm > clip_value {
766                        let scale = clip_value / total_norm;
767                        for grad in gradients.values_mut() {
768                            grad.mapv_inplace(|g| g * scale);
769                        }
770                    }
771                }
772            }
773        }
774    }
775}
776
777impl Optimizer for NAdamOptimizer {
778    fn step(
779        &mut self,
780        parameters: &mut HashMap<String, Array<f64, Ix2>>,
781        gradients: &HashMap<String, Array<f64, Ix2>>,
782    ) -> TrainResult<()> {
783        let mut clipped_gradients = gradients.clone();
784        self.clip_gradients(&mut clipped_gradients);
785
786        self.t += 1;
787        let lr = self.config.learning_rate;
788        let beta1 = self.config.beta1;
789        let beta2 = self.config.beta2;
790        let eps = self.config.epsilon;
791
792        // Momentum schedule (schedule multiplier for beta1)
793        let mu_t = beta1 * (1.0 - 0.5 * 0.96_f64.powi(self.t as i32));
794        let mu_t_next = beta1 * (1.0 - 0.5 * 0.96_f64.powi((self.t + 1) as i32));
795
796        for (name, param) in parameters.iter_mut() {
797            let grad = clipped_gradients.get(name).ok_or_else(|| {
798                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
799            })?;
800
801            // Initialize moments if not present
802            if !self.m.contains_key(name) {
803                self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
804                self.v.insert(name.clone(), Array::zeros(param.raw_dim()));
805            }
806
807            let m = self.m.get_mut(name).unwrap();
808            let v = self.v.get_mut(name).unwrap();
809
810            // Update biased first moment estimate
811            *m = &*m * beta1 + &(grad * (1.0 - beta1));
812
813            // Update biased second moment estimate
814            let grad_squared = grad.mapv(|g| g * g);
815            *v = &*v * beta2 + &(grad_squared * (1.0 - beta2));
816
817            // Bias correction
818            let m_hat = &*m / (1.0 - beta1.powi(self.t as i32));
819            let v_hat = &*v / (1.0 - beta2.powi(self.t as i32));
820
821            // Nesterov momentum
822            let m_bar =
823                &m_hat * mu_t_next / (1.0 - mu_t_next) + &(grad * (1.0 - mu_t) / (1.0 - mu_t_next));
824
825            // Update parameter
826            let update = m_bar / &v_hat.mapv(|v_val| v_val.sqrt() + eps);
827            *param = &*param - &(update * lr);
828        }
829
830        Ok(())
831    }
832
833    fn zero_grad(&mut self) {}
834
835    fn get_lr(&self) -> f64 {
836        self.config.learning_rate
837    }
838
839    fn set_lr(&mut self, lr: f64) {
840        self.config.learning_rate = lr;
841    }
842
843    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
844        let mut state = HashMap::new();
845        state.insert("t".to_string(), vec![self.t as f64]);
846
847        for (name, m_val) in &self.m {
848            state.insert(format!("m_{}", name), m_val.iter().copied().collect());
849        }
850        for (name, v_val) in &self.v {
851            state.insert(format!("v_{}", name), v_val.iter().copied().collect());
852        }
853        state
854    }
855
856    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
857        if let Some(t_vals) = state.get("t") {
858            self.t = t_vals[0] as usize;
859        }
860
861        for (key, values) in state {
862            if let Some(name) = key.strip_prefix("m_") {
863                if let Some(m) = self.m.get(name) {
864                    let shape = m.raw_dim();
865                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
866                        self.m.insert(name.to_string(), arr);
867                    }
868                }
869            } else if let Some(name) = key.strip_prefix("v_") {
870                if let Some(v) = self.v.get(name) {
871                    let shape = v.raw_dim();
872                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
873                        self.v.insert(name.to_string(), arr);
874                    }
875                }
876            }
877        }
878    }
879}
880
881/// LAMB optimizer (Layer-wise Adaptive Moments optimizer for Batch training).
882/// Designed for large batch training, uses layer-wise adaptation.
883#[derive(Debug)]
884pub struct LambOptimizer {
885    config: OptimizerConfig,
886    /// First moment estimates.
887    m: HashMap<String, Array<f64, Ix2>>,
888    /// Second moment estimates.
889    v: HashMap<String, Array<f64, Ix2>>,
890    /// Timestep counter.
891    t: usize,
892}
893
894impl LambOptimizer {
895    /// Create a new LAMB optimizer.
896    pub fn new(config: OptimizerConfig) -> Self {
897        Self {
898            config,
899            m: HashMap::new(),
900            v: HashMap::new(),
901            t: 0,
902        }
903    }
904
905    /// Apply gradient clipping if configured.
906    fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
907        if let Some(clip_value) = self.config.grad_clip {
908            match self.config.grad_clip_mode {
909                GradClipMode::Value => {
910                    for grad in gradients.values_mut() {
911                        grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
912                    }
913                }
914                GradClipMode::Norm => {
915                    let total_norm = compute_gradient_norm(gradients);
916                    if total_norm > clip_value {
917                        let scale = clip_value / total_norm;
918                        for grad in gradients.values_mut() {
919                            grad.mapv_inplace(|g| g * scale);
920                        }
921                    }
922                }
923            }
924        }
925    }
926
927    /// Compute L2 norm of an array.
928    fn compute_norm(arr: &Array<f64, Ix2>) -> f64 {
929        arr.iter().map(|&x| x * x).sum::<f64>().sqrt()
930    }
931}
932
933impl Optimizer for LambOptimizer {
934    fn step(
935        &mut self,
936        parameters: &mut HashMap<String, Array<f64, Ix2>>,
937        gradients: &HashMap<String, Array<f64, Ix2>>,
938    ) -> TrainResult<()> {
939        let mut clipped_gradients = gradients.clone();
940        self.clip_gradients(&mut clipped_gradients);
941
942        self.t += 1;
943        let lr = self.config.learning_rate;
944        let beta1 = self.config.beta1;
945        let beta2 = self.config.beta2;
946        let eps = self.config.epsilon;
947        let weight_decay = self.config.weight_decay;
948
949        for (name, param) in parameters.iter_mut() {
950            let grad = clipped_gradients.get(name).ok_or_else(|| {
951                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
952            })?;
953
954            // Initialize moments if not present
955            if !self.m.contains_key(name) {
956                self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
957                self.v.insert(name.clone(), Array::zeros(param.raw_dim()));
958            }
959
960            let m = self.m.get_mut(name).unwrap();
961            let v = self.v.get_mut(name).unwrap();
962
963            // Update biased first moment estimate
964            *m = &*m * beta1 + &(grad * (1.0 - beta1));
965
966            // Update biased second moment estimate
967            let grad_squared = grad.mapv(|g| g * g);
968            *v = &*v * beta2 + &(grad_squared * (1.0 - beta2));
969
970            // Bias correction
971            let m_hat = &*m / (1.0 - beta1.powi(self.t as i32));
972            let v_hat = &*v / (1.0 - beta2.powi(self.t as i32));
973
974            // Compute Adam step (without weight decay)
975            let adam_step = &m_hat / &v_hat.mapv(|v_val| v_val.sqrt() + eps);
976
977            // Add weight decay
978            let update = &adam_step + &param.mapv(|p| p * weight_decay);
979
980            // Layer-wise adaptation: compute trust ratio
981            let param_norm = Self::compute_norm(param);
982            let update_norm = Self::compute_norm(&update);
983
984            let trust_ratio = if param_norm > 0.0 && update_norm > 0.0 {
985                param_norm / update_norm
986            } else {
987                1.0
988            };
989
990            // Update parameter with layer-wise adapted learning rate
991            *param = &*param - &(update * (lr * trust_ratio));
992        }
993
994        Ok(())
995    }
996
997    fn zero_grad(&mut self) {}
998
999    fn get_lr(&self) -> f64 {
1000        self.config.learning_rate
1001    }
1002
1003    fn set_lr(&mut self, lr: f64) {
1004        self.config.learning_rate = lr;
1005    }
1006
1007    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
1008        let mut state = HashMap::new();
1009        state.insert("t".to_string(), vec![self.t as f64]);
1010
1011        for (name, m_val) in &self.m {
1012            state.insert(format!("m_{}", name), m_val.iter().copied().collect());
1013        }
1014        for (name, v_val) in &self.v {
1015            state.insert(format!("v_{}", name), v_val.iter().copied().collect());
1016        }
1017        state
1018    }
1019
1020    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
1021        if let Some(t_vals) = state.get("t") {
1022            self.t = t_vals[0] as usize;
1023        }
1024
1025        for (key, values) in state {
1026            if let Some(name) = key.strip_prefix("m_") {
1027                if let Some(m) = self.m.get(name) {
1028                    let shape = m.raw_dim();
1029                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
1030                        self.m.insert(name.to_string(), arr);
1031                    }
1032                }
1033            } else if let Some(name) = key.strip_prefix("v_") {
1034                if let Some(v) = self.v.get(name) {
1035                    let shape = v.raw_dim();
1036                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
1037                        self.v.insert(name.to_string(), arr);
1038                    }
1039                }
1040            }
1041        }
1042    }
1043}
1044
1045/// AdaMax optimizer (variant of Adam with infinity norm).
1046///
1047/// Uses the infinity norm of gradients instead of L2 norm, making it more robust
1048/// to large gradients and outliers.
1049///
1050/// Reference: Kingma & Ba, "Adam: A Method for Stochastic Optimization", ICLR 2015
1051#[derive(Debug)]
1052pub struct AdaMaxOptimizer {
1053    config: OptimizerConfig,
1054    /// First moment estimates (exponential moving average of gradients).
1055    m: HashMap<String, Array<f64, Ix2>>,
1056    /// Exponentially weighted infinity norm.
1057    u: HashMap<String, Array<f64, Ix2>>,
1058    /// Timestep counter.
1059    t: usize,
1060}
1061
1062impl AdaMaxOptimizer {
1063    /// Create a new AdaMax optimizer.
1064    pub fn new(config: OptimizerConfig) -> Self {
1065        Self {
1066            config,
1067            m: HashMap::new(),
1068            u: HashMap::new(),
1069            t: 0,
1070        }
1071    }
1072
1073    /// Apply gradient clipping if configured.
1074    fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
1075        if let Some(clip_value) = self.config.grad_clip {
1076            match self.config.grad_clip_mode {
1077                GradClipMode::Value => {
1078                    for grad in gradients.values_mut() {
1079                        grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
1080                    }
1081                }
1082                GradClipMode::Norm => {
1083                    let total_norm = compute_gradient_norm(gradients);
1084                    if total_norm > clip_value {
1085                        let scale = clip_value / total_norm;
1086                        for grad in gradients.values_mut() {
1087                            grad.mapv_inplace(|g| g * scale);
1088                        }
1089                    }
1090                }
1091            }
1092        }
1093    }
1094}
1095
1096impl Optimizer for AdaMaxOptimizer {
1097    fn step(
1098        &mut self,
1099        parameters: &mut HashMap<String, Array<f64, Ix2>>,
1100        gradients: &HashMap<String, Array<f64, Ix2>>,
1101    ) -> TrainResult<()> {
1102        let mut clipped_gradients = gradients.clone();
1103        self.clip_gradients(&mut clipped_gradients);
1104
1105        self.t += 1;
1106        let lr = self.config.learning_rate;
1107        let beta1 = self.config.beta1;
1108        let beta2 = self.config.beta2;
1109
1110        for (name, param) in parameters.iter_mut() {
1111            let grad = clipped_gradients.get(name).ok_or_else(|| {
1112                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
1113            })?;
1114
1115            // Initialize moments if not present
1116            if !self.m.contains_key(name) {
1117                self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
1118                self.u.insert(name.clone(), Array::zeros(param.raw_dim()));
1119            }
1120
1121            let m = self.m.get_mut(name).unwrap();
1122            let u = self.u.get_mut(name).unwrap();
1123
1124            // Update biased first moment estimate: m = beta1 * m + (1 - beta1) * grad
1125            *m = &*m * beta1 + &(grad * (1.0 - beta1));
1126
1127            // Update exponentially weighted infinity norm: u = max(beta2 * u, |grad|)
1128            for i in 0..u.nrows() {
1129                for j in 0..u.ncols() {
1130                    u[[i, j]] = (beta2 * u[[i, j]]).max(grad[[i, j]].abs());
1131                }
1132            }
1133
1134            // Bias correction for first moment
1135            let bias_correction = 1.0 - beta1.powi(self.t as i32);
1136            let lr_t = lr / bias_correction;
1137
1138            // Update parameter: param = param - lr_t * m / u
1139            for i in 0..param.nrows() {
1140                for j in 0..param.ncols() {
1141                    let update = lr_t * m[[i, j]] / (u[[i, j]] + self.config.epsilon);
1142                    param[[i, j]] -= update;
1143                }
1144            }
1145        }
1146
1147        Ok(())
1148    }
1149
1150    fn zero_grad(&mut self) {}
1151
1152    fn get_lr(&self) -> f64 {
1153        self.config.learning_rate
1154    }
1155
1156    fn set_lr(&mut self, lr: f64) {
1157        self.config.learning_rate = lr;
1158    }
1159
1160    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
1161        let mut state = HashMap::new();
1162        state.insert("t".to_string(), vec![self.t as f64]);
1163
1164        for (name, m_val) in &self.m {
1165            state.insert(format!("m_{}", name), m_val.iter().copied().collect());
1166        }
1167        for (name, u_val) in &self.u {
1168            state.insert(format!("u_{}", name), u_val.iter().copied().collect());
1169        }
1170        state
1171    }
1172
1173    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
1174        if let Some(t_vals) = state.get("t") {
1175            self.t = t_vals[0] as usize;
1176        }
1177
1178        for (key, values) in state {
1179            if let Some(name) = key.strip_prefix("m_") {
1180                if let Some(m) = self.m.get(name) {
1181                    let shape = m.raw_dim();
1182                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
1183                        self.m.insert(name.to_string(), arr);
1184                    }
1185                }
1186            } else if let Some(name) = key.strip_prefix("u_") {
1187                if let Some(u) = self.u.get(name) {
1188                    let shape = u.raw_dim();
1189                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
1190                        self.u.insert(name.to_string(), arr);
1191                    }
1192                }
1193            }
1194        }
1195    }
1196}
1197
1198/// Lookahead optimizer (wrapper that uses slow and fast weights).
1199///
1200/// Maintains two sets of weights: fast weights updated by an inner optimizer,
1201/// and slow weights that are periodically updated as an exponential moving average.
1202///
1203/// Reference: Zhang et al., "Lookahead Optimizer: k steps forward, 1 step back", NeurIPS 2019
1204#[derive(Debug)]
1205pub struct LookaheadOptimizer<O: Optimizer> {
1206    /// Inner optimizer for fast weights.
1207    inner_optimizer: O,
1208    /// Slow weights (maintained separately).
1209    slow_weights: HashMap<String, Array<f64, Ix2>>,
1210    /// Interpolation coefficient (typically 0.5).
1211    alpha: f64,
1212    /// Number of inner optimizer steps before synchronization.
1213    k: usize,
1214    /// Current step counter.
1215    step_counter: usize,
1216}
1217
1218impl<O: Optimizer> LookaheadOptimizer<O> {
1219    /// Create a new Lookahead optimizer.
1220    ///
1221    /// # Arguments
1222    /// * `inner_optimizer` - The inner optimizer (e.g., Adam, SGD)
1223    /// * `alpha` - Interpolation coefficient for slow weight update (typically 0.5)
1224    /// * `k` - Number of fast updates before slow weight synchronization (typically 5-10)
1225    pub fn new(inner_optimizer: O, alpha: f64, k: usize) -> TrainResult<Self> {
1226        if !(0.0..=1.0).contains(&alpha) {
1227            return Err(TrainError::InvalidParameter(
1228                "alpha must be in [0, 1]".to_string(),
1229            ));
1230        }
1231        if k == 0 {
1232            return Err(TrainError::InvalidParameter(
1233                "k must be at least 1".to_string(),
1234            ));
1235        }
1236
1237        Ok(Self {
1238            inner_optimizer,
1239            slow_weights: HashMap::new(),
1240            alpha,
1241            k,
1242            step_counter: 0,
1243        })
1244    }
1245
1246    /// Initialize slow weights from current parameters.
1247    fn initialize_slow_weights(&mut self, parameters: &HashMap<String, Array<f64, Ix2>>) {
1248        if self.slow_weights.is_empty() {
1249            for (name, param) in parameters {
1250                self.slow_weights.insert(name.clone(), param.clone());
1251            }
1252        }
1253    }
1254
1255    /// Synchronize slow weights with fast weights.
1256    fn synchronize_weights(&mut self, parameters: &mut HashMap<String, Array<f64, Ix2>>) {
1257        for (name, param) in parameters.iter_mut() {
1258            if let Some(slow_weight) = self.slow_weights.get_mut(name) {
1259                // Update slow weights: slow = slow + alpha * (fast - slow)
1260                *slow_weight = &*slow_weight + &((&*param - &*slow_weight) * self.alpha);
1261
1262                // Update fast weights to slow weights
1263                *param = slow_weight.clone();
1264            }
1265        }
1266    }
1267}
1268
1269impl<O: Optimizer> Optimizer for LookaheadOptimizer<O> {
1270    fn step(
1271        &mut self,
1272        parameters: &mut HashMap<String, Array<f64, Ix2>>,
1273        gradients: &HashMap<String, Array<f64, Ix2>>,
1274    ) -> TrainResult<()> {
1275        // Initialize slow weights on first step
1276        self.initialize_slow_weights(parameters);
1277
1278        // Perform fast weight update using inner optimizer
1279        self.inner_optimizer.step(parameters, gradients)?;
1280
1281        self.step_counter += 1;
1282
1283        // Synchronize every k steps
1284        if self.step_counter.is_multiple_of(self.k) {
1285            self.synchronize_weights(parameters);
1286        }
1287
1288        Ok(())
1289    }
1290
1291    fn zero_grad(&mut self) {
1292        self.inner_optimizer.zero_grad();
1293    }
1294
1295    fn get_lr(&self) -> f64 {
1296        self.inner_optimizer.get_lr()
1297    }
1298
1299    fn set_lr(&mut self, lr: f64) {
1300        self.inner_optimizer.set_lr(lr);
1301    }
1302
1303    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
1304        let mut state = self.inner_optimizer.state_dict();
1305
1306        // Add lookahead-specific state
1307        state.insert("step_counter".to_string(), vec![self.step_counter as f64]);
1308        state.insert("alpha".to_string(), vec![self.alpha]);
1309        state.insert("k".to_string(), vec![self.k as f64]);
1310
1311        for (name, slow_weight) in &self.slow_weights {
1312            state.insert(
1313                format!("slow_{}", name),
1314                slow_weight.iter().copied().collect(),
1315            );
1316        }
1317
1318        state
1319    }
1320
1321    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
1322        // Load inner optimizer state
1323        self.inner_optimizer.load_state_dict(state.clone());
1324
1325        // Load lookahead-specific state
1326        if let Some(counter) = state.get("step_counter") {
1327            self.step_counter = counter[0] as usize;
1328        }
1329        if let Some(alpha_val) = state.get("alpha") {
1330            self.alpha = alpha_val[0];
1331        }
1332        if let Some(k_val) = state.get("k") {
1333            self.k = k_val[0] as usize;
1334        }
1335
1336        // Load slow weights
1337        for (key, values) in state {
1338            if let Some(name) = key.strip_prefix("slow_") {
1339                if let Some(slow_weight) = self.slow_weights.get(name) {
1340                    let shape = slow_weight.raw_dim();
1341                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
1342                        self.slow_weights.insert(name.to_string(), arr);
1343                    }
1344                }
1345            }
1346        }
1347    }
1348}
1349
1350/// AdaBelief optimizer (NeurIPS 2020).
1351///
1352/// AdaBelief adapts the step size according to the "belief" in the gradient direction.
1353/// It uses the variance of gradients (belief) to adapt the learning rate, which can
1354/// achieve faster convergence and better generalization than Adam/AdamW.
1355///
1356/// Reference: Zhuang et al. "AdaBelief Optimizer: Adapting Stepsizes by the Belief
1357/// in Observed Gradients" (NeurIPS 2020)
1358#[derive(Debug)]
1359pub struct AdaBeliefOptimizer {
1360    config: OptimizerConfig,
1361    /// First moment estimates (exponential moving average of gradients).
1362    m: HashMap<String, Array<f64, Ix2>>,
1363    /// Second moment estimates (variance of gradients).
1364    s: HashMap<String, Array<f64, Ix2>>,
1365    /// Timestep counter.
1366    t: usize,
1367}
1368
1369impl AdaBeliefOptimizer {
1370    /// Create a new AdaBelief optimizer.
1371    pub fn new(config: OptimizerConfig) -> Self {
1372        Self {
1373            config,
1374            m: HashMap::new(),
1375            s: HashMap::new(),
1376            t: 0,
1377        }
1378    }
1379
1380    /// Apply gradient clipping if configured.
1381    fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
1382        if let Some(clip_value) = self.config.grad_clip {
1383            match self.config.grad_clip_mode {
1384                GradClipMode::Value => {
1385                    for grad in gradients.values_mut() {
1386                        grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
1387                    }
1388                }
1389                GradClipMode::Norm => {
1390                    let total_norm = compute_gradient_norm(gradients);
1391                    if total_norm > clip_value {
1392                        let scale = clip_value / total_norm;
1393                        for grad in gradients.values_mut() {
1394                            grad.mapv_inplace(|g| g * scale);
1395                        }
1396                    }
1397                }
1398            }
1399        }
1400    }
1401}
1402
1403impl Optimizer for AdaBeliefOptimizer {
1404    fn step(
1405        &mut self,
1406        parameters: &mut HashMap<String, Array<f64, Ix2>>,
1407        gradients: &HashMap<String, Array<f64, Ix2>>,
1408    ) -> TrainResult<()> {
1409        let mut clipped_gradients = gradients.clone();
1410        self.clip_gradients(&mut clipped_gradients);
1411
1412        self.t += 1;
1413        let lr = self.config.learning_rate;
1414        let beta1 = self.config.beta1;
1415        let beta2 = self.config.beta2;
1416        let eps = self.config.epsilon;
1417        let weight_decay = self.config.weight_decay;
1418
1419        // Bias correction
1420        let bias_correction1 = 1.0 - beta1.powi(self.t as i32);
1421        let bias_correction2 = 1.0 - beta2.powi(self.t as i32);
1422
1423        for (name, param) in parameters.iter_mut() {
1424            let grad = clipped_gradients.get(name).ok_or_else(|| {
1425                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
1426            })?;
1427
1428            // Initialize moments if not present
1429            if !self.m.contains_key(name) {
1430                self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
1431                self.s.insert(name.clone(), Array::zeros(param.raw_dim()));
1432            }
1433
1434            let m = self.m.get_mut(name).unwrap();
1435            let s = self.s.get_mut(name).unwrap();
1436
1437            // Update first moment: m = beta1 * m + (1 - beta1) * grad
1438            *m = &*m * beta1 + &(grad * (1.0 - beta1));
1439
1440            // Compute gradient prediction error: (grad - m)
1441            let grad_diff = grad - &*m;
1442
1443            // Update second moment (variance): s = beta2 * s + (1 - beta2) * (grad - m)^2
1444            let grad_diff_squared = grad_diff.mapv(|g| g * g);
1445            *s = &*s * beta2 + &(grad_diff_squared * (1.0 - beta2));
1446
1447            // Bias-corrected moments
1448            let m_hat = &*m / bias_correction1;
1449            let s_hat = &*s / bias_correction2;
1450
1451            // Weight decay (AdamW-style decoupled weight decay)
1452            if weight_decay > 0.0 {
1453                param.mapv_inplace(|p| p * (1.0 - lr * weight_decay));
1454            }
1455
1456            // Update parameter: param = param - lr * m_hat / (sqrt(s_hat) + eps)
1457            let update = m_hat / (s_hat.mapv(|v| v.sqrt()) + eps);
1458            *param = &*param - &(update * lr);
1459        }
1460
1461        Ok(())
1462    }
1463
1464    fn zero_grad(&mut self) {
1465        // Gradients are managed externally
1466    }
1467
1468    fn get_lr(&self) -> f64 {
1469        self.config.learning_rate
1470    }
1471
1472    fn set_lr(&mut self, lr: f64) {
1473        self.config.learning_rate = lr;
1474    }
1475
1476    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
1477        let mut state = HashMap::new();
1478        state.insert("t".to_string(), vec![self.t as f64]);
1479
1480        for (name, m_val) in &self.m {
1481            state.insert(format!("m_{}", name), m_val.iter().copied().collect());
1482        }
1483        for (name, s_val) in &self.s {
1484            state.insert(format!("s_{}", name), s_val.iter().copied().collect());
1485        }
1486
1487        state
1488    }
1489
1490    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
1491        if let Some(t_val) = state.get("t") {
1492            self.t = t_val[0] as usize;
1493        }
1494
1495        for (key, values) in state {
1496            if let Some(name) = key.strip_prefix("m_") {
1497                if let Some(m_array) = self.m.get(name) {
1498                    let shape = m_array.raw_dim();
1499                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
1500                        self.m.insert(name.to_string(), arr);
1501                    }
1502                }
1503            } else if let Some(name) = key.strip_prefix("s_") {
1504                if let Some(s_array) = self.s.get(name) {
1505                    let shape = s_array.raw_dim();
1506                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
1507                        self.s.insert(name.to_string(), arr);
1508                    }
1509                }
1510            }
1511        }
1512    }
1513}
1514
1515/// RAdam optimizer (Rectified Adam) with variance warmup (ICLR 2020).
1516///
1517/// RAdam addresses the bad convergence problem of Adam in the early stages
1518/// by rectifying the variance of the adaptive learning rate. It provides
1519/// a variance warmup mechanism that stabilizes training.
1520///
1521/// Reference: Liu et al. "On the Variance of the Adaptive Learning Rate and Beyond" (ICLR 2020)
1522#[derive(Debug)]
1523pub struct RAdamOptimizer {
1524    config: OptimizerConfig,
1525    /// First moment estimates.
1526    m: HashMap<String, Array<f64, Ix2>>,
1527    /// Second moment estimates.
1528    v: HashMap<String, Array<f64, Ix2>>,
1529    /// Timestep counter.
1530    t: usize,
1531}
1532
1533impl RAdamOptimizer {
1534    /// Create a new RAdam optimizer.
1535    pub fn new(config: OptimizerConfig) -> Self {
1536        Self {
1537            config,
1538            m: HashMap::new(),
1539            v: HashMap::new(),
1540            t: 0,
1541        }
1542    }
1543
1544    /// Apply gradient clipping if configured.
1545    fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
1546        if let Some(clip_value) = self.config.grad_clip {
1547            match self.config.grad_clip_mode {
1548                GradClipMode::Value => {
1549                    for grad in gradients.values_mut() {
1550                        grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
1551                    }
1552                }
1553                GradClipMode::Norm => {
1554                    let total_norm = compute_gradient_norm(gradients);
1555                    if total_norm > clip_value {
1556                        let scale = clip_value / total_norm;
1557                        for grad in gradients.values_mut() {
1558                            grad.mapv_inplace(|g| g * scale);
1559                        }
1560                    }
1561                }
1562            }
1563        }
1564    }
1565
1566    /// Compute the variance rectification term.
1567    fn compute_rectification(&self) -> (bool, f64) {
1568        let beta2 = self.config.beta2;
1569        let t = self.t as f64;
1570
1571        // Maximum length of the approximated SMA (Simple Moving Average)
1572        let rho_inf = 2.0 / (1.0 - beta2) - 1.0;
1573
1574        // Length of the approximated SMA at timestep t
1575        let rho_t = rho_inf - 2.0 * t * beta2.powf(t) / (1.0 - beta2.powf(t));
1576
1577        // Check if variance is tractable
1578        if rho_t > 5.0 {
1579            // Compute rectification term
1580            let rect = ((rho_t - 4.0) * (rho_t - 2.0) * rho_inf)
1581                / ((rho_inf - 4.0) * (rho_inf - 2.0) * rho_t);
1582            (true, rect.sqrt())
1583        } else {
1584            // Variance is not tractable, use un-adapted update (like SGD with momentum)
1585            (false, 0.0)
1586        }
1587    }
1588}
1589
1590impl Optimizer for RAdamOptimizer {
1591    fn step(
1592        &mut self,
1593        parameters: &mut HashMap<String, Array<f64, Ix2>>,
1594        gradients: &HashMap<String, Array<f64, Ix2>>,
1595    ) -> TrainResult<()> {
1596        let mut clipped_gradients = gradients.clone();
1597        self.clip_gradients(&mut clipped_gradients);
1598
1599        self.t += 1;
1600        let lr = self.config.learning_rate;
1601        let beta1 = self.config.beta1;
1602        let beta2 = self.config.beta2;
1603        let eps = self.config.epsilon;
1604
1605        // Bias correction
1606        let bias_correction1 = 1.0 - beta1.powi(self.t as i32);
1607
1608        // Compute variance rectification
1609        let (use_adaptive, rect) = self.compute_rectification();
1610
1611        for (name, param) in parameters.iter_mut() {
1612            let grad = clipped_gradients.get(name).ok_or_else(|| {
1613                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
1614            })?;
1615
1616            // Initialize moments if not present
1617            if !self.m.contains_key(name) {
1618                self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
1619                self.v.insert(name.clone(), Array::zeros(param.raw_dim()));
1620            }
1621
1622            let m = self.m.get_mut(name).unwrap();
1623            let v = self.v.get_mut(name).unwrap();
1624
1625            // Update first moment: m = beta1 * m + (1 - beta1) * grad
1626            *m = &*m * beta1 + &(grad * (1.0 - beta1));
1627
1628            // Update second moment: v = beta2 * v + (1 - beta2) * grad^2
1629            let grad_squared = grad.mapv(|g| g * g);
1630            *v = &*v * beta2 + &(grad_squared * (1.0 - beta2));
1631
1632            // Bias-corrected first moment
1633            let m_hat = &*m / bias_correction1;
1634
1635            if use_adaptive {
1636                // Use adaptive learning rate with rectification
1637                let bias_correction2 = 1.0 - beta2.powi(self.t as i32);
1638                let v_hat = &*v / bias_correction2;
1639
1640                // Update with rectified variance
1641                let update = m_hat / (v_hat.mapv(|val| val.sqrt()) + eps);
1642                *param = &*param - &(update * (lr * rect));
1643            } else {
1644                // Early phase: use non-adaptive update (SGD with momentum)
1645                *param = &*param - &(m_hat * lr);
1646            }
1647        }
1648
1649        Ok(())
1650    }
1651
1652    fn zero_grad(&mut self) {
1653        // Gradients are managed externally
1654    }
1655
1656    fn get_lr(&self) -> f64 {
1657        self.config.learning_rate
1658    }
1659
1660    fn set_lr(&mut self, lr: f64) {
1661        self.config.learning_rate = lr;
1662    }
1663
1664    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
1665        let mut state = HashMap::new();
1666        state.insert("t".to_string(), vec![self.t as f64]);
1667
1668        for (name, m_val) in &self.m {
1669            state.insert(format!("m_{}", name), m_val.iter().copied().collect());
1670        }
1671        for (name, v_val) in &self.v {
1672            state.insert(format!("v_{}", name), v_val.iter().copied().collect());
1673        }
1674
1675        state
1676    }
1677
1678    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
1679        if let Some(t_val) = state.get("t") {
1680            self.t = t_val[0] as usize;
1681        }
1682
1683        for (key, values) in state {
1684            if let Some(name) = key.strip_prefix("m_") {
1685                if let Some(m_array) = self.m.get(name) {
1686                    let shape = m_array.raw_dim();
1687                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
1688                        self.m.insert(name.to_string(), arr);
1689                    }
1690                }
1691            } else if let Some(name) = key.strip_prefix("v_") {
1692                if let Some(v_array) = self.v.get(name) {
1693                    let shape = v_array.raw_dim();
1694                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
1695                        self.v.insert(name.to_string(), arr);
1696                    }
1697                }
1698            }
1699        }
1700    }
1701}
1702
1703/// LARS optimizer (Layer-wise Adaptive Rate Scaling).
1704///
1705/// LARS scales the learning rate for each layer based on the ratio of the parameter norm
1706/// to the gradient norm. This is particularly effective for large batch training.
1707///
1708/// Reference: You et al. "Large Batch Training of Convolutional Networks" (2017)
1709#[derive(Debug)]
1710pub struct LarsOptimizer {
1711    config: OptimizerConfig,
1712    /// Momentum buffers for each parameter.
1713    velocity: HashMap<String, Array<f64, Ix2>>,
1714    /// Trust coefficient for layer-wise LR adaptation (typically 0.001).
1715    trust_coef: f64,
1716    /// Whether to apply LARS to bias parameters.
1717    exclude_bias: bool,
1718}
1719
1720impl LarsOptimizer {
1721    /// Create a new LARS optimizer.
1722    ///
1723    /// # Arguments
1724    /// * `config` - Optimizer configuration
1725    /// * `trust_coef` - Trust coefficient for adaptive LR (default: 0.001)
1726    /// * `exclude_bias` - Whether to exclude bias from LARS adaptation (default: true)
1727    pub fn new(config: OptimizerConfig, trust_coef: f64, exclude_bias: bool) -> Self {
1728        Self {
1729            config,
1730            velocity: HashMap::new(),
1731            trust_coef,
1732            exclude_bias,
1733        }
1734    }
1735
1736    /// Apply gradient clipping if configured.
1737    fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
1738        if let Some(clip_value) = self.config.grad_clip {
1739            match self.config.grad_clip_mode {
1740                GradClipMode::Value => {
1741                    for grad in gradients.values_mut() {
1742                        grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
1743                    }
1744                }
1745                GradClipMode::Norm => {
1746                    let total_norm = compute_gradient_norm(gradients);
1747                    if total_norm > clip_value {
1748                        let scale = clip_value / total_norm;
1749                        for grad in gradients.values_mut() {
1750                            grad.mapv_inplace(|g| g * scale);
1751                        }
1752                    }
1753                }
1754            }
1755        }
1756    }
1757
1758    /// Compute layer-wise adaptive learning rate.
1759    fn compute_adaptive_lr(
1760        &self,
1761        param: &Array<f64, Ix2>,
1762        grad: &Array<f64, Ix2>,
1763        name: &str,
1764    ) -> f64 {
1765        // Skip LARS for bias if configured
1766        if self.exclude_bias && (name.contains("bias") || name.contains("b")) {
1767            return self.config.learning_rate;
1768        }
1769
1770        // Compute parameter norm
1771        let param_norm: f64 = param.iter().map(|&p| p * p).sum::<f64>().sqrt();
1772
1773        // Compute gradient norm
1774        let grad_norm: f64 = grad.iter().map(|&g| g * g).sum::<f64>().sqrt();
1775
1776        // Avoid division by zero
1777        if param_norm == 0.0 || grad_norm == 0.0 {
1778            return self.config.learning_rate;
1779        }
1780
1781        // Compute layer-wise LR: trust_coef * ||param|| / ||grad||
1782        let local_lr = self.trust_coef * param_norm / grad_norm;
1783
1784        // Return base LR * local LR
1785        self.config.learning_rate * local_lr
1786    }
1787}
1788
1789impl Optimizer for LarsOptimizer {
1790    fn step(
1791        &mut self,
1792        parameters: &mut HashMap<String, Array<f64, Ix2>>,
1793        gradients: &HashMap<String, Array<f64, Ix2>>,
1794    ) -> TrainResult<()> {
1795        let mut clipped_gradients = gradients.clone();
1796        self.clip_gradients(&mut clipped_gradients);
1797
1798        for (name, param) in parameters.iter_mut() {
1799            let grad = clipped_gradients.get(name).ok_or_else(|| {
1800                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
1801            })?;
1802
1803            // Compute layer-wise adaptive learning rate (before borrowing velocity)
1804            let adaptive_lr = self.compute_adaptive_lr(param, grad, name);
1805
1806            // Weight decay
1807            let mut effective_grad = grad.clone();
1808            if self.config.weight_decay > 0.0 {
1809                effective_grad += &(&*param * self.config.weight_decay);
1810            }
1811
1812            // Initialize velocity if not present
1813            if !self.velocity.contains_key(name) {
1814                self.velocity
1815                    .insert(name.clone(), Array::zeros(param.raw_dim()));
1816            }
1817
1818            let velocity = self.velocity.get_mut(name).unwrap();
1819
1820            // Update velocity with LARS-adapted LR: v = momentum * v + adaptive_lr * grad
1821            velocity.mapv_inplace(|v| self.config.momentum * v);
1822            *velocity = &*velocity + &(effective_grad * adaptive_lr);
1823
1824            // Update parameter: param = param - velocity
1825            *param = &*param - &*velocity;
1826        }
1827
1828        Ok(())
1829    }
1830
1831    fn zero_grad(&mut self) {
1832        // Gradients are managed externally
1833    }
1834
1835    fn get_lr(&self) -> f64 {
1836        self.config.learning_rate
1837    }
1838
1839    fn set_lr(&mut self, lr: f64) {
1840        self.config.learning_rate = lr;
1841    }
1842
1843    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
1844        let mut state = HashMap::new();
1845        state.insert("trust_coef".to_string(), vec![self.trust_coef]);
1846        state.insert(
1847            "exclude_bias".to_string(),
1848            vec![if self.exclude_bias { 1.0 } else { 0.0 }],
1849        );
1850
1851        for (name, velocity) in &self.velocity {
1852            state.insert(
1853                format!("velocity_{}", name),
1854                velocity.iter().copied().collect(),
1855            );
1856        }
1857
1858        state
1859    }
1860
1861    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
1862        if let Some(trust) = state.get("trust_coef") {
1863            self.trust_coef = trust[0];
1864        }
1865        if let Some(exclude) = state.get("exclude_bias") {
1866            self.exclude_bias = exclude[0] > 0.5;
1867        }
1868
1869        for (key, values) in state {
1870            if let Some(name) = key.strip_prefix("velocity_") {
1871                if let Some(velocity) = self.velocity.get(name) {
1872                    let shape = velocity.raw_dim();
1873                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
1874                        self.velocity.insert(name.to_string(), arr);
1875                    }
1876                }
1877            }
1878        }
1879    }
1880}
1881
1882/// SAM optimizer (Sharpness Aware Minimization).
1883///
1884/// SAM seeks parameters that lie in neighborhoods having uniformly low loss,
1885/// improving model generalization. It requires two forward-backward passes per step:
1886/// one to compute the adversarial perturbation, and one to compute the actual gradient.
1887///
1888/// Reference: Foret et al. "Sharpness-Aware Minimization for Efficiently Improving Generalization" (ICLR 2021)
1889///
1890/// Note: This is a wrapper optimizer. SAM requires special handling in the training loop
1891/// to perform two gradient computations per step. The typical usage is:
1892/// 1. Compute gradients at current parameters
1893/// 2. Compute adversarial perturbation
1894/// 3. Compute gradients at perturbed parameters
1895/// 4. Update with the perturbed gradients
1896#[derive(Debug)]
1897pub struct SamOptimizer<O: Optimizer> {
1898    /// Base optimizer (e.g., SGD, Adam).
1899    base_optimizer: O,
1900    /// Perturbation radius (rho).
1901    rho: f64,
1902    /// Stored perturbations for each parameter.
1903    perturbations: HashMap<String, Array<f64, Ix2>>,
1904}
1905
1906impl<O: Optimizer> SamOptimizer<O> {
1907    /// Create a new SAM optimizer.
1908    ///
1909    /// # Arguments
1910    /// * `base_optimizer` - The base optimizer to use (SGD, Adam, etc.)
1911    /// * `rho` - Perturbation radius (typically 0.05)
1912    pub fn new(base_optimizer: O, rho: f64) -> TrainResult<Self> {
1913        if rho <= 0.0 {
1914            return Err(TrainError::OptimizerError(
1915                "SAM rho must be positive".to_string(),
1916            ));
1917        }
1918
1919        Ok(Self {
1920            base_optimizer,
1921            rho,
1922            perturbations: HashMap::new(),
1923        })
1924    }
1925
1926    /// Compute adversarial perturbations.
1927    ///
1928    /// This should be called with the first set of gradients to compute
1929    /// the perturbation direction.
1930    pub fn first_step(
1931        &mut self,
1932        parameters: &mut HashMap<String, Array<f64, Ix2>>,
1933        gradients: &HashMap<String, Array<f64, Ix2>>,
1934    ) -> TrainResult<()> {
1935        // Compute gradient norm
1936        let grad_norm = compute_gradient_norm(gradients);
1937
1938        if grad_norm == 0.0 {
1939            return Ok(());
1940        }
1941
1942        // Compute and apply perturbations: e = rho * grad / ||grad||
1943        for (name, param) in parameters.iter_mut() {
1944            let grad = gradients.get(name).ok_or_else(|| {
1945                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
1946            })?;
1947
1948            // Compute perturbation
1949            let perturbation = grad.mapv(|g| self.rho * g / grad_norm);
1950
1951            // Apply perturbation: param = param + e
1952            *param = &*param + &perturbation;
1953
1954            // Store perturbation for later removal
1955            self.perturbations.insert(name.clone(), perturbation);
1956        }
1957
1958        Ok(())
1959    }
1960
1961    /// Perform the actual optimization step.
1962    ///
1963    /// This should be called with the second set of gradients (computed at the perturbed parameters).
1964    /// It will remove the perturbations and update the parameters using the base optimizer.
1965    pub fn second_step(
1966        &mut self,
1967        parameters: &mut HashMap<String, Array<f64, Ix2>>,
1968        gradients: &HashMap<String, Array<f64, Ix2>>,
1969    ) -> TrainResult<()> {
1970        // Remove perturbations first: param = param - e
1971        for (name, param) in parameters.iter_mut() {
1972            if let Some(perturbation) = self.perturbations.get(name) {
1973                *param = &*param - perturbation;
1974            }
1975        }
1976
1977        // Clear perturbations
1978        self.perturbations.clear();
1979
1980        // Perform base optimizer step with the gradients at perturbed point
1981        self.base_optimizer.step(parameters, gradients)
1982    }
1983}
1984
1985impl<O: Optimizer> Optimizer for SamOptimizer<O> {
1986    fn step(
1987        &mut self,
1988        parameters: &mut HashMap<String, Array<f64, Ix2>>,
1989        gradients: &HashMap<String, Array<f64, Ix2>>,
1990    ) -> TrainResult<()> {
1991        // For the trait implementation, we just do the second step
1992        // In practice, users should call first_step() and second_step() explicitly
1993        self.second_step(parameters, gradients)
1994    }
1995
1996    fn zero_grad(&mut self) {
1997        self.base_optimizer.zero_grad();
1998    }
1999
2000    fn get_lr(&self) -> f64 {
2001        self.base_optimizer.get_lr()
2002    }
2003
2004    fn set_lr(&mut self, lr: f64) {
2005        self.base_optimizer.set_lr(lr);
2006    }
2007
2008    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
2009        let mut state = self.base_optimizer.state_dict();
2010        state.insert("rho".to_string(), vec![self.rho]);
2011
2012        for (name, perturbation) in &self.perturbations {
2013            state.insert(
2014                format!("perturbation_{}", name),
2015                perturbation.iter().copied().collect(),
2016            );
2017        }
2018
2019        state
2020    }
2021
2022    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
2023        if let Some(rho_val) = state.get("rho") {
2024            self.rho = rho_val[0];
2025        }
2026
2027        // Load base optimizer state
2028        self.base_optimizer.load_state_dict(state.clone());
2029
2030        // Load perturbations
2031        for (key, values) in state {
2032            if let Some(name) = key.strip_prefix("perturbation_") {
2033                if let Some(pert) = self.perturbations.get(name) {
2034                    let shape = pert.raw_dim();
2035                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
2036                        self.perturbations.insert(name.to_string(), arr);
2037                    }
2038                }
2039            }
2040        }
2041    }
2042}
2043
2044#[cfg(test)]
2045mod tests {
2046    use super::*;
2047    use scirs2_core::ndarray::array;
2048
2049    #[test]
2050    fn test_sgd_optimizer() {
2051        let config = OptimizerConfig {
2052            learning_rate: 0.1,
2053            momentum: 0.9,
2054            ..Default::default()
2055        };
2056        let mut optimizer = SgdOptimizer::new(config);
2057
2058        let mut params = HashMap::new();
2059        params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2060
2061        let mut grads = HashMap::new();
2062        grads.insert("w".to_string(), array![[0.1, 0.1], [0.1, 0.1]]);
2063
2064        optimizer.step(&mut params, &grads).unwrap();
2065
2066        let w = params.get("w").unwrap();
2067        assert!(w[[0, 0]] < 1.0);
2068        assert!(w[[0, 1]] < 2.0);
2069    }
2070
2071    #[test]
2072    fn test_adam_optimizer() {
2073        let config = OptimizerConfig {
2074            learning_rate: 0.001,
2075            ..Default::default()
2076        };
2077        let mut optimizer = AdamOptimizer::new(config);
2078
2079        let mut params = HashMap::new();
2080        params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2081
2082        let mut grads = HashMap::new();
2083        grads.insert("w".to_string(), array![[0.1, 0.1], [0.1, 0.1]]);
2084
2085        optimizer.step(&mut params, &grads).unwrap();
2086
2087        let w = params.get("w").unwrap();
2088        assert!(w[[0, 0]] < 1.0);
2089    }
2090
2091    #[test]
2092    fn test_adamw_optimizer() {
2093        let config = OptimizerConfig {
2094            learning_rate: 0.001,
2095            weight_decay: 0.01,
2096            ..Default::default()
2097        };
2098        let mut optimizer = AdamWOptimizer::new(config);
2099
2100        let mut params = HashMap::new();
2101        params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2102
2103        let mut grads = HashMap::new();
2104        grads.insert("w".to_string(), array![[0.1, 0.1], [0.1, 0.1]]);
2105
2106        optimizer.step(&mut params, &grads).unwrap();
2107
2108        let w = params.get("w").unwrap();
2109        assert!(w[[0, 0]] < 1.0);
2110    }
2111
2112    #[test]
2113    fn test_gradient_clipping() {
2114        let config = OptimizerConfig {
2115            learning_rate: 0.1,
2116            grad_clip: Some(0.05),
2117            ..Default::default()
2118        };
2119        let mut optimizer = SgdOptimizer::new(config);
2120
2121        let mut params = HashMap::new();
2122        params.insert("w".to_string(), array![[1.0]]);
2123
2124        let mut grads = HashMap::new();
2125        grads.insert("w".to_string(), array![[10.0]]); // Large gradient
2126
2127        optimizer.step(&mut params, &grads).unwrap();
2128
2129        let w = params.get("w").unwrap();
2130        // With clipping, gradient should be capped at 0.05
2131        assert!((w[[0, 0]] - (1.0 - 0.1 * 0.05)).abs() < 1e-6);
2132    }
2133
2134    #[test]
2135    fn test_rmsprop_optimizer() {
2136        let config = OptimizerConfig {
2137            learning_rate: 0.01,
2138            ..Default::default()
2139        };
2140        let mut optimizer = RMSpropOptimizer::new(config);
2141
2142        let mut params = HashMap::new();
2143        params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2144
2145        let mut grads = HashMap::new();
2146        grads.insert("w".to_string(), array![[0.1, 0.1], [0.1, 0.1]]);
2147
2148        optimizer.step(&mut params, &grads).unwrap();
2149
2150        let w = params.get("w").unwrap();
2151        assert!(w[[0, 0]] < 1.0);
2152    }
2153
2154    #[test]
2155    fn test_adagrad_optimizer() {
2156        let config = OptimizerConfig {
2157            learning_rate: 0.1,
2158            ..Default::default()
2159        };
2160        let mut optimizer = AdagradOptimizer::new(config);
2161
2162        let mut params = HashMap::new();
2163        params.insert("w".to_string(), array![[1.0, 2.0]]);
2164
2165        let mut grads = HashMap::new();
2166        grads.insert("w".to_string(), array![[0.1, 0.2]]);
2167
2168        optimizer.step(&mut params, &grads).unwrap();
2169
2170        let w = params.get("w").unwrap();
2171        assert!(w[[0, 0]] < 1.0);
2172        assert!(w[[0, 1]] < 2.0);
2173    }
2174
2175    #[test]
2176    fn test_nadam_optimizer() {
2177        let config = OptimizerConfig {
2178            learning_rate: 0.002,
2179            ..Default::default()
2180        };
2181        let mut optimizer = NAdamOptimizer::new(config);
2182
2183        let mut params = HashMap::new();
2184        params.insert("w".to_string(), array![[1.0, 2.0]]);
2185
2186        let mut grads = HashMap::new();
2187        grads.insert("w".to_string(), array![[0.1, 0.1]]);
2188
2189        optimizer.step(&mut params, &grads).unwrap();
2190
2191        let w = params.get("w").unwrap();
2192        assert!(w[[0, 0]] < 1.0);
2193    }
2194
2195    #[test]
2196    fn test_lamb_optimizer() {
2197        let config = OptimizerConfig {
2198            learning_rate: 0.001,
2199            weight_decay: 0.01,
2200            ..Default::default()
2201        };
2202        let mut optimizer = LambOptimizer::new(config);
2203
2204        let mut params = HashMap::new();
2205        params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2206
2207        let mut grads = HashMap::new();
2208        grads.insert("w".to_string(), array![[0.1, 0.1], [0.1, 0.1]]);
2209
2210        optimizer.step(&mut params, &grads).unwrap();
2211
2212        let w = params.get("w").unwrap();
2213        assert!(w[[0, 0]] < 1.0);
2214    }
2215
2216    #[test]
2217    fn test_adamax_optimizer() {
2218        let config = OptimizerConfig {
2219            learning_rate: 0.002,
2220            ..Default::default()
2221        };
2222        let mut optimizer = AdaMaxOptimizer::new(config);
2223
2224        let mut params = HashMap::new();
2225        params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2226
2227        let mut grads = HashMap::new();
2228        grads.insert("w".to_string(), array![[0.1, 0.2], [0.3, 0.4]]);
2229
2230        // Perform multiple steps to test infinity norm tracking
2231        for _ in 0..3 {
2232            optimizer.step(&mut params, &grads).unwrap();
2233        }
2234
2235        let w = params.get("w").unwrap();
2236        // Parameters should decrease
2237        assert!(w[[0, 0]] < 1.0);
2238        assert!(w[[0, 1]] < 2.0);
2239        assert!(w[[1, 0]] < 3.0);
2240        assert!(w[[1, 1]] < 4.0);
2241
2242        // Test state dict
2243        let state = optimizer.state_dict();
2244        assert!(state.contains_key("t"));
2245        assert!(state.contains_key("m_w"));
2246        assert!(state.contains_key("u_w"));
2247    }
2248
2249    #[test]
2250    fn test_lookahead_optimizer() {
2251        let inner_config = OptimizerConfig {
2252            learning_rate: 0.01,
2253            ..Default::default()
2254        };
2255        let inner_optimizer = AdamOptimizer::new(inner_config);
2256
2257        let mut optimizer = LookaheadOptimizer::new(inner_optimizer, 0.5, 5).unwrap();
2258
2259        let mut params = HashMap::new();
2260        params.insert("w".to_string(), array![[1.0, 2.0]]);
2261
2262        let mut grads = HashMap::new();
2263        grads.insert("w".to_string(), array![[0.1, 0.1]]);
2264
2265        // Step several times
2266        for _ in 0..10 {
2267            optimizer.step(&mut params, &grads).unwrap();
2268        }
2269
2270        let w = params.get("w").unwrap();
2271        // Parameters should decrease
2272        assert!(w[[0, 0]] < 1.0);
2273        assert!(w[[0, 1]] < 2.0);
2274
2275        // Test learning rate access
2276        assert_eq!(optimizer.get_lr(), 0.01);
2277
2278        optimizer.set_lr(0.02);
2279        assert_eq!(optimizer.get_lr(), 0.02);
2280    }
2281
2282    #[test]
2283    fn test_lookahead_invalid_alpha() {
2284        let inner_optimizer = AdamOptimizer::new(OptimizerConfig::default());
2285
2286        let result = LookaheadOptimizer::new(inner_optimizer, 1.5, 5);
2287        assert!(result.is_err());
2288
2289        let inner_optimizer = AdamOptimizer::new(OptimizerConfig::default());
2290        let result = LookaheadOptimizer::new(inner_optimizer, -0.1, 5);
2291        assert!(result.is_err());
2292    }
2293
2294    #[test]
2295    fn test_lookahead_invalid_k() {
2296        let inner_optimizer = AdamOptimizer::new(OptimizerConfig::default());
2297
2298        let result = LookaheadOptimizer::new(inner_optimizer, 0.5, 0);
2299        assert!(result.is_err());
2300    }
2301
2302    #[test]
2303    fn test_lookahead_synchronization() {
2304        let inner_config = OptimizerConfig {
2305            learning_rate: 0.1,
2306            ..Default::default()
2307        };
2308        let inner_optimizer = SgdOptimizer::new(inner_config);
2309
2310        let mut optimizer = LookaheadOptimizer::new(inner_optimizer, 0.5, 3).unwrap();
2311
2312        let mut params = HashMap::new();
2313        params.insert("w".to_string(), array![[1.0]]);
2314
2315        let mut grads = HashMap::new();
2316        grads.insert("w".to_string(), array![[0.1]]);
2317
2318        let initial_w = params.get("w").unwrap()[[0, 0]];
2319
2320        // Step 3 times to trigger synchronization
2321        for _ in 0..3 {
2322            optimizer.step(&mut params, &grads).unwrap();
2323        }
2324
2325        let w_after_sync = params.get("w").unwrap()[[0, 0]];
2326
2327        // Parameters should have changed after synchronization
2328        assert_ne!(w_after_sync, initial_w);
2329        assert!(w_after_sync < initial_w);
2330    }
2331
2332    #[test]
2333    fn test_adabelief_optimizer() {
2334        let config = OptimizerConfig {
2335            learning_rate: 0.001,
2336            weight_decay: 0.01,
2337            ..Default::default()
2338        };
2339        let mut optimizer = AdaBeliefOptimizer::new(config);
2340
2341        let mut params = HashMap::new();
2342        params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2343
2344        let mut grads = HashMap::new();
2345        grads.insert("w".to_string(), array![[0.1, 0.2], [0.3, 0.4]]);
2346
2347        // Perform multiple steps
2348        for _ in 0..5 {
2349            optimizer.step(&mut params, &grads).unwrap();
2350        }
2351
2352        let w = params.get("w").unwrap();
2353        // Parameters should decrease
2354        assert!(w[[0, 0]] < 1.0);
2355        assert!(w[[1, 1]] < 4.0);
2356
2357        // Test state dict
2358        let state = optimizer.state_dict();
2359        assert!(state.contains_key("t"));
2360        assert!(state.contains_key("m_w"));
2361        assert!(state.contains_key("s_w"));
2362    }
2363
2364    #[test]
2365    fn test_radam_optimizer() {
2366        let config = OptimizerConfig {
2367            learning_rate: 0.001,
2368            ..Default::default()
2369        };
2370        let mut optimizer = RAdamOptimizer::new(config);
2371
2372        let mut params = HashMap::new();
2373        params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2374
2375        let mut grads = HashMap::new();
2376        grads.insert("w".to_string(), array![[0.1, 0.1], [0.1, 0.1]]);
2377
2378        // Perform multiple steps (RAdam needs warmup)
2379        for _ in 0..10 {
2380            optimizer.step(&mut params, &grads).unwrap();
2381        }
2382
2383        let w = params.get("w").unwrap();
2384        // Parameters should decrease
2385        assert!(w[[0, 0]] < 1.0);
2386        assert!(w[[0, 1]] < 2.0);
2387
2388        // Test state dict
2389        let state = optimizer.state_dict();
2390        assert!(state.contains_key("t"));
2391        assert!(state.contains_key("m_w"));
2392        assert!(state.contains_key("v_w"));
2393    }
2394
2395    #[test]
2396    fn test_lars_optimizer() {
2397        let config = OptimizerConfig {
2398            learning_rate: 0.1,
2399            momentum: 0.9,
2400            weight_decay: 0.0001,
2401            ..Default::default()
2402        };
2403        let mut optimizer = LarsOptimizer::new(config, 0.001, true);
2404
2405        let mut params = HashMap::new();
2406        params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2407
2408        let mut grads = HashMap::new();
2409        grads.insert("w".to_string(), array![[0.1, 0.1], [0.1, 0.1]]);
2410
2411        optimizer.step(&mut params, &grads).unwrap();
2412
2413        let w = params.get("w").unwrap();
2414        // Parameters should decrease
2415        assert!(w[[0, 0]] < 1.0);
2416        assert!(w[[1, 1]] < 4.0);
2417
2418        // Test state dict
2419        let state = optimizer.state_dict();
2420        assert!(state.contains_key("trust_coef"));
2421        assert!(state.contains_key("exclude_bias"));
2422        assert!(state.contains_key("velocity_w"));
2423    }
2424
2425    #[test]
2426    fn test_lars_bias_exclusion() {
2427        let config = OptimizerConfig {
2428            learning_rate: 0.1,
2429            momentum: 0.9,
2430            ..Default::default()
2431        };
2432
2433        // Test with exclude_bias = true
2434        let mut optimizer = LarsOptimizer::new(config.clone(), 0.001, true);
2435
2436        let mut params = HashMap::new();
2437        params.insert("weights".to_string(), array![[1.0, 2.0]]);
2438        params.insert("bias".to_string(), array![[1.0, 2.0]]);
2439
2440        let mut grads = HashMap::new();
2441        grads.insert("weights".to_string(), array![[0.1, 0.1]]);
2442        grads.insert("bias".to_string(), array![[0.1, 0.1]]);
2443
2444        optimizer.step(&mut params, &grads).unwrap();
2445
2446        // Both should decrease, but bias should use base LR
2447        let weights = params.get("weights").unwrap();
2448        let bias = params.get("bias").unwrap();
2449        assert!(weights[[0, 0]] < 1.0);
2450        assert!(bias[[0, 0]] < 1.0);
2451    }
2452
2453    #[test]
2454    fn test_sam_optimizer() {
2455        let inner_config = OptimizerConfig {
2456            learning_rate: 0.01,
2457            ..Default::default()
2458        };
2459        let inner_optimizer = SgdOptimizer::new(inner_config);
2460
2461        let mut optimizer = SamOptimizer::new(inner_optimizer, 0.05).unwrap();
2462
2463        let mut params = HashMap::new();
2464        params.insert("w".to_string(), array![[1.0, 2.0]]);
2465
2466        let mut grads = HashMap::new();
2467        grads.insert("w".to_string(), array![[0.1, 0.1]]);
2468
2469        // First step: compute perturbation
2470        let original_w = params.get("w").unwrap().clone();
2471        optimizer.first_step(&mut params, &grads).unwrap();
2472
2473        // Parameters should be perturbed
2474        let perturbed_w = params.get("w").unwrap();
2475        assert_ne!(perturbed_w[[0, 0]], original_w[[0, 0]]);
2476
2477        // Second step: update with gradients at perturbed point
2478        optimizer.second_step(&mut params, &grads).unwrap();
2479
2480        // Parameters should be updated from original position
2481        let final_w = params.get("w").unwrap();
2482        assert!(final_w[[0, 0]] < original_w[[0, 0]]);
2483
2484        // Test state dict
2485        let state = optimizer.state_dict();
2486        assert!(state.contains_key("rho"));
2487    }
2488
2489    #[test]
2490    fn test_sam_invalid_rho() {
2491        let inner_optimizer = SgdOptimizer::new(OptimizerConfig::default());
2492
2493        let result = SamOptimizer::new(inner_optimizer, 0.0);
2494        assert!(result.is_err());
2495
2496        let inner_optimizer = SgdOptimizer::new(OptimizerConfig::default());
2497        let result = SamOptimizer::new(inner_optimizer, -0.1);
2498        assert!(result.is_err());
2499    }
2500}