1use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{Array, Ix2};
5use std::collections::HashMap;
6
7fn compute_gradient_norm(gradients: &HashMap<String, Array<f64, Ix2>>) -> f64 {
15 let mut total_norm_sq = 0.0;
16
17 for grad in gradients.values() {
18 for &g in grad.iter() {
19 total_norm_sq += g * g;
20 }
21 }
22
23 total_norm_sq.sqrt()
24}
25
26#[derive(Debug, Clone, Copy, PartialEq)]
28pub enum GradClipMode {
29 Value,
31 Norm,
33}
34
35#[derive(Debug, Clone)]
37pub struct OptimizerConfig {
38 pub learning_rate: f64,
40 pub momentum: f64,
42 pub beta1: f64,
44 pub beta2: f64,
46 pub epsilon: f64,
48 pub weight_decay: f64,
50 pub grad_clip: Option<f64>,
52 pub grad_clip_mode: GradClipMode,
54}
55
56impl Default for OptimizerConfig {
57 fn default() -> Self {
58 Self {
59 learning_rate: 0.001,
60 momentum: 0.9,
61 beta1: 0.9,
62 beta2: 0.999,
63 epsilon: 1e-8,
64 weight_decay: 0.01,
65 grad_clip: None,
66 grad_clip_mode: GradClipMode::Value,
67 }
68 }
69}
70
71pub trait Optimizer {
73 fn step(
75 &mut self,
76 parameters: &mut HashMap<String, Array<f64, Ix2>>,
77 gradients: &HashMap<String, Array<f64, Ix2>>,
78 ) -> TrainResult<()>;
79
80 fn zero_grad(&mut self);
82
83 fn get_lr(&self) -> f64;
85
86 fn set_lr(&mut self, lr: f64);
88
89 fn state_dict(&self) -> HashMap<String, Vec<f64>>;
91
92 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>);
94}
95
96#[derive(Debug)]
98pub struct SgdOptimizer {
99 config: OptimizerConfig,
100 velocity: HashMap<String, Array<f64, Ix2>>,
102}
103
104impl SgdOptimizer {
105 pub fn new(config: OptimizerConfig) -> Self {
107 Self {
108 config,
109 velocity: HashMap::new(),
110 }
111 }
112
113 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
115 if let Some(clip_value) = self.config.grad_clip {
116 match self.config.grad_clip_mode {
117 GradClipMode::Value => {
118 for grad in gradients.values_mut() {
120 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
121 }
122 }
123 GradClipMode::Norm => {
124 let total_norm = compute_gradient_norm(gradients);
126
127 if total_norm > clip_value {
128 let scale = clip_value / total_norm;
129 for grad in gradients.values_mut() {
130 grad.mapv_inplace(|g| g * scale);
131 }
132 }
133 }
134 }
135 }
136 }
137}
138
139impl Optimizer for SgdOptimizer {
140 fn step(
141 &mut self,
142 parameters: &mut HashMap<String, Array<f64, Ix2>>,
143 gradients: &HashMap<String, Array<f64, Ix2>>,
144 ) -> TrainResult<()> {
145 let mut clipped_gradients = gradients.clone();
146 self.clip_gradients(&mut clipped_gradients);
147
148 for (name, param) in parameters.iter_mut() {
149 let grad = clipped_gradients.get(name).ok_or_else(|| {
150 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
151 })?;
152
153 if !self.velocity.contains_key(name) {
155 self.velocity
156 .insert(name.clone(), Array::zeros(param.raw_dim()));
157 }
158
159 let velocity = self.velocity.get_mut(name).unwrap();
160
161 velocity.mapv_inplace(|v| self.config.momentum * v);
163 *velocity = &*velocity + &(grad * self.config.learning_rate);
164
165 *param = &*param - &*velocity;
167 }
168
169 Ok(())
170 }
171
172 fn zero_grad(&mut self) {
173 }
175
176 fn get_lr(&self) -> f64 {
177 self.config.learning_rate
178 }
179
180 fn set_lr(&mut self, lr: f64) {
181 self.config.learning_rate = lr;
182 }
183
184 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
185 let mut state = HashMap::new();
186 for (name, velocity) in &self.velocity {
187 state.insert(
188 format!("velocity_{}", name),
189 velocity.iter().copied().collect(),
190 );
191 }
192 state
193 }
194
195 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
196 for (key, values) in state {
197 if let Some(name) = key.strip_prefix("velocity_") {
198 if let Some(velocity) = self.velocity.get(name) {
200 let shape = velocity.raw_dim();
201 if let Ok(arr) = Array::from_shape_vec(shape, values) {
202 self.velocity.insert(name.to_string(), arr);
203 }
204 }
205 }
206 }
207 }
208}
209
210#[derive(Debug)]
212pub struct AdamOptimizer {
213 config: OptimizerConfig,
214 m: HashMap<String, Array<f64, Ix2>>,
216 v: HashMap<String, Array<f64, Ix2>>,
218 t: usize,
220}
221
222impl AdamOptimizer {
223 pub fn new(config: OptimizerConfig) -> Self {
225 Self {
226 config,
227 m: HashMap::new(),
228 v: HashMap::new(),
229 t: 0,
230 }
231 }
232
233 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
235 if let Some(clip_value) = self.config.grad_clip {
236 match self.config.grad_clip_mode {
237 GradClipMode::Value => {
238 for grad in gradients.values_mut() {
240 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
241 }
242 }
243 GradClipMode::Norm => {
244 let total_norm = compute_gradient_norm(gradients);
246
247 if total_norm > clip_value {
248 let scale = clip_value / total_norm;
249 for grad in gradients.values_mut() {
250 grad.mapv_inplace(|g| g * scale);
251 }
252 }
253 }
254 }
255 }
256 }
257}
258
259impl Optimizer for AdamOptimizer {
260 fn step(
261 &mut self,
262 parameters: &mut HashMap<String, Array<f64, Ix2>>,
263 gradients: &HashMap<String, Array<f64, Ix2>>,
264 ) -> TrainResult<()> {
265 let mut clipped_gradients = gradients.clone();
266 self.clip_gradients(&mut clipped_gradients);
267
268 self.t += 1;
269 let lr = self.config.learning_rate;
270 let beta1 = self.config.beta1;
271 let beta2 = self.config.beta2;
272 let eps = self.config.epsilon;
273
274 let lr_t =
276 lr * ((1.0 - beta2.powi(self.t as i32)).sqrt()) / (1.0 - beta1.powi(self.t as i32));
277
278 for (name, param) in parameters.iter_mut() {
279 let grad = clipped_gradients.get(name).ok_or_else(|| {
280 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
281 })?;
282
283 if !self.m.contains_key(name) {
285 self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
286 self.v.insert(name.clone(), Array::zeros(param.raw_dim()));
287 }
288
289 let m = self.m.get_mut(name).unwrap();
290 let v = self.v.get_mut(name).unwrap();
291
292 *m = &*m * beta1 + &(grad * (1.0 - beta1));
294
295 let grad_squared = grad.mapv(|g| g * g);
297 *v = &*v * beta2 + &(grad_squared * (1.0 - beta2));
298
299 let update = m.mapv(|m_val| m_val * lr_t) / &v.mapv(|v_val| v_val.sqrt() + eps);
301 *param = &*param - &update;
302 }
303
304 Ok(())
305 }
306
307 fn zero_grad(&mut self) {
308 }
310
311 fn get_lr(&self) -> f64 {
312 self.config.learning_rate
313 }
314
315 fn set_lr(&mut self, lr: f64) {
316 self.config.learning_rate = lr;
317 }
318
319 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
320 let mut state = HashMap::new();
321 state.insert("t".to_string(), vec![self.t as f64]);
322
323 for (name, m_val) in &self.m {
324 state.insert(format!("m_{}", name), m_val.iter().copied().collect());
325 }
326 for (name, v_val) in &self.v {
327 state.insert(format!("v_{}", name), v_val.iter().copied().collect());
328 }
329 state
330 }
331
332 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
333 if let Some(t_vals) = state.get("t") {
334 self.t = t_vals[0] as usize;
335 }
336
337 for (key, values) in state {
338 if let Some(name) = key.strip_prefix("m_") {
339 if let Some(m) = self.m.get(name) {
340 let shape = m.raw_dim();
341 if let Ok(arr) = Array::from_shape_vec(shape, values) {
342 self.m.insert(name.to_string(), arr);
343 }
344 }
345 } else if let Some(name) = key.strip_prefix("v_") {
346 if let Some(v) = self.v.get(name) {
347 let shape = v.raw_dim();
348 if let Ok(arr) = Array::from_shape_vec(shape, values) {
349 self.v.insert(name.to_string(), arr);
350 }
351 }
352 }
353 }
354 }
355}
356
357#[derive(Debug)]
359pub struct AdamWOptimizer {
360 config: OptimizerConfig,
361 m: HashMap<String, Array<f64, Ix2>>,
363 v: HashMap<String, Array<f64, Ix2>>,
365 t: usize,
367}
368
369impl AdamWOptimizer {
370 pub fn new(config: OptimizerConfig) -> Self {
372 Self {
373 config,
374 m: HashMap::new(),
375 v: HashMap::new(),
376 t: 0,
377 }
378 }
379
380 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
382 if let Some(clip_value) = self.config.grad_clip {
383 match self.config.grad_clip_mode {
384 GradClipMode::Value => {
385 for grad in gradients.values_mut() {
387 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
388 }
389 }
390 GradClipMode::Norm => {
391 let total_norm = compute_gradient_norm(gradients);
393
394 if total_norm > clip_value {
395 let scale = clip_value / total_norm;
396 for grad in gradients.values_mut() {
397 grad.mapv_inplace(|g| g * scale);
398 }
399 }
400 }
401 }
402 }
403 }
404}
405
406impl Optimizer for AdamWOptimizer {
407 fn step(
408 &mut self,
409 parameters: &mut HashMap<String, Array<f64, Ix2>>,
410 gradients: &HashMap<String, Array<f64, Ix2>>,
411 ) -> TrainResult<()> {
412 let mut clipped_gradients = gradients.clone();
413 self.clip_gradients(&mut clipped_gradients);
414
415 self.t += 1;
416 let lr = self.config.learning_rate;
417 let beta1 = self.config.beta1;
418 let beta2 = self.config.beta2;
419 let eps = self.config.epsilon;
420 let weight_decay = self.config.weight_decay;
421
422 let lr_t =
424 lr * ((1.0 - beta2.powi(self.t as i32)).sqrt()) / (1.0 - beta1.powi(self.t as i32));
425
426 for (name, param) in parameters.iter_mut() {
427 let grad = clipped_gradients.get(name).ok_or_else(|| {
428 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
429 })?;
430
431 if !self.m.contains_key(name) {
433 self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
434 self.v.insert(name.clone(), Array::zeros(param.raw_dim()));
435 }
436
437 let m = self.m.get_mut(name).unwrap();
438 let v = self.v.get_mut(name).unwrap();
439
440 *m = &*m * beta1 + &(grad * (1.0 - beta1));
442
443 let grad_squared = grad.mapv(|g| g * g);
445 *v = &*v * beta2 + &(grad_squared * (1.0 - beta2));
446
447 let update = m.mapv(|m_val| m_val * lr_t) / &v.mapv(|v_val| v_val.sqrt() + eps);
449
450 let decay = param.mapv(|p| p * lr * weight_decay);
452
453 *param = &*param - &update - &decay;
455 }
456
457 Ok(())
458 }
459
460 fn zero_grad(&mut self) {
461 }
463
464 fn get_lr(&self) -> f64 {
465 self.config.learning_rate
466 }
467
468 fn set_lr(&mut self, lr: f64) {
469 self.config.learning_rate = lr;
470 }
471
472 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
473 let mut state = HashMap::new();
474 state.insert("t".to_string(), vec![self.t as f64]);
475
476 for (name, m_val) in &self.m {
477 state.insert(format!("m_{}", name), m_val.iter().copied().collect());
478 }
479 for (name, v_val) in &self.v {
480 state.insert(format!("v_{}", name), v_val.iter().copied().collect());
481 }
482 state
483 }
484
485 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
486 if let Some(t_vals) = state.get("t") {
487 self.t = t_vals[0] as usize;
488 }
489
490 for (key, values) in state {
491 if let Some(name) = key.strip_prefix("m_") {
492 if let Some(m) = self.m.get(name) {
493 let shape = m.raw_dim();
494 if let Ok(arr) = Array::from_shape_vec(shape, values) {
495 self.m.insert(name.to_string(), arr);
496 }
497 }
498 } else if let Some(name) = key.strip_prefix("v_") {
499 if let Some(v) = self.v.get(name) {
500 let shape = v.raw_dim();
501 if let Ok(arr) = Array::from_shape_vec(shape, values) {
502 self.v.insert(name.to_string(), arr);
503 }
504 }
505 }
506 }
507 }
508}
509
510#[derive(Debug)]
512pub struct RMSpropOptimizer {
513 config: OptimizerConfig,
514 v: HashMap<String, Array<f64, Ix2>>,
516}
517
518impl RMSpropOptimizer {
519 pub fn new(config: OptimizerConfig) -> Self {
521 Self {
522 config,
523 v: HashMap::new(),
524 }
525 }
526
527 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
529 if let Some(clip_value) = self.config.grad_clip {
530 match self.config.grad_clip_mode {
531 GradClipMode::Value => {
532 for grad in gradients.values_mut() {
533 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
534 }
535 }
536 GradClipMode::Norm => {
537 let total_norm = compute_gradient_norm(gradients);
538 if total_norm > clip_value {
539 let scale = clip_value / total_norm;
540 for grad in gradients.values_mut() {
541 grad.mapv_inplace(|g| g * scale);
542 }
543 }
544 }
545 }
546 }
547 }
548}
549
550impl Optimizer for RMSpropOptimizer {
551 fn step(
552 &mut self,
553 parameters: &mut HashMap<String, Array<f64, Ix2>>,
554 gradients: &HashMap<String, Array<f64, Ix2>>,
555 ) -> TrainResult<()> {
556 let mut clipped_gradients = gradients.clone();
557 self.clip_gradients(&mut clipped_gradients);
558
559 let lr = self.config.learning_rate;
560 let alpha = self.config.beta2; let eps = self.config.epsilon;
562
563 for (name, param) in parameters.iter_mut() {
564 let grad = clipped_gradients.get(name).ok_or_else(|| {
565 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
566 })?;
567
568 if !self.v.contains_key(name) {
570 self.v.insert(name.clone(), Array::zeros(param.raw_dim()));
571 }
572
573 let v = self.v.get_mut(name).unwrap();
574
575 let grad_squared = grad.mapv(|g| g * g);
577 *v = &*v * alpha + &(grad_squared * (1.0 - alpha));
578
579 let update = grad / &v.mapv(|v_val| v_val.sqrt() + eps);
581 *param = &*param - &(update * lr);
582 }
583
584 Ok(())
585 }
586
587 fn zero_grad(&mut self) {}
588
589 fn get_lr(&self) -> f64 {
590 self.config.learning_rate
591 }
592
593 fn set_lr(&mut self, lr: f64) {
594 self.config.learning_rate = lr;
595 }
596
597 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
598 let mut state = HashMap::new();
599 for (name, v_val) in &self.v {
600 state.insert(format!("v_{}", name), v_val.iter().copied().collect());
601 }
602 state
603 }
604
605 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
606 for (key, values) in state {
607 if let Some(name) = key.strip_prefix("v_") {
608 if let Some(v) = self.v.get(name) {
609 let shape = v.raw_dim();
610 if let Ok(arr) = Array::from_shape_vec(shape, values) {
611 self.v.insert(name.to_string(), arr);
612 }
613 }
614 }
615 }
616 }
617}
618
619#[derive(Debug)]
621pub struct AdagradOptimizer {
622 config: OptimizerConfig,
623 sum_squared_grads: HashMap<String, Array<f64, Ix2>>,
625}
626
627impl AdagradOptimizer {
628 pub fn new(config: OptimizerConfig) -> Self {
630 Self {
631 config,
632 sum_squared_grads: HashMap::new(),
633 }
634 }
635
636 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
638 if let Some(clip_value) = self.config.grad_clip {
639 match self.config.grad_clip_mode {
640 GradClipMode::Value => {
641 for grad in gradients.values_mut() {
642 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
643 }
644 }
645 GradClipMode::Norm => {
646 let total_norm = compute_gradient_norm(gradients);
647 if total_norm > clip_value {
648 let scale = clip_value / total_norm;
649 for grad in gradients.values_mut() {
650 grad.mapv_inplace(|g| g * scale);
651 }
652 }
653 }
654 }
655 }
656 }
657}
658
659impl Optimizer for AdagradOptimizer {
660 fn step(
661 &mut self,
662 parameters: &mut HashMap<String, Array<f64, Ix2>>,
663 gradients: &HashMap<String, Array<f64, Ix2>>,
664 ) -> TrainResult<()> {
665 let mut clipped_gradients = gradients.clone();
666 self.clip_gradients(&mut clipped_gradients);
667
668 let lr = self.config.learning_rate;
669 let eps = self.config.epsilon;
670
671 for (name, param) in parameters.iter_mut() {
672 let grad = clipped_gradients.get(name).ok_or_else(|| {
673 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
674 })?;
675
676 if !self.sum_squared_grads.contains_key(name) {
678 self.sum_squared_grads
679 .insert(name.clone(), Array::zeros(param.raw_dim()));
680 }
681
682 let sum_sq = self.sum_squared_grads.get_mut(name).unwrap();
683
684 let grad_squared = grad.mapv(|g| g * g);
686 *sum_sq = &*sum_sq + &grad_squared;
687
688 let update = grad / &sum_sq.mapv(|s| s.sqrt() + eps);
690 *param = &*param - &(update * lr);
691 }
692
693 Ok(())
694 }
695
696 fn zero_grad(&mut self) {}
697
698 fn get_lr(&self) -> f64 {
699 self.config.learning_rate
700 }
701
702 fn set_lr(&mut self, lr: f64) {
703 self.config.learning_rate = lr;
704 }
705
706 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
707 let mut state = HashMap::new();
708 for (name, sum_sq) in &self.sum_squared_grads {
709 state.insert(
710 format!("sum_squared_grads_{}", name),
711 sum_sq.iter().copied().collect(),
712 );
713 }
714 state
715 }
716
717 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
718 for (key, values) in state {
719 if let Some(name) = key.strip_prefix("sum_squared_grads_") {
720 if let Some(sum_sq) = self.sum_squared_grads.get(name) {
721 let shape = sum_sq.raw_dim();
722 if let Ok(arr) = Array::from_shape_vec(shape, values) {
723 self.sum_squared_grads.insert(name.to_string(), arr);
724 }
725 }
726 }
727 }
728 }
729}
730
731#[derive(Debug)]
733pub struct NAdamOptimizer {
734 config: OptimizerConfig,
735 m: HashMap<String, Array<f64, Ix2>>,
737 v: HashMap<String, Array<f64, Ix2>>,
739 t: usize,
741}
742
743impl NAdamOptimizer {
744 pub fn new(config: OptimizerConfig) -> Self {
746 Self {
747 config,
748 m: HashMap::new(),
749 v: HashMap::new(),
750 t: 0,
751 }
752 }
753
754 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
756 if let Some(clip_value) = self.config.grad_clip {
757 match self.config.grad_clip_mode {
758 GradClipMode::Value => {
759 for grad in gradients.values_mut() {
760 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
761 }
762 }
763 GradClipMode::Norm => {
764 let total_norm = compute_gradient_norm(gradients);
765 if total_norm > clip_value {
766 let scale = clip_value / total_norm;
767 for grad in gradients.values_mut() {
768 grad.mapv_inplace(|g| g * scale);
769 }
770 }
771 }
772 }
773 }
774 }
775}
776
777impl Optimizer for NAdamOptimizer {
778 fn step(
779 &mut self,
780 parameters: &mut HashMap<String, Array<f64, Ix2>>,
781 gradients: &HashMap<String, Array<f64, Ix2>>,
782 ) -> TrainResult<()> {
783 let mut clipped_gradients = gradients.clone();
784 self.clip_gradients(&mut clipped_gradients);
785
786 self.t += 1;
787 let lr = self.config.learning_rate;
788 let beta1 = self.config.beta1;
789 let beta2 = self.config.beta2;
790 let eps = self.config.epsilon;
791
792 let mu_t = beta1 * (1.0 - 0.5 * 0.96_f64.powi(self.t as i32));
794 let mu_t_next = beta1 * (1.0 - 0.5 * 0.96_f64.powi((self.t + 1) as i32));
795
796 for (name, param) in parameters.iter_mut() {
797 let grad = clipped_gradients.get(name).ok_or_else(|| {
798 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
799 })?;
800
801 if !self.m.contains_key(name) {
803 self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
804 self.v.insert(name.clone(), Array::zeros(param.raw_dim()));
805 }
806
807 let m = self.m.get_mut(name).unwrap();
808 let v = self.v.get_mut(name).unwrap();
809
810 *m = &*m * beta1 + &(grad * (1.0 - beta1));
812
813 let grad_squared = grad.mapv(|g| g * g);
815 *v = &*v * beta2 + &(grad_squared * (1.0 - beta2));
816
817 let m_hat = &*m / (1.0 - beta1.powi(self.t as i32));
819 let v_hat = &*v / (1.0 - beta2.powi(self.t as i32));
820
821 let m_bar =
823 &m_hat * mu_t_next / (1.0 - mu_t_next) + &(grad * (1.0 - mu_t) / (1.0 - mu_t_next));
824
825 let update = m_bar / &v_hat.mapv(|v_val| v_val.sqrt() + eps);
827 *param = &*param - &(update * lr);
828 }
829
830 Ok(())
831 }
832
833 fn zero_grad(&mut self) {}
834
835 fn get_lr(&self) -> f64 {
836 self.config.learning_rate
837 }
838
839 fn set_lr(&mut self, lr: f64) {
840 self.config.learning_rate = lr;
841 }
842
843 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
844 let mut state = HashMap::new();
845 state.insert("t".to_string(), vec![self.t as f64]);
846
847 for (name, m_val) in &self.m {
848 state.insert(format!("m_{}", name), m_val.iter().copied().collect());
849 }
850 for (name, v_val) in &self.v {
851 state.insert(format!("v_{}", name), v_val.iter().copied().collect());
852 }
853 state
854 }
855
856 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
857 if let Some(t_vals) = state.get("t") {
858 self.t = t_vals[0] as usize;
859 }
860
861 for (key, values) in state {
862 if let Some(name) = key.strip_prefix("m_") {
863 if let Some(m) = self.m.get(name) {
864 let shape = m.raw_dim();
865 if let Ok(arr) = Array::from_shape_vec(shape, values) {
866 self.m.insert(name.to_string(), arr);
867 }
868 }
869 } else if let Some(name) = key.strip_prefix("v_") {
870 if let Some(v) = self.v.get(name) {
871 let shape = v.raw_dim();
872 if let Ok(arr) = Array::from_shape_vec(shape, values) {
873 self.v.insert(name.to_string(), arr);
874 }
875 }
876 }
877 }
878 }
879}
880
881#[derive(Debug)]
884pub struct LambOptimizer {
885 config: OptimizerConfig,
886 m: HashMap<String, Array<f64, Ix2>>,
888 v: HashMap<String, Array<f64, Ix2>>,
890 t: usize,
892}
893
894impl LambOptimizer {
895 pub fn new(config: OptimizerConfig) -> Self {
897 Self {
898 config,
899 m: HashMap::new(),
900 v: HashMap::new(),
901 t: 0,
902 }
903 }
904
905 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
907 if let Some(clip_value) = self.config.grad_clip {
908 match self.config.grad_clip_mode {
909 GradClipMode::Value => {
910 for grad in gradients.values_mut() {
911 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
912 }
913 }
914 GradClipMode::Norm => {
915 let total_norm = compute_gradient_norm(gradients);
916 if total_norm > clip_value {
917 let scale = clip_value / total_norm;
918 for grad in gradients.values_mut() {
919 grad.mapv_inplace(|g| g * scale);
920 }
921 }
922 }
923 }
924 }
925 }
926
927 fn compute_norm(arr: &Array<f64, Ix2>) -> f64 {
929 arr.iter().map(|&x| x * x).sum::<f64>().sqrt()
930 }
931}
932
933impl Optimizer for LambOptimizer {
934 fn step(
935 &mut self,
936 parameters: &mut HashMap<String, Array<f64, Ix2>>,
937 gradients: &HashMap<String, Array<f64, Ix2>>,
938 ) -> TrainResult<()> {
939 let mut clipped_gradients = gradients.clone();
940 self.clip_gradients(&mut clipped_gradients);
941
942 self.t += 1;
943 let lr = self.config.learning_rate;
944 let beta1 = self.config.beta1;
945 let beta2 = self.config.beta2;
946 let eps = self.config.epsilon;
947 let weight_decay = self.config.weight_decay;
948
949 for (name, param) in parameters.iter_mut() {
950 let grad = clipped_gradients.get(name).ok_or_else(|| {
951 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
952 })?;
953
954 if !self.m.contains_key(name) {
956 self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
957 self.v.insert(name.clone(), Array::zeros(param.raw_dim()));
958 }
959
960 let m = self.m.get_mut(name).unwrap();
961 let v = self.v.get_mut(name).unwrap();
962
963 *m = &*m * beta1 + &(grad * (1.0 - beta1));
965
966 let grad_squared = grad.mapv(|g| g * g);
968 *v = &*v * beta2 + &(grad_squared * (1.0 - beta2));
969
970 let m_hat = &*m / (1.0 - beta1.powi(self.t as i32));
972 let v_hat = &*v / (1.0 - beta2.powi(self.t as i32));
973
974 let adam_step = &m_hat / &v_hat.mapv(|v_val| v_val.sqrt() + eps);
976
977 let update = &adam_step + ¶m.mapv(|p| p * weight_decay);
979
980 let param_norm = Self::compute_norm(param);
982 let update_norm = Self::compute_norm(&update);
983
984 let trust_ratio = if param_norm > 0.0 && update_norm > 0.0 {
985 param_norm / update_norm
986 } else {
987 1.0
988 };
989
990 *param = &*param - &(update * (lr * trust_ratio));
992 }
993
994 Ok(())
995 }
996
997 fn zero_grad(&mut self) {}
998
999 fn get_lr(&self) -> f64 {
1000 self.config.learning_rate
1001 }
1002
1003 fn set_lr(&mut self, lr: f64) {
1004 self.config.learning_rate = lr;
1005 }
1006
1007 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
1008 let mut state = HashMap::new();
1009 state.insert("t".to_string(), vec![self.t as f64]);
1010
1011 for (name, m_val) in &self.m {
1012 state.insert(format!("m_{}", name), m_val.iter().copied().collect());
1013 }
1014 for (name, v_val) in &self.v {
1015 state.insert(format!("v_{}", name), v_val.iter().copied().collect());
1016 }
1017 state
1018 }
1019
1020 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
1021 if let Some(t_vals) = state.get("t") {
1022 self.t = t_vals[0] as usize;
1023 }
1024
1025 for (key, values) in state {
1026 if let Some(name) = key.strip_prefix("m_") {
1027 if let Some(m) = self.m.get(name) {
1028 let shape = m.raw_dim();
1029 if let Ok(arr) = Array::from_shape_vec(shape, values) {
1030 self.m.insert(name.to_string(), arr);
1031 }
1032 }
1033 } else if let Some(name) = key.strip_prefix("v_") {
1034 if let Some(v) = self.v.get(name) {
1035 let shape = v.raw_dim();
1036 if let Ok(arr) = Array::from_shape_vec(shape, values) {
1037 self.v.insert(name.to_string(), arr);
1038 }
1039 }
1040 }
1041 }
1042 }
1043}
1044
1045#[derive(Debug)]
1052pub struct AdaMaxOptimizer {
1053 config: OptimizerConfig,
1054 m: HashMap<String, Array<f64, Ix2>>,
1056 u: HashMap<String, Array<f64, Ix2>>,
1058 t: usize,
1060}
1061
1062impl AdaMaxOptimizer {
1063 pub fn new(config: OptimizerConfig) -> Self {
1065 Self {
1066 config,
1067 m: HashMap::new(),
1068 u: HashMap::new(),
1069 t: 0,
1070 }
1071 }
1072
1073 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
1075 if let Some(clip_value) = self.config.grad_clip {
1076 match self.config.grad_clip_mode {
1077 GradClipMode::Value => {
1078 for grad in gradients.values_mut() {
1079 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
1080 }
1081 }
1082 GradClipMode::Norm => {
1083 let total_norm = compute_gradient_norm(gradients);
1084 if total_norm > clip_value {
1085 let scale = clip_value / total_norm;
1086 for grad in gradients.values_mut() {
1087 grad.mapv_inplace(|g| g * scale);
1088 }
1089 }
1090 }
1091 }
1092 }
1093 }
1094}
1095
1096impl Optimizer for AdaMaxOptimizer {
1097 fn step(
1098 &mut self,
1099 parameters: &mut HashMap<String, Array<f64, Ix2>>,
1100 gradients: &HashMap<String, Array<f64, Ix2>>,
1101 ) -> TrainResult<()> {
1102 let mut clipped_gradients = gradients.clone();
1103 self.clip_gradients(&mut clipped_gradients);
1104
1105 self.t += 1;
1106 let lr = self.config.learning_rate;
1107 let beta1 = self.config.beta1;
1108 let beta2 = self.config.beta2;
1109
1110 for (name, param) in parameters.iter_mut() {
1111 let grad = clipped_gradients.get(name).ok_or_else(|| {
1112 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
1113 })?;
1114
1115 if !self.m.contains_key(name) {
1117 self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
1118 self.u.insert(name.clone(), Array::zeros(param.raw_dim()));
1119 }
1120
1121 let m = self.m.get_mut(name).unwrap();
1122 let u = self.u.get_mut(name).unwrap();
1123
1124 *m = &*m * beta1 + &(grad * (1.0 - beta1));
1126
1127 for i in 0..u.nrows() {
1129 for j in 0..u.ncols() {
1130 u[[i, j]] = (beta2 * u[[i, j]]).max(grad[[i, j]].abs());
1131 }
1132 }
1133
1134 let bias_correction = 1.0 - beta1.powi(self.t as i32);
1136 let lr_t = lr / bias_correction;
1137
1138 for i in 0..param.nrows() {
1140 for j in 0..param.ncols() {
1141 let update = lr_t * m[[i, j]] / (u[[i, j]] + self.config.epsilon);
1142 param[[i, j]] -= update;
1143 }
1144 }
1145 }
1146
1147 Ok(())
1148 }
1149
1150 fn zero_grad(&mut self) {}
1151
1152 fn get_lr(&self) -> f64 {
1153 self.config.learning_rate
1154 }
1155
1156 fn set_lr(&mut self, lr: f64) {
1157 self.config.learning_rate = lr;
1158 }
1159
1160 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
1161 let mut state = HashMap::new();
1162 state.insert("t".to_string(), vec![self.t as f64]);
1163
1164 for (name, m_val) in &self.m {
1165 state.insert(format!("m_{}", name), m_val.iter().copied().collect());
1166 }
1167 for (name, u_val) in &self.u {
1168 state.insert(format!("u_{}", name), u_val.iter().copied().collect());
1169 }
1170 state
1171 }
1172
1173 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
1174 if let Some(t_vals) = state.get("t") {
1175 self.t = t_vals[0] as usize;
1176 }
1177
1178 for (key, values) in state {
1179 if let Some(name) = key.strip_prefix("m_") {
1180 if let Some(m) = self.m.get(name) {
1181 let shape = m.raw_dim();
1182 if let Ok(arr) = Array::from_shape_vec(shape, values) {
1183 self.m.insert(name.to_string(), arr);
1184 }
1185 }
1186 } else if let Some(name) = key.strip_prefix("u_") {
1187 if let Some(u) = self.u.get(name) {
1188 let shape = u.raw_dim();
1189 if let Ok(arr) = Array::from_shape_vec(shape, values) {
1190 self.u.insert(name.to_string(), arr);
1191 }
1192 }
1193 }
1194 }
1195 }
1196}
1197
1198#[derive(Debug)]
1205pub struct LookaheadOptimizer<O: Optimizer> {
1206 inner_optimizer: O,
1208 slow_weights: HashMap<String, Array<f64, Ix2>>,
1210 alpha: f64,
1212 k: usize,
1214 step_counter: usize,
1216}
1217
1218impl<O: Optimizer> LookaheadOptimizer<O> {
1219 pub fn new(inner_optimizer: O, alpha: f64, k: usize) -> TrainResult<Self> {
1226 if !(0.0..=1.0).contains(&alpha) {
1227 return Err(TrainError::InvalidParameter(
1228 "alpha must be in [0, 1]".to_string(),
1229 ));
1230 }
1231 if k == 0 {
1232 return Err(TrainError::InvalidParameter(
1233 "k must be at least 1".to_string(),
1234 ));
1235 }
1236
1237 Ok(Self {
1238 inner_optimizer,
1239 slow_weights: HashMap::new(),
1240 alpha,
1241 k,
1242 step_counter: 0,
1243 })
1244 }
1245
1246 fn initialize_slow_weights(&mut self, parameters: &HashMap<String, Array<f64, Ix2>>) {
1248 if self.slow_weights.is_empty() {
1249 for (name, param) in parameters {
1250 self.slow_weights.insert(name.clone(), param.clone());
1251 }
1252 }
1253 }
1254
1255 fn synchronize_weights(&mut self, parameters: &mut HashMap<String, Array<f64, Ix2>>) {
1257 for (name, param) in parameters.iter_mut() {
1258 if let Some(slow_weight) = self.slow_weights.get_mut(name) {
1259 *slow_weight = &*slow_weight + &((&*param - &*slow_weight) * self.alpha);
1261
1262 *param = slow_weight.clone();
1264 }
1265 }
1266 }
1267}
1268
1269impl<O: Optimizer> Optimizer for LookaheadOptimizer<O> {
1270 fn step(
1271 &mut self,
1272 parameters: &mut HashMap<String, Array<f64, Ix2>>,
1273 gradients: &HashMap<String, Array<f64, Ix2>>,
1274 ) -> TrainResult<()> {
1275 self.initialize_slow_weights(parameters);
1277
1278 self.inner_optimizer.step(parameters, gradients)?;
1280
1281 self.step_counter += 1;
1282
1283 if self.step_counter.is_multiple_of(self.k) {
1285 self.synchronize_weights(parameters);
1286 }
1287
1288 Ok(())
1289 }
1290
1291 fn zero_grad(&mut self) {
1292 self.inner_optimizer.zero_grad();
1293 }
1294
1295 fn get_lr(&self) -> f64 {
1296 self.inner_optimizer.get_lr()
1297 }
1298
1299 fn set_lr(&mut self, lr: f64) {
1300 self.inner_optimizer.set_lr(lr);
1301 }
1302
1303 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
1304 let mut state = self.inner_optimizer.state_dict();
1305
1306 state.insert("step_counter".to_string(), vec![self.step_counter as f64]);
1308 state.insert("alpha".to_string(), vec![self.alpha]);
1309 state.insert("k".to_string(), vec![self.k as f64]);
1310
1311 for (name, slow_weight) in &self.slow_weights {
1312 state.insert(
1313 format!("slow_{}", name),
1314 slow_weight.iter().copied().collect(),
1315 );
1316 }
1317
1318 state
1319 }
1320
1321 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
1322 self.inner_optimizer.load_state_dict(state.clone());
1324
1325 if let Some(counter) = state.get("step_counter") {
1327 self.step_counter = counter[0] as usize;
1328 }
1329 if let Some(alpha_val) = state.get("alpha") {
1330 self.alpha = alpha_val[0];
1331 }
1332 if let Some(k_val) = state.get("k") {
1333 self.k = k_val[0] as usize;
1334 }
1335
1336 for (key, values) in state {
1338 if let Some(name) = key.strip_prefix("slow_") {
1339 if let Some(slow_weight) = self.slow_weights.get(name) {
1340 let shape = slow_weight.raw_dim();
1341 if let Ok(arr) = Array::from_shape_vec(shape, values) {
1342 self.slow_weights.insert(name.to_string(), arr);
1343 }
1344 }
1345 }
1346 }
1347 }
1348}
1349
1350#[derive(Debug)]
1359pub struct AdaBeliefOptimizer {
1360 config: OptimizerConfig,
1361 m: HashMap<String, Array<f64, Ix2>>,
1363 s: HashMap<String, Array<f64, Ix2>>,
1365 t: usize,
1367}
1368
1369impl AdaBeliefOptimizer {
1370 pub fn new(config: OptimizerConfig) -> Self {
1372 Self {
1373 config,
1374 m: HashMap::new(),
1375 s: HashMap::new(),
1376 t: 0,
1377 }
1378 }
1379
1380 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
1382 if let Some(clip_value) = self.config.grad_clip {
1383 match self.config.grad_clip_mode {
1384 GradClipMode::Value => {
1385 for grad in gradients.values_mut() {
1386 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
1387 }
1388 }
1389 GradClipMode::Norm => {
1390 let total_norm = compute_gradient_norm(gradients);
1391 if total_norm > clip_value {
1392 let scale = clip_value / total_norm;
1393 for grad in gradients.values_mut() {
1394 grad.mapv_inplace(|g| g * scale);
1395 }
1396 }
1397 }
1398 }
1399 }
1400 }
1401}
1402
1403impl Optimizer for AdaBeliefOptimizer {
1404 fn step(
1405 &mut self,
1406 parameters: &mut HashMap<String, Array<f64, Ix2>>,
1407 gradients: &HashMap<String, Array<f64, Ix2>>,
1408 ) -> TrainResult<()> {
1409 let mut clipped_gradients = gradients.clone();
1410 self.clip_gradients(&mut clipped_gradients);
1411
1412 self.t += 1;
1413 let lr = self.config.learning_rate;
1414 let beta1 = self.config.beta1;
1415 let beta2 = self.config.beta2;
1416 let eps = self.config.epsilon;
1417 let weight_decay = self.config.weight_decay;
1418
1419 let bias_correction1 = 1.0 - beta1.powi(self.t as i32);
1421 let bias_correction2 = 1.0 - beta2.powi(self.t as i32);
1422
1423 for (name, param) in parameters.iter_mut() {
1424 let grad = clipped_gradients.get(name).ok_or_else(|| {
1425 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
1426 })?;
1427
1428 if !self.m.contains_key(name) {
1430 self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
1431 self.s.insert(name.clone(), Array::zeros(param.raw_dim()));
1432 }
1433
1434 let m = self.m.get_mut(name).unwrap();
1435 let s = self.s.get_mut(name).unwrap();
1436
1437 *m = &*m * beta1 + &(grad * (1.0 - beta1));
1439
1440 let grad_diff = grad - &*m;
1442
1443 let grad_diff_squared = grad_diff.mapv(|g| g * g);
1445 *s = &*s * beta2 + &(grad_diff_squared * (1.0 - beta2));
1446
1447 let m_hat = &*m / bias_correction1;
1449 let s_hat = &*s / bias_correction2;
1450
1451 if weight_decay > 0.0 {
1453 param.mapv_inplace(|p| p * (1.0 - lr * weight_decay));
1454 }
1455
1456 let update = m_hat / (s_hat.mapv(|v| v.sqrt()) + eps);
1458 *param = &*param - &(update * lr);
1459 }
1460
1461 Ok(())
1462 }
1463
1464 fn zero_grad(&mut self) {
1465 }
1467
1468 fn get_lr(&self) -> f64 {
1469 self.config.learning_rate
1470 }
1471
1472 fn set_lr(&mut self, lr: f64) {
1473 self.config.learning_rate = lr;
1474 }
1475
1476 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
1477 let mut state = HashMap::new();
1478 state.insert("t".to_string(), vec![self.t as f64]);
1479
1480 for (name, m_val) in &self.m {
1481 state.insert(format!("m_{}", name), m_val.iter().copied().collect());
1482 }
1483 for (name, s_val) in &self.s {
1484 state.insert(format!("s_{}", name), s_val.iter().copied().collect());
1485 }
1486
1487 state
1488 }
1489
1490 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
1491 if let Some(t_val) = state.get("t") {
1492 self.t = t_val[0] as usize;
1493 }
1494
1495 for (key, values) in state {
1496 if let Some(name) = key.strip_prefix("m_") {
1497 if let Some(m_array) = self.m.get(name) {
1498 let shape = m_array.raw_dim();
1499 if let Ok(arr) = Array::from_shape_vec(shape, values) {
1500 self.m.insert(name.to_string(), arr);
1501 }
1502 }
1503 } else if let Some(name) = key.strip_prefix("s_") {
1504 if let Some(s_array) = self.s.get(name) {
1505 let shape = s_array.raw_dim();
1506 if let Ok(arr) = Array::from_shape_vec(shape, values) {
1507 self.s.insert(name.to_string(), arr);
1508 }
1509 }
1510 }
1511 }
1512 }
1513}
1514
1515#[derive(Debug)]
1523pub struct RAdamOptimizer {
1524 config: OptimizerConfig,
1525 m: HashMap<String, Array<f64, Ix2>>,
1527 v: HashMap<String, Array<f64, Ix2>>,
1529 t: usize,
1531}
1532
1533impl RAdamOptimizer {
1534 pub fn new(config: OptimizerConfig) -> Self {
1536 Self {
1537 config,
1538 m: HashMap::new(),
1539 v: HashMap::new(),
1540 t: 0,
1541 }
1542 }
1543
1544 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
1546 if let Some(clip_value) = self.config.grad_clip {
1547 match self.config.grad_clip_mode {
1548 GradClipMode::Value => {
1549 for grad in gradients.values_mut() {
1550 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
1551 }
1552 }
1553 GradClipMode::Norm => {
1554 let total_norm = compute_gradient_norm(gradients);
1555 if total_norm > clip_value {
1556 let scale = clip_value / total_norm;
1557 for grad in gradients.values_mut() {
1558 grad.mapv_inplace(|g| g * scale);
1559 }
1560 }
1561 }
1562 }
1563 }
1564 }
1565
1566 fn compute_rectification(&self) -> (bool, f64) {
1568 let beta2 = self.config.beta2;
1569 let t = self.t as f64;
1570
1571 let rho_inf = 2.0 / (1.0 - beta2) - 1.0;
1573
1574 let rho_t = rho_inf - 2.0 * t * beta2.powf(t) / (1.0 - beta2.powf(t));
1576
1577 if rho_t > 5.0 {
1579 let rect = ((rho_t - 4.0) * (rho_t - 2.0) * rho_inf)
1581 / ((rho_inf - 4.0) * (rho_inf - 2.0) * rho_t);
1582 (true, rect.sqrt())
1583 } else {
1584 (false, 0.0)
1586 }
1587 }
1588}
1589
1590impl Optimizer for RAdamOptimizer {
1591 fn step(
1592 &mut self,
1593 parameters: &mut HashMap<String, Array<f64, Ix2>>,
1594 gradients: &HashMap<String, Array<f64, Ix2>>,
1595 ) -> TrainResult<()> {
1596 let mut clipped_gradients = gradients.clone();
1597 self.clip_gradients(&mut clipped_gradients);
1598
1599 self.t += 1;
1600 let lr = self.config.learning_rate;
1601 let beta1 = self.config.beta1;
1602 let beta2 = self.config.beta2;
1603 let eps = self.config.epsilon;
1604
1605 let bias_correction1 = 1.0 - beta1.powi(self.t as i32);
1607
1608 let (use_adaptive, rect) = self.compute_rectification();
1610
1611 for (name, param) in parameters.iter_mut() {
1612 let grad = clipped_gradients.get(name).ok_or_else(|| {
1613 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
1614 })?;
1615
1616 if !self.m.contains_key(name) {
1618 self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
1619 self.v.insert(name.clone(), Array::zeros(param.raw_dim()));
1620 }
1621
1622 let m = self.m.get_mut(name).unwrap();
1623 let v = self.v.get_mut(name).unwrap();
1624
1625 *m = &*m * beta1 + &(grad * (1.0 - beta1));
1627
1628 let grad_squared = grad.mapv(|g| g * g);
1630 *v = &*v * beta2 + &(grad_squared * (1.0 - beta2));
1631
1632 let m_hat = &*m / bias_correction1;
1634
1635 if use_adaptive {
1636 let bias_correction2 = 1.0 - beta2.powi(self.t as i32);
1638 let v_hat = &*v / bias_correction2;
1639
1640 let update = m_hat / (v_hat.mapv(|val| val.sqrt()) + eps);
1642 *param = &*param - &(update * (lr * rect));
1643 } else {
1644 *param = &*param - &(m_hat * lr);
1646 }
1647 }
1648
1649 Ok(())
1650 }
1651
1652 fn zero_grad(&mut self) {
1653 }
1655
1656 fn get_lr(&self) -> f64 {
1657 self.config.learning_rate
1658 }
1659
1660 fn set_lr(&mut self, lr: f64) {
1661 self.config.learning_rate = lr;
1662 }
1663
1664 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
1665 let mut state = HashMap::new();
1666 state.insert("t".to_string(), vec![self.t as f64]);
1667
1668 for (name, m_val) in &self.m {
1669 state.insert(format!("m_{}", name), m_val.iter().copied().collect());
1670 }
1671 for (name, v_val) in &self.v {
1672 state.insert(format!("v_{}", name), v_val.iter().copied().collect());
1673 }
1674
1675 state
1676 }
1677
1678 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
1679 if let Some(t_val) = state.get("t") {
1680 self.t = t_val[0] as usize;
1681 }
1682
1683 for (key, values) in state {
1684 if let Some(name) = key.strip_prefix("m_") {
1685 if let Some(m_array) = self.m.get(name) {
1686 let shape = m_array.raw_dim();
1687 if let Ok(arr) = Array::from_shape_vec(shape, values) {
1688 self.m.insert(name.to_string(), arr);
1689 }
1690 }
1691 } else if let Some(name) = key.strip_prefix("v_") {
1692 if let Some(v_array) = self.v.get(name) {
1693 let shape = v_array.raw_dim();
1694 if let Ok(arr) = Array::from_shape_vec(shape, values) {
1695 self.v.insert(name.to_string(), arr);
1696 }
1697 }
1698 }
1699 }
1700 }
1701}
1702
1703#[derive(Debug)]
1710pub struct LarsOptimizer {
1711 config: OptimizerConfig,
1712 velocity: HashMap<String, Array<f64, Ix2>>,
1714 trust_coef: f64,
1716 exclude_bias: bool,
1718}
1719
1720impl LarsOptimizer {
1721 pub fn new(config: OptimizerConfig, trust_coef: f64, exclude_bias: bool) -> Self {
1728 Self {
1729 config,
1730 velocity: HashMap::new(),
1731 trust_coef,
1732 exclude_bias,
1733 }
1734 }
1735
1736 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
1738 if let Some(clip_value) = self.config.grad_clip {
1739 match self.config.grad_clip_mode {
1740 GradClipMode::Value => {
1741 for grad in gradients.values_mut() {
1742 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
1743 }
1744 }
1745 GradClipMode::Norm => {
1746 let total_norm = compute_gradient_norm(gradients);
1747 if total_norm > clip_value {
1748 let scale = clip_value / total_norm;
1749 for grad in gradients.values_mut() {
1750 grad.mapv_inplace(|g| g * scale);
1751 }
1752 }
1753 }
1754 }
1755 }
1756 }
1757
1758 fn compute_adaptive_lr(
1760 &self,
1761 param: &Array<f64, Ix2>,
1762 grad: &Array<f64, Ix2>,
1763 name: &str,
1764 ) -> f64 {
1765 if self.exclude_bias && (name.contains("bias") || name.contains("b")) {
1767 return self.config.learning_rate;
1768 }
1769
1770 let param_norm: f64 = param.iter().map(|&p| p * p).sum::<f64>().sqrt();
1772
1773 let grad_norm: f64 = grad.iter().map(|&g| g * g).sum::<f64>().sqrt();
1775
1776 if param_norm == 0.0 || grad_norm == 0.0 {
1778 return self.config.learning_rate;
1779 }
1780
1781 let local_lr = self.trust_coef * param_norm / grad_norm;
1783
1784 self.config.learning_rate * local_lr
1786 }
1787}
1788
1789impl Optimizer for LarsOptimizer {
1790 fn step(
1791 &mut self,
1792 parameters: &mut HashMap<String, Array<f64, Ix2>>,
1793 gradients: &HashMap<String, Array<f64, Ix2>>,
1794 ) -> TrainResult<()> {
1795 let mut clipped_gradients = gradients.clone();
1796 self.clip_gradients(&mut clipped_gradients);
1797
1798 for (name, param) in parameters.iter_mut() {
1799 let grad = clipped_gradients.get(name).ok_or_else(|| {
1800 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
1801 })?;
1802
1803 let adaptive_lr = self.compute_adaptive_lr(param, grad, name);
1805
1806 let mut effective_grad = grad.clone();
1808 if self.config.weight_decay > 0.0 {
1809 effective_grad += &(&*param * self.config.weight_decay);
1810 }
1811
1812 if !self.velocity.contains_key(name) {
1814 self.velocity
1815 .insert(name.clone(), Array::zeros(param.raw_dim()));
1816 }
1817
1818 let velocity = self.velocity.get_mut(name).unwrap();
1819
1820 velocity.mapv_inplace(|v| self.config.momentum * v);
1822 *velocity = &*velocity + &(effective_grad * adaptive_lr);
1823
1824 *param = &*param - &*velocity;
1826 }
1827
1828 Ok(())
1829 }
1830
1831 fn zero_grad(&mut self) {
1832 }
1834
1835 fn get_lr(&self) -> f64 {
1836 self.config.learning_rate
1837 }
1838
1839 fn set_lr(&mut self, lr: f64) {
1840 self.config.learning_rate = lr;
1841 }
1842
1843 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
1844 let mut state = HashMap::new();
1845 state.insert("trust_coef".to_string(), vec![self.trust_coef]);
1846 state.insert(
1847 "exclude_bias".to_string(),
1848 vec![if self.exclude_bias { 1.0 } else { 0.0 }],
1849 );
1850
1851 for (name, velocity) in &self.velocity {
1852 state.insert(
1853 format!("velocity_{}", name),
1854 velocity.iter().copied().collect(),
1855 );
1856 }
1857
1858 state
1859 }
1860
1861 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
1862 if let Some(trust) = state.get("trust_coef") {
1863 self.trust_coef = trust[0];
1864 }
1865 if let Some(exclude) = state.get("exclude_bias") {
1866 self.exclude_bias = exclude[0] > 0.5;
1867 }
1868
1869 for (key, values) in state {
1870 if let Some(name) = key.strip_prefix("velocity_") {
1871 if let Some(velocity) = self.velocity.get(name) {
1872 let shape = velocity.raw_dim();
1873 if let Ok(arr) = Array::from_shape_vec(shape, values) {
1874 self.velocity.insert(name.to_string(), arr);
1875 }
1876 }
1877 }
1878 }
1879 }
1880}
1881
1882#[derive(Debug)]
1897pub struct SamOptimizer<O: Optimizer> {
1898 base_optimizer: O,
1900 rho: f64,
1902 perturbations: HashMap<String, Array<f64, Ix2>>,
1904}
1905
1906impl<O: Optimizer> SamOptimizer<O> {
1907 pub fn new(base_optimizer: O, rho: f64) -> TrainResult<Self> {
1913 if rho <= 0.0 {
1914 return Err(TrainError::OptimizerError(
1915 "SAM rho must be positive".to_string(),
1916 ));
1917 }
1918
1919 Ok(Self {
1920 base_optimizer,
1921 rho,
1922 perturbations: HashMap::new(),
1923 })
1924 }
1925
1926 pub fn first_step(
1931 &mut self,
1932 parameters: &mut HashMap<String, Array<f64, Ix2>>,
1933 gradients: &HashMap<String, Array<f64, Ix2>>,
1934 ) -> TrainResult<()> {
1935 let grad_norm = compute_gradient_norm(gradients);
1937
1938 if grad_norm == 0.0 {
1939 return Ok(());
1940 }
1941
1942 for (name, param) in parameters.iter_mut() {
1944 let grad = gradients.get(name).ok_or_else(|| {
1945 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
1946 })?;
1947
1948 let perturbation = grad.mapv(|g| self.rho * g / grad_norm);
1950
1951 *param = &*param + &perturbation;
1953
1954 self.perturbations.insert(name.clone(), perturbation);
1956 }
1957
1958 Ok(())
1959 }
1960
1961 pub fn second_step(
1966 &mut self,
1967 parameters: &mut HashMap<String, Array<f64, Ix2>>,
1968 gradients: &HashMap<String, Array<f64, Ix2>>,
1969 ) -> TrainResult<()> {
1970 for (name, param) in parameters.iter_mut() {
1972 if let Some(perturbation) = self.perturbations.get(name) {
1973 *param = &*param - perturbation;
1974 }
1975 }
1976
1977 self.perturbations.clear();
1979
1980 self.base_optimizer.step(parameters, gradients)
1982 }
1983}
1984
1985impl<O: Optimizer> Optimizer for SamOptimizer<O> {
1986 fn step(
1987 &mut self,
1988 parameters: &mut HashMap<String, Array<f64, Ix2>>,
1989 gradients: &HashMap<String, Array<f64, Ix2>>,
1990 ) -> TrainResult<()> {
1991 self.second_step(parameters, gradients)
1994 }
1995
1996 fn zero_grad(&mut self) {
1997 self.base_optimizer.zero_grad();
1998 }
1999
2000 fn get_lr(&self) -> f64 {
2001 self.base_optimizer.get_lr()
2002 }
2003
2004 fn set_lr(&mut self, lr: f64) {
2005 self.base_optimizer.set_lr(lr);
2006 }
2007
2008 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
2009 let mut state = self.base_optimizer.state_dict();
2010 state.insert("rho".to_string(), vec![self.rho]);
2011
2012 for (name, perturbation) in &self.perturbations {
2013 state.insert(
2014 format!("perturbation_{}", name),
2015 perturbation.iter().copied().collect(),
2016 );
2017 }
2018
2019 state
2020 }
2021
2022 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
2023 if let Some(rho_val) = state.get("rho") {
2024 self.rho = rho_val[0];
2025 }
2026
2027 self.base_optimizer.load_state_dict(state.clone());
2029
2030 for (key, values) in state {
2032 if let Some(name) = key.strip_prefix("perturbation_") {
2033 if let Some(pert) = self.perturbations.get(name) {
2034 let shape = pert.raw_dim();
2035 if let Ok(arr) = Array::from_shape_vec(shape, values) {
2036 self.perturbations.insert(name.to_string(), arr);
2037 }
2038 }
2039 }
2040 }
2041 }
2042}
2043
2044#[cfg(test)]
2045mod tests {
2046 use super::*;
2047 use scirs2_core::ndarray::array;
2048
2049 #[test]
2050 fn test_sgd_optimizer() {
2051 let config = OptimizerConfig {
2052 learning_rate: 0.1,
2053 momentum: 0.9,
2054 ..Default::default()
2055 };
2056 let mut optimizer = SgdOptimizer::new(config);
2057
2058 let mut params = HashMap::new();
2059 params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2060
2061 let mut grads = HashMap::new();
2062 grads.insert("w".to_string(), array![[0.1, 0.1], [0.1, 0.1]]);
2063
2064 optimizer.step(&mut params, &grads).unwrap();
2065
2066 let w = params.get("w").unwrap();
2067 assert!(w[[0, 0]] < 1.0);
2068 assert!(w[[0, 1]] < 2.0);
2069 }
2070
2071 #[test]
2072 fn test_adam_optimizer() {
2073 let config = OptimizerConfig {
2074 learning_rate: 0.001,
2075 ..Default::default()
2076 };
2077 let mut optimizer = AdamOptimizer::new(config);
2078
2079 let mut params = HashMap::new();
2080 params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2081
2082 let mut grads = HashMap::new();
2083 grads.insert("w".to_string(), array![[0.1, 0.1], [0.1, 0.1]]);
2084
2085 optimizer.step(&mut params, &grads).unwrap();
2086
2087 let w = params.get("w").unwrap();
2088 assert!(w[[0, 0]] < 1.0);
2089 }
2090
2091 #[test]
2092 fn test_adamw_optimizer() {
2093 let config = OptimizerConfig {
2094 learning_rate: 0.001,
2095 weight_decay: 0.01,
2096 ..Default::default()
2097 };
2098 let mut optimizer = AdamWOptimizer::new(config);
2099
2100 let mut params = HashMap::new();
2101 params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2102
2103 let mut grads = HashMap::new();
2104 grads.insert("w".to_string(), array![[0.1, 0.1], [0.1, 0.1]]);
2105
2106 optimizer.step(&mut params, &grads).unwrap();
2107
2108 let w = params.get("w").unwrap();
2109 assert!(w[[0, 0]] < 1.0);
2110 }
2111
2112 #[test]
2113 fn test_gradient_clipping() {
2114 let config = OptimizerConfig {
2115 learning_rate: 0.1,
2116 grad_clip: Some(0.05),
2117 ..Default::default()
2118 };
2119 let mut optimizer = SgdOptimizer::new(config);
2120
2121 let mut params = HashMap::new();
2122 params.insert("w".to_string(), array![[1.0]]);
2123
2124 let mut grads = HashMap::new();
2125 grads.insert("w".to_string(), array![[10.0]]); optimizer.step(&mut params, &grads).unwrap();
2128
2129 let w = params.get("w").unwrap();
2130 assert!((w[[0, 0]] - (1.0 - 0.1 * 0.05)).abs() < 1e-6);
2132 }
2133
2134 #[test]
2135 fn test_rmsprop_optimizer() {
2136 let config = OptimizerConfig {
2137 learning_rate: 0.01,
2138 ..Default::default()
2139 };
2140 let mut optimizer = RMSpropOptimizer::new(config);
2141
2142 let mut params = HashMap::new();
2143 params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2144
2145 let mut grads = HashMap::new();
2146 grads.insert("w".to_string(), array![[0.1, 0.1], [0.1, 0.1]]);
2147
2148 optimizer.step(&mut params, &grads).unwrap();
2149
2150 let w = params.get("w").unwrap();
2151 assert!(w[[0, 0]] < 1.0);
2152 }
2153
2154 #[test]
2155 fn test_adagrad_optimizer() {
2156 let config = OptimizerConfig {
2157 learning_rate: 0.1,
2158 ..Default::default()
2159 };
2160 let mut optimizer = AdagradOptimizer::new(config);
2161
2162 let mut params = HashMap::new();
2163 params.insert("w".to_string(), array![[1.0, 2.0]]);
2164
2165 let mut grads = HashMap::new();
2166 grads.insert("w".to_string(), array![[0.1, 0.2]]);
2167
2168 optimizer.step(&mut params, &grads).unwrap();
2169
2170 let w = params.get("w").unwrap();
2171 assert!(w[[0, 0]] < 1.0);
2172 assert!(w[[0, 1]] < 2.0);
2173 }
2174
2175 #[test]
2176 fn test_nadam_optimizer() {
2177 let config = OptimizerConfig {
2178 learning_rate: 0.002,
2179 ..Default::default()
2180 };
2181 let mut optimizer = NAdamOptimizer::new(config);
2182
2183 let mut params = HashMap::new();
2184 params.insert("w".to_string(), array![[1.0, 2.0]]);
2185
2186 let mut grads = HashMap::new();
2187 grads.insert("w".to_string(), array![[0.1, 0.1]]);
2188
2189 optimizer.step(&mut params, &grads).unwrap();
2190
2191 let w = params.get("w").unwrap();
2192 assert!(w[[0, 0]] < 1.0);
2193 }
2194
2195 #[test]
2196 fn test_lamb_optimizer() {
2197 let config = OptimizerConfig {
2198 learning_rate: 0.001,
2199 weight_decay: 0.01,
2200 ..Default::default()
2201 };
2202 let mut optimizer = LambOptimizer::new(config);
2203
2204 let mut params = HashMap::new();
2205 params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2206
2207 let mut grads = HashMap::new();
2208 grads.insert("w".to_string(), array![[0.1, 0.1], [0.1, 0.1]]);
2209
2210 optimizer.step(&mut params, &grads).unwrap();
2211
2212 let w = params.get("w").unwrap();
2213 assert!(w[[0, 0]] < 1.0);
2214 }
2215
2216 #[test]
2217 fn test_adamax_optimizer() {
2218 let config = OptimizerConfig {
2219 learning_rate: 0.002,
2220 ..Default::default()
2221 };
2222 let mut optimizer = AdaMaxOptimizer::new(config);
2223
2224 let mut params = HashMap::new();
2225 params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2226
2227 let mut grads = HashMap::new();
2228 grads.insert("w".to_string(), array![[0.1, 0.2], [0.3, 0.4]]);
2229
2230 for _ in 0..3 {
2232 optimizer.step(&mut params, &grads).unwrap();
2233 }
2234
2235 let w = params.get("w").unwrap();
2236 assert!(w[[0, 0]] < 1.0);
2238 assert!(w[[0, 1]] < 2.0);
2239 assert!(w[[1, 0]] < 3.0);
2240 assert!(w[[1, 1]] < 4.0);
2241
2242 let state = optimizer.state_dict();
2244 assert!(state.contains_key("t"));
2245 assert!(state.contains_key("m_w"));
2246 assert!(state.contains_key("u_w"));
2247 }
2248
2249 #[test]
2250 fn test_lookahead_optimizer() {
2251 let inner_config = OptimizerConfig {
2252 learning_rate: 0.01,
2253 ..Default::default()
2254 };
2255 let inner_optimizer = AdamOptimizer::new(inner_config);
2256
2257 let mut optimizer = LookaheadOptimizer::new(inner_optimizer, 0.5, 5).unwrap();
2258
2259 let mut params = HashMap::new();
2260 params.insert("w".to_string(), array![[1.0, 2.0]]);
2261
2262 let mut grads = HashMap::new();
2263 grads.insert("w".to_string(), array![[0.1, 0.1]]);
2264
2265 for _ in 0..10 {
2267 optimizer.step(&mut params, &grads).unwrap();
2268 }
2269
2270 let w = params.get("w").unwrap();
2271 assert!(w[[0, 0]] < 1.0);
2273 assert!(w[[0, 1]] < 2.0);
2274
2275 assert_eq!(optimizer.get_lr(), 0.01);
2277
2278 optimizer.set_lr(0.02);
2279 assert_eq!(optimizer.get_lr(), 0.02);
2280 }
2281
2282 #[test]
2283 fn test_lookahead_invalid_alpha() {
2284 let inner_optimizer = AdamOptimizer::new(OptimizerConfig::default());
2285
2286 let result = LookaheadOptimizer::new(inner_optimizer, 1.5, 5);
2287 assert!(result.is_err());
2288
2289 let inner_optimizer = AdamOptimizer::new(OptimizerConfig::default());
2290 let result = LookaheadOptimizer::new(inner_optimizer, -0.1, 5);
2291 assert!(result.is_err());
2292 }
2293
2294 #[test]
2295 fn test_lookahead_invalid_k() {
2296 let inner_optimizer = AdamOptimizer::new(OptimizerConfig::default());
2297
2298 let result = LookaheadOptimizer::new(inner_optimizer, 0.5, 0);
2299 assert!(result.is_err());
2300 }
2301
2302 #[test]
2303 fn test_lookahead_synchronization() {
2304 let inner_config = OptimizerConfig {
2305 learning_rate: 0.1,
2306 ..Default::default()
2307 };
2308 let inner_optimizer = SgdOptimizer::new(inner_config);
2309
2310 let mut optimizer = LookaheadOptimizer::new(inner_optimizer, 0.5, 3).unwrap();
2311
2312 let mut params = HashMap::new();
2313 params.insert("w".to_string(), array![[1.0]]);
2314
2315 let mut grads = HashMap::new();
2316 grads.insert("w".to_string(), array![[0.1]]);
2317
2318 let initial_w = params.get("w").unwrap()[[0, 0]];
2319
2320 for _ in 0..3 {
2322 optimizer.step(&mut params, &grads).unwrap();
2323 }
2324
2325 let w_after_sync = params.get("w").unwrap()[[0, 0]];
2326
2327 assert_ne!(w_after_sync, initial_w);
2329 assert!(w_after_sync < initial_w);
2330 }
2331
2332 #[test]
2333 fn test_adabelief_optimizer() {
2334 let config = OptimizerConfig {
2335 learning_rate: 0.001,
2336 weight_decay: 0.01,
2337 ..Default::default()
2338 };
2339 let mut optimizer = AdaBeliefOptimizer::new(config);
2340
2341 let mut params = HashMap::new();
2342 params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2343
2344 let mut grads = HashMap::new();
2345 grads.insert("w".to_string(), array![[0.1, 0.2], [0.3, 0.4]]);
2346
2347 for _ in 0..5 {
2349 optimizer.step(&mut params, &grads).unwrap();
2350 }
2351
2352 let w = params.get("w").unwrap();
2353 assert!(w[[0, 0]] < 1.0);
2355 assert!(w[[1, 1]] < 4.0);
2356
2357 let state = optimizer.state_dict();
2359 assert!(state.contains_key("t"));
2360 assert!(state.contains_key("m_w"));
2361 assert!(state.contains_key("s_w"));
2362 }
2363
2364 #[test]
2365 fn test_radam_optimizer() {
2366 let config = OptimizerConfig {
2367 learning_rate: 0.001,
2368 ..Default::default()
2369 };
2370 let mut optimizer = RAdamOptimizer::new(config);
2371
2372 let mut params = HashMap::new();
2373 params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2374
2375 let mut grads = HashMap::new();
2376 grads.insert("w".to_string(), array![[0.1, 0.1], [0.1, 0.1]]);
2377
2378 for _ in 0..10 {
2380 optimizer.step(&mut params, &grads).unwrap();
2381 }
2382
2383 let w = params.get("w").unwrap();
2384 assert!(w[[0, 0]] < 1.0);
2386 assert!(w[[0, 1]] < 2.0);
2387
2388 let state = optimizer.state_dict();
2390 assert!(state.contains_key("t"));
2391 assert!(state.contains_key("m_w"));
2392 assert!(state.contains_key("v_w"));
2393 }
2394
2395 #[test]
2396 fn test_lars_optimizer() {
2397 let config = OptimizerConfig {
2398 learning_rate: 0.1,
2399 momentum: 0.9,
2400 weight_decay: 0.0001,
2401 ..Default::default()
2402 };
2403 let mut optimizer = LarsOptimizer::new(config, 0.001, true);
2404
2405 let mut params = HashMap::new();
2406 params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
2407
2408 let mut grads = HashMap::new();
2409 grads.insert("w".to_string(), array![[0.1, 0.1], [0.1, 0.1]]);
2410
2411 optimizer.step(&mut params, &grads).unwrap();
2412
2413 let w = params.get("w").unwrap();
2414 assert!(w[[0, 0]] < 1.0);
2416 assert!(w[[1, 1]] < 4.0);
2417
2418 let state = optimizer.state_dict();
2420 assert!(state.contains_key("trust_coef"));
2421 assert!(state.contains_key("exclude_bias"));
2422 assert!(state.contains_key("velocity_w"));
2423 }
2424
2425 #[test]
2426 fn test_lars_bias_exclusion() {
2427 let config = OptimizerConfig {
2428 learning_rate: 0.1,
2429 momentum: 0.9,
2430 ..Default::default()
2431 };
2432
2433 let mut optimizer = LarsOptimizer::new(config.clone(), 0.001, true);
2435
2436 let mut params = HashMap::new();
2437 params.insert("weights".to_string(), array![[1.0, 2.0]]);
2438 params.insert("bias".to_string(), array![[1.0, 2.0]]);
2439
2440 let mut grads = HashMap::new();
2441 grads.insert("weights".to_string(), array![[0.1, 0.1]]);
2442 grads.insert("bias".to_string(), array![[0.1, 0.1]]);
2443
2444 optimizer.step(&mut params, &grads).unwrap();
2445
2446 let weights = params.get("weights").unwrap();
2448 let bias = params.get("bias").unwrap();
2449 assert!(weights[[0, 0]] < 1.0);
2450 assert!(bias[[0, 0]] < 1.0);
2451 }
2452
2453 #[test]
2454 fn test_sam_optimizer() {
2455 let inner_config = OptimizerConfig {
2456 learning_rate: 0.01,
2457 ..Default::default()
2458 };
2459 let inner_optimizer = SgdOptimizer::new(inner_config);
2460
2461 let mut optimizer = SamOptimizer::new(inner_optimizer, 0.05).unwrap();
2462
2463 let mut params = HashMap::new();
2464 params.insert("w".to_string(), array![[1.0, 2.0]]);
2465
2466 let mut grads = HashMap::new();
2467 grads.insert("w".to_string(), array![[0.1, 0.1]]);
2468
2469 let original_w = params.get("w").unwrap().clone();
2471 optimizer.first_step(&mut params, &grads).unwrap();
2472
2473 let perturbed_w = params.get("w").unwrap();
2475 assert_ne!(perturbed_w[[0, 0]], original_w[[0, 0]]);
2476
2477 optimizer.second_step(&mut params, &grads).unwrap();
2479
2480 let final_w = params.get("w").unwrap();
2482 assert!(final_w[[0, 0]] < original_w[[0, 0]]);
2483
2484 let state = optimizer.state_dict();
2486 assert!(state.contains_key("rho"));
2487 }
2488
2489 #[test]
2490 fn test_sam_invalid_rho() {
2491 let inner_optimizer = SgdOptimizer::new(OptimizerConfig::default());
2492
2493 let result = SamOptimizer::new(inner_optimizer, 0.0);
2494 assert!(result.is_err());
2495
2496 let inner_optimizer = SgdOptimizer::new(OptimizerConfig::default());
2497 let result = SamOptimizer::new(inner_optimizer, -0.1);
2498 assert!(result.is_err());
2499 }
2500}