1use crate::Optimizer;
4
5pub trait LrScheduler {
7 fn step(&mut self, optimizer: &mut dyn Optimizer);
9
10 fn get_lr(&self) -> f64;
12
13 fn state_dict(&self) -> std::collections::HashMap<String, f64>;
15
16 fn load_state_dict(
18 &mut self,
19 state: &std::collections::HashMap<String, f64>,
20 ) -> crate::TrainResult<()>;
21}
22
23#[derive(Debug, Clone)]
26pub struct StepLrScheduler {
27 pub initial_lr: f64,
29 pub step_size: usize,
31 pub gamma: f64,
33 current_epoch: usize,
35 current_lr: f64,
37}
38
39impl StepLrScheduler {
40 pub fn new(initial_lr: f64, step_size: usize, gamma: f64) -> Self {
42 Self {
43 initial_lr,
44 step_size,
45 gamma,
46 current_epoch: 0,
47 current_lr: initial_lr,
48 }
49 }
50}
51
52impl LrScheduler for StepLrScheduler {
53 fn step(&mut self, optimizer: &mut dyn Optimizer) {
54 self.current_epoch += 1;
55
56 if self.current_epoch.is_multiple_of(self.step_size) {
57 self.current_lr *= self.gamma;
58 optimizer.set_lr(self.current_lr);
59 }
60 }
61
62 fn get_lr(&self) -> f64 {
63 self.current_lr
64 }
65
66 fn state_dict(&self) -> std::collections::HashMap<String, f64> {
67 let mut state = std::collections::HashMap::new();
68 state.insert("initial_lr".to_string(), self.initial_lr);
69 state.insert("current_lr".to_string(), self.current_lr);
70 state.insert("current_epoch".to_string(), self.current_epoch as f64);
71 state.insert("step_size".to_string(), self.step_size as f64);
72 state.insert("gamma".to_string(), self.gamma);
73 state
74 }
75
76 fn load_state_dict(
77 &mut self,
78 state: &std::collections::HashMap<String, f64>,
79 ) -> crate::TrainResult<()> {
80 if let Some(¤t_lr) = state.get("current_lr") {
81 self.current_lr = current_lr;
82 }
83 if let Some(¤t_epoch) = state.get("current_epoch") {
84 self.current_epoch = current_epoch as usize;
85 }
86 Ok(())
87 }
88}
89
90#[derive(Debug, Clone)]
93pub struct ExponentialLrScheduler {
94 pub initial_lr: f64,
96 pub gamma: f64,
98 current_epoch: usize,
100 current_lr: f64,
102}
103
104impl ExponentialLrScheduler {
105 pub fn new(initial_lr: f64, gamma: f64) -> Self {
107 Self {
108 initial_lr,
109 gamma,
110 current_epoch: 0,
111 current_lr: initial_lr,
112 }
113 }
114}
115
116impl LrScheduler for ExponentialLrScheduler {
117 fn step(&mut self, optimizer: &mut dyn Optimizer) {
118 self.current_epoch += 1;
119 self.current_lr = self.initial_lr * self.gamma.powi(self.current_epoch as i32);
120 optimizer.set_lr(self.current_lr);
121 }
122
123 fn get_lr(&self) -> f64 {
124 self.current_lr
125 }
126
127 fn state_dict(&self) -> std::collections::HashMap<String, f64> {
128 let mut state = std::collections::HashMap::new();
129 state.insert("initial_lr".to_string(), self.initial_lr);
130 state.insert("current_lr".to_string(), self.current_lr);
131 state.insert("current_epoch".to_string(), self.current_epoch as f64);
132 state.insert("gamma".to_string(), self.gamma);
133 state
134 }
135
136 fn load_state_dict(
137 &mut self,
138 state: &std::collections::HashMap<String, f64>,
139 ) -> crate::TrainResult<()> {
140 if let Some(¤t_lr) = state.get("current_lr") {
141 self.current_lr = current_lr;
142 }
143 if let Some(¤t_epoch) = state.get("current_epoch") {
144 self.current_epoch = current_epoch as usize;
145 }
146 Ok(())
147 }
148}
149
150#[derive(Debug, Clone)]
153pub struct CosineAnnealingLrScheduler {
154 pub initial_lr: f64,
156 pub min_lr: f64,
158 pub t_max: usize,
160 current_epoch: usize,
162 current_lr: f64,
164}
165
166impl CosineAnnealingLrScheduler {
167 pub fn new(initial_lr: f64, min_lr: f64, t_max: usize) -> Self {
169 Self {
170 initial_lr,
171 min_lr,
172 t_max,
173 current_epoch: 0,
174 current_lr: initial_lr,
175 }
176 }
177}
178
179impl LrScheduler for CosineAnnealingLrScheduler {
180 fn step(&mut self, optimizer: &mut dyn Optimizer) {
181 self.current_epoch += 1;
182
183 let progress = (self.current_epoch as f64) / (self.t_max as f64);
184 let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * progress).cos());
185 self.current_lr = self.min_lr + (self.initial_lr - self.min_lr) * cosine_decay;
186
187 optimizer.set_lr(self.current_lr);
188 }
189
190 fn get_lr(&self) -> f64 {
191 self.current_lr
192 }
193
194 fn state_dict(&self) -> std::collections::HashMap<String, f64> {
195 let mut state = std::collections::HashMap::new();
196 state.insert("initial_lr".to_string(), self.initial_lr);
197 state.insert("current_lr".to_string(), self.current_lr);
198 state.insert("current_epoch".to_string(), self.current_epoch as f64);
199 state.insert("min_lr".to_string(), self.min_lr);
200 state.insert("t_max".to_string(), self.t_max as f64);
201 state
202 }
203
204 fn load_state_dict(
205 &mut self,
206 state: &std::collections::HashMap<String, f64>,
207 ) -> crate::TrainResult<()> {
208 if let Some(¤t_lr) = state.get("current_lr") {
209 self.current_lr = current_lr;
210 }
211 if let Some(¤t_epoch) = state.get("current_epoch") {
212 self.current_epoch = current_epoch as usize;
213 }
214 Ok(())
215 }
216}
217
218#[derive(Debug, Clone)]
220#[allow(dead_code)]
221pub struct WarmupScheduler {
222 pub target_lr: f64,
224 pub warmup_steps: usize,
226 current_step: usize,
228 current_lr: f64,
230}
231
232impl WarmupScheduler {
233 #[allow(dead_code)]
235 pub fn new(target_lr: f64, warmup_steps: usize) -> Self {
236 Self {
237 target_lr,
238 warmup_steps,
239 current_step: 0,
240 current_lr: 0.0,
241 }
242 }
243}
244
245impl LrScheduler for WarmupScheduler {
246 fn step(&mut self, optimizer: &mut dyn Optimizer) {
247 self.current_step += 1;
248
249 if self.current_step < self.warmup_steps {
250 self.current_lr =
251 self.target_lr * (self.current_step as f64) / (self.warmup_steps as f64);
252 } else {
253 self.current_lr = self.target_lr;
254 }
255
256 optimizer.set_lr(self.current_lr);
257 }
258
259 fn get_lr(&self) -> f64 {
260 self.current_lr
261 }
262
263 fn state_dict(&self) -> std::collections::HashMap<String, f64> {
264 let mut state = std::collections::HashMap::new();
265 state.insert("target_lr".to_string(), self.target_lr);
266 state.insert("current_lr".to_string(), self.current_lr);
267 state.insert("current_step".to_string(), self.current_step as f64);
268 state.insert("warmup_steps".to_string(), self.warmup_steps as f64);
269 state
270 }
271
272 fn load_state_dict(
273 &mut self,
274 state: &std::collections::HashMap<String, f64>,
275 ) -> crate::TrainResult<()> {
276 if let Some(¤t_lr) = state.get("current_lr") {
277 self.current_lr = current_lr;
278 }
279 if let Some(¤t_step) = state.get("current_step") {
280 self.current_step = current_step as usize;
281 }
282 Ok(())
283 }
284}
285
286#[derive(Debug, Clone)]
289pub struct OneCycleLrScheduler {
290 pub initial_lr: f64,
292 pub max_lr: f64,
294 pub min_lr: f64,
296 pub total_steps: usize,
298 pub pct_start: f64,
300 current_step: usize,
302 current_lr: f64,
304}
305
306impl OneCycleLrScheduler {
307 pub fn new(
309 initial_lr: f64,
310 max_lr: f64,
311 min_lr: f64,
312 total_steps: usize,
313 pct_start: f64,
314 ) -> Self {
315 Self {
316 initial_lr,
317 max_lr,
318 min_lr,
319 total_steps,
320 pct_start,
321 current_step: 0,
322 current_lr: initial_lr,
323 }
324 }
325}
326
327impl LrScheduler for OneCycleLrScheduler {
328 fn step(&mut self, optimizer: &mut dyn Optimizer) {
329 self.current_step += 1;
330
331 let step_num = self.current_step.min(self.total_steps);
332 let pct = step_num as f64 / self.total_steps as f64;
333
334 if pct < self.pct_start {
335 let phase_pct = pct / self.pct_start;
337 self.current_lr = self.initial_lr + (self.max_lr - self.initial_lr) * phase_pct;
338 } else {
339 let phase_pct = (pct - self.pct_start) / (1.0 - self.pct_start);
341 let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * phase_pct).cos());
343 self.current_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay;
344 }
345
346 optimizer.set_lr(self.current_lr);
347 }
348
349 fn get_lr(&self) -> f64 {
350 self.current_lr
351 }
352
353 fn state_dict(&self) -> std::collections::HashMap<String, f64> {
354 let mut state = std::collections::HashMap::new();
355 state.insert("initial_lr".to_string(), self.initial_lr);
356 state.insert("max_lr".to_string(), self.max_lr);
357 state.insert("min_lr".to_string(), self.min_lr);
358 state.insert("current_lr".to_string(), self.current_lr);
359 state.insert("current_step".to_string(), self.current_step as f64);
360 state.insert("total_steps".to_string(), self.total_steps as f64);
361 state.insert("pct_start".to_string(), self.pct_start);
362 state
363 }
364
365 fn load_state_dict(
366 &mut self,
367 state: &std::collections::HashMap<String, f64>,
368 ) -> crate::TrainResult<()> {
369 if let Some(¤t_lr) = state.get("current_lr") {
370 self.current_lr = current_lr;
371 }
372 if let Some(¤t_step) = state.get("current_step") {
373 self.current_step = current_step as usize;
374 }
375 Ok(())
376 }
377}
378
379#[derive(Debug, Clone)]
381pub struct PolynomialDecayLrScheduler {
382 pub initial_lr: f64,
384 pub final_lr: f64,
386 pub power: f64,
388 pub decay_steps: usize,
390 current_step: usize,
392 current_lr: f64,
394}
395
396impl PolynomialDecayLrScheduler {
397 pub fn new(initial_lr: f64, final_lr: f64, power: f64, decay_steps: usize) -> Self {
399 Self {
400 initial_lr,
401 final_lr,
402 power,
403 decay_steps,
404 current_step: 0,
405 current_lr: initial_lr,
406 }
407 }
408}
409
410impl LrScheduler for PolynomialDecayLrScheduler {
411 fn step(&mut self, optimizer: &mut dyn Optimizer) {
412 self.current_step += 1;
413
414 let step_num = self.current_step.min(self.decay_steps);
415 let decay_factor = (1.0 - (step_num as f64 / self.decay_steps as f64)).powf(self.power);
416 self.current_lr = (self.initial_lr - self.final_lr) * decay_factor + self.final_lr;
417
418 optimizer.set_lr(self.current_lr);
419 }
420
421 fn get_lr(&self) -> f64 {
422 self.current_lr
423 }
424
425 fn state_dict(&self) -> std::collections::HashMap<String, f64> {
426 let mut state = std::collections::HashMap::new();
427 state.insert("initial_lr".to_string(), self.initial_lr);
428 state.insert("final_lr".to_string(), self.final_lr);
429 state.insert("power".to_string(), self.power);
430 state.insert("current_lr".to_string(), self.current_lr);
431 state.insert("current_step".to_string(), self.current_step as f64);
432 state.insert("decay_steps".to_string(), self.decay_steps as f64);
433 state
434 }
435
436 fn load_state_dict(
437 &mut self,
438 state: &std::collections::HashMap<String, f64>,
439 ) -> crate::TrainResult<()> {
440 if let Some(¤t_lr) = state.get("current_lr") {
441 self.current_lr = current_lr;
442 }
443 if let Some(¤t_step) = state.get("current_step") {
444 self.current_step = current_step as usize;
445 }
446 Ok(())
447 }
448}
449
450#[derive(Debug, Clone, Copy, PartialEq)]
452pub enum CyclicLrMode {
453 Triangular,
455 Triangular2,
457 ExpRange,
459}
460
461#[derive(Debug, Clone)]
463pub struct CyclicLrScheduler {
464 pub base_lr: f64,
466 pub max_lr: f64,
468 pub step_size: usize,
470 pub mode: CyclicLrMode,
472 pub gamma: f64,
474 current_step: usize,
476 current_lr: f64,
478 cycle: usize,
480}
481
482impl CyclicLrScheduler {
483 pub fn new(base_lr: f64, max_lr: f64, step_size: usize, mode: CyclicLrMode) -> Self {
485 Self {
486 base_lr,
487 max_lr,
488 step_size,
489 mode,
490 gamma: 0.99994,
491 current_step: 0,
492 current_lr: base_lr,
493 cycle: 0,
494 }
495 }
496
497 pub fn new_exp_range(base_lr: f64, max_lr: f64, step_size: usize, gamma: f64) -> Self {
499 Self {
500 base_lr,
501 max_lr,
502 step_size,
503 mode: CyclicLrMode::ExpRange,
504 gamma,
505 current_step: 0,
506 current_lr: base_lr,
507 cycle: 0,
508 }
509 }
510}
511
512impl LrScheduler for CyclicLrScheduler {
513 fn step(&mut self, optimizer: &mut dyn Optimizer) {
514 self.current_step += 1;
515
516 let cycle = (self.current_step - 1) / (2 * self.step_size);
518 let x = ((self.current_step - 1) as f64 / self.step_size as f64).abs() % 2.0;
519
520 let scale_fn = match self.mode {
522 CyclicLrMode::Triangular => 1.0,
523 CyclicLrMode::Triangular2 => 1.0 / 2.0_f64.powi(cycle as i32),
524 CyclicLrMode::ExpRange => self.gamma.powi(self.current_step as i32),
525 };
526
527 if x <= 1.0 {
529 self.current_lr = self.base_lr + (self.max_lr - self.base_lr) * x * scale_fn;
531 } else {
532 self.current_lr = self.base_lr + (self.max_lr - self.base_lr) * (2.0 - x) * scale_fn;
534 }
535
536 self.cycle = cycle;
537 optimizer.set_lr(self.current_lr);
538 }
539
540 fn get_lr(&self) -> f64 {
541 self.current_lr
542 }
543
544 fn state_dict(&self) -> std::collections::HashMap<String, f64> {
545 let mut state = std::collections::HashMap::new();
546 state.insert("base_lr".to_string(), self.base_lr);
547 state.insert("max_lr".to_string(), self.max_lr);
548 state.insert("current_lr".to_string(), self.current_lr);
549 state.insert("current_step".to_string(), self.current_step as f64);
550 state.insert("step_size".to_string(), self.step_size as f64);
551 state.insert("cycle".to_string(), self.cycle as f64);
552 state.insert("gamma".to_string(), self.gamma);
553 state
554 }
555
556 fn load_state_dict(
557 &mut self,
558 state: &std::collections::HashMap<String, f64>,
559 ) -> crate::TrainResult<()> {
560 if let Some(¤t_lr) = state.get("current_lr") {
561 self.current_lr = current_lr;
562 }
563 if let Some(¤t_step) = state.get("current_step") {
564 self.current_step = current_step as usize;
565 }
566 if let Some(&cycle) = state.get("cycle") {
567 self.cycle = cycle as usize;
568 }
569 Ok(())
570 }
571}
572
573#[derive(Debug, Clone)]
575pub struct WarmupCosineLrScheduler {
576 pub target_lr: f64,
578 pub min_lr: f64,
580 pub warmup_steps: usize,
582 pub total_steps: usize,
584 current_step: usize,
586 current_lr: f64,
588}
589
590impl WarmupCosineLrScheduler {
591 pub fn new(target_lr: f64, min_lr: f64, warmup_steps: usize, total_steps: usize) -> Self {
593 Self {
594 target_lr,
595 min_lr,
596 warmup_steps,
597 total_steps,
598 current_step: 0,
599 current_lr: 0.0,
600 }
601 }
602}
603
604impl LrScheduler for WarmupCosineLrScheduler {
605 fn step(&mut self, optimizer: &mut dyn Optimizer) {
606 self.current_step += 1;
607
608 if self.current_step <= self.warmup_steps {
609 self.current_lr =
611 self.target_lr * (self.current_step as f64 / self.warmup_steps as f64);
612 } else {
613 let progress = (self.current_step - self.warmup_steps) as f64
615 / (self.total_steps - self.warmup_steps) as f64;
616 let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * progress).cos());
617 self.current_lr = self.min_lr + (self.target_lr - self.min_lr) * cosine_decay;
618 }
619
620 optimizer.set_lr(self.current_lr);
621 }
622
623 fn get_lr(&self) -> f64 {
624 self.current_lr
625 }
626
627 fn state_dict(&self) -> std::collections::HashMap<String, f64> {
628 let mut state = std::collections::HashMap::new();
629 state.insert("target_lr".to_string(), self.target_lr);
630 state.insert("min_lr".to_string(), self.min_lr);
631 state.insert("current_lr".to_string(), self.current_lr);
632 state.insert("current_step".to_string(), self.current_step as f64);
633 state.insert("warmup_steps".to_string(), self.warmup_steps as f64);
634 state.insert("total_steps".to_string(), self.total_steps as f64);
635 state
636 }
637
638 fn load_state_dict(
639 &mut self,
640 state: &std::collections::HashMap<String, f64>,
641 ) -> crate::TrainResult<()> {
642 if let Some(¤t_lr) = state.get("current_lr") {
643 self.current_lr = current_lr;
644 }
645 if let Some(¤t_step) = state.get("current_step") {
646 self.current_step = current_step as usize;
647 }
648 Ok(())
649 }
650}
651
652#[derive(Debug, Clone)]
660pub struct NoamScheduler {
661 model_dim: f64,
663 warmup_steps: usize,
665 scale_factor: f64,
667 current_step: usize,
669 current_lr: f64,
671}
672
673impl NoamScheduler {
674 pub fn new(model_dim: usize, warmup_steps: usize, scale_factor: f64) -> Self {
681 let model_dim_f64 = model_dim as f64;
682 let current_lr = scale_factor * model_dim_f64.powf(-0.5);
683
684 Self {
685 model_dim: model_dim_f64,
686 warmup_steps,
687 scale_factor,
688 current_step: 0,
689 current_lr,
690 }
691 }
692
693 fn compute_lr(&self) -> f64 {
695 let step = (self.current_step + 1) as f64; let warmup = self.warmup_steps as f64;
697
698 self.scale_factor
700 * self.model_dim.powf(-0.5)
701 * step.powf(-0.5).min(step * warmup.powf(-1.5))
702 }
703}
704
705impl LrScheduler for NoamScheduler {
706 fn step(&mut self, optimizer: &mut dyn Optimizer) {
707 self.current_step += 1;
708 self.current_lr = self.compute_lr();
709 optimizer.set_lr(self.current_lr);
710 }
711
712 fn get_lr(&self) -> f64 {
713 self.current_lr
714 }
715
716 fn state_dict(&self) -> std::collections::HashMap<String, f64> {
717 let mut state = std::collections::HashMap::new();
718 state.insert("model_dim".to_string(), self.model_dim);
719 state.insert("warmup_steps".to_string(), self.warmup_steps as f64);
720 state.insert("scale_factor".to_string(), self.scale_factor);
721 state.insert("current_step".to_string(), self.current_step as f64);
722 state.insert("current_lr".to_string(), self.current_lr);
723 state
724 }
725
726 fn load_state_dict(
727 &mut self,
728 state: &std::collections::HashMap<String, f64>,
729 ) -> crate::TrainResult<()> {
730 if let Some(¤t_step) = state.get("current_step") {
731 self.current_step = current_step as usize;
732 }
733 if let Some(¤t_lr) = state.get("current_lr") {
734 self.current_lr = current_lr;
735 }
736 Ok(())
737 }
738}
739
740#[derive(Debug, Clone)]
745pub struct MultiStepLrScheduler {
746 pub initial_lr: f64,
748 pub milestones: Vec<usize>,
750 pub gamma: f64,
752 current_epoch: usize,
754 current_lr: f64,
756 next_milestone_idx: usize,
758}
759
760impl MultiStepLrScheduler {
761 pub fn new(initial_lr: f64, mut milestones: Vec<usize>, gamma: f64) -> Self {
768 milestones.sort_unstable();
770
771 Self {
772 initial_lr,
773 milestones,
774 gamma,
775 current_epoch: 0,
776 current_lr: initial_lr,
777 next_milestone_idx: 0,
778 }
779 }
780}
781
782impl LrScheduler for MultiStepLrScheduler {
783 fn step(&mut self, optimizer: &mut dyn Optimizer) {
784 self.current_epoch += 1;
785
786 if self.next_milestone_idx < self.milestones.len()
788 && self.current_epoch >= self.milestones[self.next_milestone_idx]
789 {
790 self.current_lr *= self.gamma;
791 self.next_milestone_idx += 1;
792 optimizer.set_lr(self.current_lr);
793 }
794 }
795
796 fn get_lr(&self) -> f64 {
797 self.current_lr
798 }
799
800 fn state_dict(&self) -> std::collections::HashMap<String, f64> {
801 let mut state = std::collections::HashMap::new();
802 state.insert("initial_lr".to_string(), self.initial_lr);
803 state.insert("current_lr".to_string(), self.current_lr);
804 state.insert("current_epoch".to_string(), self.current_epoch as f64);
805 state.insert("gamma".to_string(), self.gamma);
806 state.insert(
807 "next_milestone_idx".to_string(),
808 self.next_milestone_idx as f64,
809 );
810 state
811 }
812
813 fn load_state_dict(
814 &mut self,
815 state: &std::collections::HashMap<String, f64>,
816 ) -> crate::TrainResult<()> {
817 if let Some(¤t_lr) = state.get("current_lr") {
818 self.current_lr = current_lr;
819 }
820 if let Some(¤t_epoch) = state.get("current_epoch") {
821 self.current_epoch = current_epoch as usize;
822 }
823 if let Some(&next_milestone_idx) = state.get("next_milestone_idx") {
824 self.next_milestone_idx = next_milestone_idx as usize;
825 }
826 Ok(())
827 }
828}
829
830#[derive(Debug, Clone)]
835pub struct ReduceLROnPlateauScheduler {
836 current_lr: f64,
838 pub factor: f64,
840 pub patience: usize,
842 pub min_lr: f64,
844 pub threshold: f64,
846 pub cooldown: usize,
848 best_metric: Option<f64>,
850 num_bad_epochs: usize,
852 cooldown_counter: usize,
854 mode: PlateauMode,
856}
857
858#[derive(Debug, Clone, Copy, PartialEq)]
860pub enum PlateauMode {
861 Min,
863 Max,
865}
866
867impl ReduceLROnPlateauScheduler {
868 pub fn new(
879 initial_lr: f64,
880 mode: PlateauMode,
881 factor: f64,
882 patience: usize,
883 threshold: f64,
884 min_lr: f64,
885 cooldown: usize,
886 ) -> Self {
887 Self {
888 current_lr: initial_lr,
889 factor,
890 patience,
891 min_lr,
892 threshold,
893 cooldown,
894 best_metric: None,
895 num_bad_epochs: 0,
896 cooldown_counter: 0,
897 mode,
898 }
899 }
900
901 pub fn step_with_metric(&mut self, optimizer: &mut dyn Optimizer, metric: f64) {
905 if self.cooldown_counter > 0 {
907 self.cooldown_counter -= 1;
908 return;
909 }
910
911 let is_better = match self.best_metric {
913 None => true, Some(best) => match self.mode {
915 PlateauMode::Min => metric < best * (1.0 - self.threshold),
916 PlateauMode::Max => metric > best * (1.0 + self.threshold),
917 },
918 };
919
920 if is_better {
921 self.best_metric = Some(metric);
923 self.num_bad_epochs = 0;
924 } else {
925 self.num_bad_epochs += 1;
927
928 if self.num_bad_epochs >= self.patience {
929 let new_lr = (self.current_lr * self.factor).max(self.min_lr);
931
932 if new_lr < self.current_lr {
933 self.current_lr = new_lr;
934 optimizer.set_lr(self.current_lr);
935 self.cooldown_counter = self.cooldown;
936 self.num_bad_epochs = 0;
937 }
938 }
939 }
940 }
941}
942
943impl LrScheduler for ReduceLROnPlateauScheduler {
944 fn step(&mut self, _optimizer: &mut dyn Optimizer) {
945 }
948
949 fn get_lr(&self) -> f64 {
950 self.current_lr
951 }
952
953 fn state_dict(&self) -> std::collections::HashMap<String, f64> {
954 let mut state = std::collections::HashMap::new();
955 state.insert("current_lr".to_string(), self.current_lr);
956 state.insert("factor".to_string(), self.factor);
957 state.insert("patience".to_string(), self.patience as f64);
958 state.insert("min_lr".to_string(), self.min_lr);
959 state.insert("threshold".to_string(), self.threshold);
960 state.insert("cooldown".to_string(), self.cooldown as f64);
961 state.insert(
962 "best_metric".to_string(),
963 self.best_metric.unwrap_or(f64::NAN),
964 );
965 state.insert("num_bad_epochs".to_string(), self.num_bad_epochs as f64);
966 state.insert("cooldown_counter".to_string(), self.cooldown_counter as f64);
967 state.insert(
968 "mode".to_string(),
969 match self.mode {
970 PlateauMode::Min => 0.0,
971 PlateauMode::Max => 1.0,
972 },
973 );
974 state
975 }
976
977 fn load_state_dict(
978 &mut self,
979 state: &std::collections::HashMap<String, f64>,
980 ) -> crate::TrainResult<()> {
981 if let Some(¤t_lr) = state.get("current_lr") {
982 self.current_lr = current_lr;
983 }
984 if let Some(&best_metric) = state.get("best_metric") {
985 self.best_metric = if best_metric.is_nan() {
986 None
987 } else {
988 Some(best_metric)
989 };
990 }
991 if let Some(&num_bad_epochs) = state.get("num_bad_epochs") {
992 self.num_bad_epochs = num_bad_epochs as usize;
993 }
994 if let Some(&cooldown_counter) = state.get("cooldown_counter") {
995 self.cooldown_counter = cooldown_counter as usize;
996 }
997 Ok(())
998 }
999}
1000
1001#[cfg(test)]
1002mod tests {
1003 use super::*;
1004 use crate::{OptimizerConfig, SgdOptimizer};
1005
1006 #[test]
1007 fn test_step_lr_scheduler() {
1008 let config = OptimizerConfig {
1009 learning_rate: 0.1,
1010 ..Default::default()
1011 };
1012 let mut optimizer = SgdOptimizer::new(config);
1013 let mut scheduler = StepLrScheduler::new(0.1, 2, 0.5);
1014
1015 assert_eq!(scheduler.get_lr(), 0.1);
1016
1017 scheduler.step(&mut optimizer);
1018 assert_eq!(scheduler.get_lr(), 0.1);
1019
1020 scheduler.step(&mut optimizer);
1021 assert_eq!(scheduler.get_lr(), 0.05);
1022
1023 scheduler.step(&mut optimizer);
1024 assert_eq!(scheduler.get_lr(), 0.05);
1025
1026 scheduler.step(&mut optimizer);
1027 assert_eq!(scheduler.get_lr(), 0.025);
1028 }
1029
1030 #[test]
1031 fn test_exponential_lr_scheduler() {
1032 let config = OptimizerConfig {
1033 learning_rate: 0.1,
1034 ..Default::default()
1035 };
1036 let mut optimizer = SgdOptimizer::new(config);
1037 let mut scheduler = ExponentialLrScheduler::new(0.1, 0.9);
1038
1039 assert_eq!(scheduler.get_lr(), 0.1);
1040
1041 scheduler.step(&mut optimizer);
1042 assert!((scheduler.get_lr() - 0.09).abs() < 1e-6);
1043
1044 scheduler.step(&mut optimizer);
1045 assert!((scheduler.get_lr() - 0.081).abs() < 1e-6);
1046 }
1047
1048 #[test]
1049 fn test_cosine_annealing_scheduler() {
1050 let config = OptimizerConfig {
1051 learning_rate: 0.1,
1052 ..Default::default()
1053 };
1054 let mut optimizer = SgdOptimizer::new(config);
1055 let mut scheduler = CosineAnnealingLrScheduler::new(0.1, 0.01, 10);
1056
1057 assert_eq!(scheduler.get_lr(), 0.1);
1058
1059 scheduler.step(&mut optimizer);
1060 assert!(scheduler.get_lr() < 0.1);
1061 assert!(scheduler.get_lr() > 0.01);
1062
1063 for _ in 1..5 {
1065 scheduler.step(&mut optimizer);
1066 }
1067 let halfway_lr = scheduler.get_lr();
1068 assert!((halfway_lr - 0.055).abs() < 0.01); }
1070
1071 #[test]
1072 fn test_warmup_scheduler() {
1073 let config = OptimizerConfig {
1074 learning_rate: 0.0,
1075 ..Default::default()
1076 };
1077 let mut optimizer = SgdOptimizer::new(config);
1078 let mut scheduler = WarmupScheduler::new(0.1, 10);
1079
1080 assert_eq!(scheduler.get_lr(), 0.0);
1081
1082 scheduler.step(&mut optimizer);
1083 assert!((scheduler.get_lr() - 0.01).abs() < 1e-6);
1084
1085 for _ in 1..10 {
1086 scheduler.step(&mut optimizer);
1087 }
1088 assert_eq!(scheduler.get_lr(), 0.1);
1089
1090 scheduler.step(&mut optimizer);
1091 assert_eq!(scheduler.get_lr(), 0.1); }
1093
1094 #[test]
1095 fn test_one_cycle_scheduler() {
1096 let config = OptimizerConfig {
1097 learning_rate: 0.01,
1098 ..Default::default()
1099 };
1100 let mut optimizer = SgdOptimizer::new(config);
1101 let mut scheduler = OneCycleLrScheduler::new(0.01, 0.1, 0.001, 100, 0.3);
1102
1103 assert_eq!(scheduler.get_lr(), 0.01);
1104
1105 for _ in 0..30 {
1107 scheduler.step(&mut optimizer);
1108 }
1109 assert!(scheduler.get_lr() > 0.01);
1110 assert!(scheduler.get_lr() <= 0.1);
1111
1112 for _ in 30..100 {
1114 scheduler.step(&mut optimizer);
1115 }
1116 assert!(scheduler.get_lr() < 0.1);
1117 }
1118
1119 #[test]
1120 fn test_polynomial_decay_scheduler() {
1121 let config = OptimizerConfig {
1122 learning_rate: 0.1,
1123 ..Default::default()
1124 };
1125 let mut optimizer = SgdOptimizer::new(config);
1126 let mut scheduler = PolynomialDecayLrScheduler::new(0.1, 0.001, 2.0, 100);
1127
1128 assert_eq!(scheduler.get_lr(), 0.1);
1129
1130 scheduler.step(&mut optimizer);
1131 assert!(scheduler.get_lr() < 0.1);
1132
1133 for _ in 1..100 {
1134 scheduler.step(&mut optimizer);
1135 }
1136 assert!((scheduler.get_lr() - 0.001).abs() < 1e-6);
1137 }
1138
1139 #[test]
1140 fn test_cyclic_lr_scheduler() {
1141 let config = OptimizerConfig {
1142 learning_rate: 0.01,
1143 ..Default::default()
1144 };
1145 let mut optimizer = SgdOptimizer::new(config);
1146 let mut scheduler = CyclicLrScheduler::new(0.01, 0.1, 10, CyclicLrMode::Triangular);
1147
1148 assert_eq!(scheduler.get_lr(), 0.01);
1149
1150 for _ in 0..10 {
1152 scheduler.step(&mut optimizer);
1153 }
1154 assert!(scheduler.get_lr() > 0.01);
1155
1156 for _ in 10..20 {
1157 scheduler.step(&mut optimizer);
1158 }
1159 assert!(scheduler.get_lr() < 0.1);
1160 }
1161
1162 #[test]
1163 fn test_warmup_cosine_scheduler() {
1164 let config = OptimizerConfig {
1165 learning_rate: 0.0,
1166 ..Default::default()
1167 };
1168 let mut optimizer = SgdOptimizer::new(config);
1169 let mut scheduler = WarmupCosineLrScheduler::new(0.1, 0.001, 10, 100);
1170
1171 assert_eq!(scheduler.get_lr(), 0.0);
1172
1173 for _ in 0..10 {
1175 scheduler.step(&mut optimizer);
1176 }
1177 assert!((scheduler.get_lr() - 0.1).abs() < 1e-6);
1178
1179 for _ in 10..50 {
1181 scheduler.step(&mut optimizer);
1182 }
1183 assert!(scheduler.get_lr() < 0.1);
1184 assert!(scheduler.get_lr() > 0.001);
1185
1186 for _ in 50..100 {
1188 scheduler.step(&mut optimizer);
1189 }
1190 assert!(scheduler.get_lr() < 0.1);
1191 assert!((scheduler.get_lr() - 0.001).abs() < 0.01);
1193 }
1194
1195 #[test]
1196 fn test_noam_scheduler() {
1197 let config = OptimizerConfig {
1198 learning_rate: 0.0,
1199 ..Default::default()
1200 };
1201 let mut optimizer = SgdOptimizer::new(config);
1202 let mut scheduler = NoamScheduler::new(512, 4000, 1.0);
1203
1204 let initial_lr = scheduler.get_lr();
1205 assert!(initial_lr > 0.0);
1206
1207 scheduler.step(&mut optimizer);
1209 let step1_lr = scheduler.get_lr();
1210
1211 assert!(step1_lr != initial_lr);
1213
1214 for _ in 1..4000 {
1216 scheduler.step(&mut optimizer);
1217 }
1218 let peak_lr = scheduler.get_lr();
1219
1220 for _ in 4000..8000 {
1222 scheduler.step(&mut optimizer);
1223 }
1224 assert!(scheduler.get_lr() < peak_lr);
1225 }
1226
1227 #[test]
1228 fn test_multistep_lr_scheduler() {
1229 let config = OptimizerConfig {
1230 learning_rate: 0.1,
1231 ..Default::default()
1232 };
1233 let mut optimizer = SgdOptimizer::new(config);
1234 let mut scheduler = MultiStepLrScheduler::new(0.1, vec![10, 20, 30], 0.1);
1235
1236 assert_eq!(scheduler.get_lr(), 0.1);
1237
1238 for _ in 0..9 {
1240 scheduler.step(&mut optimizer);
1241 }
1242 assert_eq!(scheduler.get_lr(), 0.1);
1243
1244 scheduler.step(&mut optimizer);
1246 assert!((scheduler.get_lr() - 0.01).abs() < 1e-6);
1247
1248 for _ in 10..19 {
1250 scheduler.step(&mut optimizer);
1251 }
1252 assert!((scheduler.get_lr() - 0.01).abs() < 1e-6);
1253
1254 scheduler.step(&mut optimizer);
1256 assert!((scheduler.get_lr() - 0.001).abs() < 1e-6);
1257
1258 for _ in 20..29 {
1260 scheduler.step(&mut optimizer);
1261 }
1262 scheduler.step(&mut optimizer);
1263 assert!((scheduler.get_lr() - 0.0001).abs() < 1e-6);
1264 }
1265
1266 #[test]
1267 fn test_reduce_lr_on_plateau_min_mode() {
1268 let config = OptimizerConfig {
1269 learning_rate: 0.1,
1270 ..Default::default()
1271 };
1272 let mut optimizer = SgdOptimizer::new(config);
1273 let mut scheduler = ReduceLROnPlateauScheduler::new(
1274 0.1, PlateauMode::Min, 0.5, 3, 0.01, 0.001, 2, );
1282
1283 assert_eq!(scheduler.get_lr(), 0.1);
1284
1285 scheduler.step_with_metric(&mut optimizer, 1.0);
1287 assert_eq!(scheduler.get_lr(), 0.1);
1288
1289 scheduler.step_with_metric(&mut optimizer, 0.9);
1290 assert_eq!(scheduler.get_lr(), 0.1);
1291
1292 scheduler.step_with_metric(&mut optimizer, 0.9);
1294 assert_eq!(scheduler.get_lr(), 0.1);
1295
1296 scheduler.step_with_metric(&mut optimizer, 0.9);
1297 assert_eq!(scheduler.get_lr(), 0.1);
1298
1299 scheduler.step_with_metric(&mut optimizer, 0.9);
1300 assert_eq!(scheduler.get_lr(), 0.05);
1302
1303 scheduler.step_with_metric(&mut optimizer, 1.0);
1305 assert_eq!(scheduler.get_lr(), 0.05);
1306
1307 scheduler.step_with_metric(&mut optimizer, 1.0);
1308 assert_eq!(scheduler.get_lr(), 0.05);
1309 }
1310
1311 #[test]
1312 fn test_reduce_lr_on_plateau_max_mode() {
1313 let config = OptimizerConfig {
1314 learning_rate: 0.1,
1315 ..Default::default()
1316 };
1317 let mut optimizer = SgdOptimizer::new(config);
1318 let mut scheduler = ReduceLROnPlateauScheduler::new(
1319 0.1,
1320 PlateauMode::Max, 0.1,
1322 2,
1323 0.01,
1324 0.001,
1325 0,
1326 );
1327
1328 assert_eq!(scheduler.get_lr(), 0.1);
1329
1330 scheduler.step_with_metric(&mut optimizer, 0.5);
1332 assert_eq!(scheduler.get_lr(), 0.1);
1333
1334 scheduler.step_with_metric(&mut optimizer, 0.6);
1335 assert_eq!(scheduler.get_lr(), 0.1);
1336
1337 scheduler.step_with_metric(&mut optimizer, 0.6);
1339 assert_eq!(scheduler.get_lr(), 0.1);
1340
1341 scheduler.step_with_metric(&mut optimizer, 0.6);
1342 assert!((scheduler.get_lr() - 0.01).abs() < 1e-6);
1344 }
1345
1346 #[test]
1347 fn test_sgdr_scheduler() {
1348 let mut scheduler = SgdrScheduler::new(0.1, 0.001, 10, 2.0);
1349 let mut optimizer = SgdOptimizer::new(OptimizerConfig::default());
1350
1351 let initial_lr = scheduler.get_current_lr();
1353 assert!((initial_lr - 0.1).abs() < 1e-6);
1354
1355 for _ in 0..5 {
1357 scheduler.step(&mut optimizer);
1358 }
1359
1360 let mid_lr = scheduler.get_lr();
1362 assert!(mid_lr < initial_lr);
1363
1364 for _ in 5..10 {
1366 scheduler.step(&mut optimizer);
1367 }
1368
1369 scheduler.step(&mut optimizer);
1371 let restart_lr = scheduler.get_lr();
1372 assert!(restart_lr > mid_lr); assert_eq!(scheduler.current_period, 20);
1376 }
1377}
1378
1379#[derive(Debug, Clone)]
1384pub struct SgdrScheduler {
1385 pub max_lr: f64,
1387 pub min_lr: f64,
1389 pub t_0: usize,
1391 pub t_mult: f64,
1393 current_step: usize,
1395 current_period: usize,
1397 total_steps: usize,
1399}
1400
1401impl SgdrScheduler {
1402 pub fn new(max_lr: f64, min_lr: f64, t_0: usize, t_mult: f64) -> Self {
1410 Self {
1411 max_lr,
1412 min_lr,
1413 t_0,
1414 t_mult,
1415 current_step: 0,
1416 current_period: t_0,
1417 total_steps: 0,
1418 }
1419 }
1420
1421 fn get_current_lr(&self) -> f64 {
1423 let progress = self.current_step as f64 / self.current_period as f64;
1424 let cosine_factor = (1.0 + (std::f64::consts::PI * progress).cos()) / 2.0;
1425 self.min_lr + (self.max_lr - self.min_lr) * cosine_factor
1426 }
1427}
1428
1429impl LrScheduler for SgdrScheduler {
1430 fn step(&mut self, optimizer: &mut dyn Optimizer) {
1431 let lr = self.get_current_lr();
1432 optimizer.set_lr(lr);
1433
1434 self.current_step += 1;
1435 self.total_steps += 1;
1436
1437 if self.current_step >= self.current_period {
1439 self.current_step = 0;
1440 self.current_period = (self.current_period as f64 * self.t_mult) as usize;
1441 }
1443 }
1444
1445 fn get_lr(&self) -> f64 {
1446 self.get_current_lr()
1447 }
1448
1449 fn state_dict(&self) -> std::collections::HashMap<String, f64> {
1450 let mut state = std::collections::HashMap::new();
1451 state.insert("max_lr".to_string(), self.max_lr);
1452 state.insert("min_lr".to_string(), self.min_lr);
1453 state.insert("t_0".to_string(), self.t_0 as f64);
1454 state.insert("t_mult".to_string(), self.t_mult);
1455 state.insert("current_step".to_string(), self.current_step as f64);
1456 state.insert("current_period".to_string(), self.current_period as f64);
1457 state.insert("total_steps".to_string(), self.total_steps as f64);
1458 state
1459 }
1460
1461 fn load_state_dict(
1462 &mut self,
1463 state: &std::collections::HashMap<String, f64>,
1464 ) -> crate::TrainResult<()> {
1465 if let Some(&max_lr) = state.get("max_lr") {
1466 self.max_lr = max_lr;
1467 }
1468 if let Some(&min_lr) = state.get("min_lr") {
1469 self.min_lr = min_lr;
1470 }
1471 if let Some(&t_0) = state.get("t_0") {
1472 self.t_0 = t_0 as usize;
1473 }
1474 if let Some(&t_mult) = state.get("t_mult") {
1475 self.t_mult = t_mult;
1476 }
1477 if let Some(¤t_step) = state.get("current_step") {
1478 self.current_step = current_step as usize;
1479 }
1480 if let Some(¤t_period) = state.get("current_period") {
1481 self.current_period = current_period as usize;
1482 }
1483 if let Some(&total_steps) = state.get("total_steps") {
1484 self.total_steps = total_steps as usize;
1485 }
1486 Ok(())
1487 }
1488}