Skip to main content

tensorlogic_train/
scheduler.rs

1//! Learning rate schedulers.
2
3use crate::Optimizer;
4
5/// Trait for learning rate schedulers.
6pub trait LrScheduler {
7    /// Update learning rate based on current step/epoch.
8    fn step(&mut self, optimizer: &mut dyn Optimizer);
9
10    /// Get current learning rate.
11    fn get_lr(&self) -> f64;
12
13    /// Get scheduler state as a dictionary.
14    fn state_dict(&self) -> std::collections::HashMap<String, f64>;
15
16    /// Load scheduler state from a dictionary.
17    fn load_state_dict(
18        &mut self,
19        state: &std::collections::HashMap<String, f64>,
20    ) -> crate::TrainResult<()>;
21}
22
23/// Step-based learning rate scheduler.
24/// Decreases learning rate by a factor every `step_size` epochs.
25#[derive(Debug, Clone)]
26pub struct StepLrScheduler {
27    /// Initial learning rate.
28    pub initial_lr: f64,
29    /// Step size (epochs).
30    pub step_size: usize,
31    /// Multiplicative factor of learning rate decay.
32    pub gamma: f64,
33    /// Current epoch counter.
34    current_epoch: usize,
35    /// Current learning rate.
36    current_lr: f64,
37}
38
39impl StepLrScheduler {
40    /// Create a new step LR scheduler.
41    pub fn new(initial_lr: f64, step_size: usize, gamma: f64) -> Self {
42        Self {
43            initial_lr,
44            step_size,
45            gamma,
46            current_epoch: 0,
47            current_lr: initial_lr,
48        }
49    }
50}
51
52impl LrScheduler for StepLrScheduler {
53    fn step(&mut self, optimizer: &mut dyn Optimizer) {
54        self.current_epoch += 1;
55
56        if self.current_epoch.is_multiple_of(self.step_size) {
57            self.current_lr *= self.gamma;
58            optimizer.set_lr(self.current_lr);
59        }
60    }
61
62    fn get_lr(&self) -> f64 {
63        self.current_lr
64    }
65
66    fn state_dict(&self) -> std::collections::HashMap<String, f64> {
67        let mut state = std::collections::HashMap::new();
68        state.insert("initial_lr".to_string(), self.initial_lr);
69        state.insert("current_lr".to_string(), self.current_lr);
70        state.insert("current_epoch".to_string(), self.current_epoch as f64);
71        state.insert("step_size".to_string(), self.step_size as f64);
72        state.insert("gamma".to_string(), self.gamma);
73        state
74    }
75
76    fn load_state_dict(
77        &mut self,
78        state: &std::collections::HashMap<String, f64>,
79    ) -> crate::TrainResult<()> {
80        if let Some(&current_lr) = state.get("current_lr") {
81            self.current_lr = current_lr;
82        }
83        if let Some(&current_epoch) = state.get("current_epoch") {
84            self.current_epoch = current_epoch as usize;
85        }
86        Ok(())
87    }
88}
89
90/// Exponential learning rate scheduler.
91/// Decreases learning rate by a factor of gamma every epoch.
92#[derive(Debug, Clone)]
93pub struct ExponentialLrScheduler {
94    /// Initial learning rate.
95    pub initial_lr: f64,
96    /// Multiplicative factor of learning rate decay.
97    pub gamma: f64,
98    /// Current epoch counter.
99    current_epoch: usize,
100    /// Current learning rate.
101    current_lr: f64,
102}
103
104impl ExponentialLrScheduler {
105    /// Create a new exponential LR scheduler.
106    pub fn new(initial_lr: f64, gamma: f64) -> Self {
107        Self {
108            initial_lr,
109            gamma,
110            current_epoch: 0,
111            current_lr: initial_lr,
112        }
113    }
114}
115
116impl LrScheduler for ExponentialLrScheduler {
117    fn step(&mut self, optimizer: &mut dyn Optimizer) {
118        self.current_epoch += 1;
119        self.current_lr = self.initial_lr * self.gamma.powi(self.current_epoch as i32);
120        optimizer.set_lr(self.current_lr);
121    }
122
123    fn get_lr(&self) -> f64 {
124        self.current_lr
125    }
126
127    fn state_dict(&self) -> std::collections::HashMap<String, f64> {
128        let mut state = std::collections::HashMap::new();
129        state.insert("initial_lr".to_string(), self.initial_lr);
130        state.insert("current_lr".to_string(), self.current_lr);
131        state.insert("current_epoch".to_string(), self.current_epoch as f64);
132        state.insert("gamma".to_string(), self.gamma);
133        state
134    }
135
136    fn load_state_dict(
137        &mut self,
138        state: &std::collections::HashMap<String, f64>,
139    ) -> crate::TrainResult<()> {
140        if let Some(&current_lr) = state.get("current_lr") {
141            self.current_lr = current_lr;
142        }
143        if let Some(&current_epoch) = state.get("current_epoch") {
144            self.current_epoch = current_epoch as usize;
145        }
146        Ok(())
147    }
148}
149
150/// Cosine annealing learning rate scheduler.
151/// Anneals learning rate using a cosine schedule.
152#[derive(Debug, Clone)]
153pub struct CosineAnnealingLrScheduler {
154    /// Initial learning rate.
155    pub initial_lr: f64,
156    /// Minimum learning rate.
157    pub min_lr: f64,
158    /// Total number of epochs.
159    pub t_max: usize,
160    /// Current epoch counter.
161    current_epoch: usize,
162    /// Current learning rate.
163    current_lr: f64,
164}
165
166impl CosineAnnealingLrScheduler {
167    /// Create a new cosine annealing LR scheduler.
168    pub fn new(initial_lr: f64, min_lr: f64, t_max: usize) -> Self {
169        Self {
170            initial_lr,
171            min_lr,
172            t_max,
173            current_epoch: 0,
174            current_lr: initial_lr,
175        }
176    }
177}
178
179impl LrScheduler for CosineAnnealingLrScheduler {
180    fn step(&mut self, optimizer: &mut dyn Optimizer) {
181        self.current_epoch += 1;
182
183        let progress = (self.current_epoch as f64) / (self.t_max as f64);
184        let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * progress).cos());
185        self.current_lr = self.min_lr + (self.initial_lr - self.min_lr) * cosine_decay;
186
187        optimizer.set_lr(self.current_lr);
188    }
189
190    fn get_lr(&self) -> f64 {
191        self.current_lr
192    }
193
194    fn state_dict(&self) -> std::collections::HashMap<String, f64> {
195        let mut state = std::collections::HashMap::new();
196        state.insert("initial_lr".to_string(), self.initial_lr);
197        state.insert("current_lr".to_string(), self.current_lr);
198        state.insert("current_epoch".to_string(), self.current_epoch as f64);
199        state.insert("min_lr".to_string(), self.min_lr);
200        state.insert("t_max".to_string(), self.t_max as f64);
201        state
202    }
203
204    fn load_state_dict(
205        &mut self,
206        state: &std::collections::HashMap<String, f64>,
207    ) -> crate::TrainResult<()> {
208        if let Some(&current_lr) = state.get("current_lr") {
209            self.current_lr = current_lr;
210        }
211        if let Some(&current_epoch) = state.get("current_epoch") {
212            self.current_epoch = current_epoch as usize;
213        }
214        Ok(())
215    }
216}
217
218/// Warmup scheduler that linearly increases learning rate.
219#[derive(Debug, Clone)]
220#[allow(dead_code)]
221pub struct WarmupScheduler {
222    /// Target learning rate after warmup.
223    pub target_lr: f64,
224    /// Number of warmup steps.
225    pub warmup_steps: usize,
226    /// Current step counter.
227    current_step: usize,
228    /// Current learning rate.
229    current_lr: f64,
230}
231
232impl WarmupScheduler {
233    /// Create a new warmup scheduler.
234    #[allow(dead_code)]
235    pub fn new(target_lr: f64, warmup_steps: usize) -> Self {
236        Self {
237            target_lr,
238            warmup_steps,
239            current_step: 0,
240            current_lr: 0.0,
241        }
242    }
243}
244
245impl LrScheduler for WarmupScheduler {
246    fn step(&mut self, optimizer: &mut dyn Optimizer) {
247        self.current_step += 1;
248
249        if self.current_step < self.warmup_steps {
250            self.current_lr =
251                self.target_lr * (self.current_step as f64) / (self.warmup_steps as f64);
252        } else {
253            self.current_lr = self.target_lr;
254        }
255
256        optimizer.set_lr(self.current_lr);
257    }
258
259    fn get_lr(&self) -> f64 {
260        self.current_lr
261    }
262
263    fn state_dict(&self) -> std::collections::HashMap<String, f64> {
264        let mut state = std::collections::HashMap::new();
265        state.insert("target_lr".to_string(), self.target_lr);
266        state.insert("current_lr".to_string(), self.current_lr);
267        state.insert("current_step".to_string(), self.current_step as f64);
268        state.insert("warmup_steps".to_string(), self.warmup_steps as f64);
269        state
270    }
271
272    fn load_state_dict(
273        &mut self,
274        state: &std::collections::HashMap<String, f64>,
275    ) -> crate::TrainResult<()> {
276        if let Some(&current_lr) = state.get("current_lr") {
277            self.current_lr = current_lr;
278        }
279        if let Some(&current_step) = state.get("current_step") {
280            self.current_step = current_step as usize;
281        }
282        Ok(())
283    }
284}
285
286/// One-cycle learning rate scheduler.
287/// Increases LR from initial to max, then decreases to min.
288#[derive(Debug, Clone)]
289pub struct OneCycleLrScheduler {
290    /// Initial learning rate.
291    pub initial_lr: f64,
292    /// Maximum learning rate.
293    pub max_lr: f64,
294    /// Minimum learning rate (final).
295    pub min_lr: f64,
296    /// Total number of steps.
297    pub total_steps: usize,
298    /// Percentage of cycle spent increasing LR.
299    pub pct_start: f64,
300    /// Current step counter.
301    current_step: usize,
302    /// Current learning rate.
303    current_lr: f64,
304}
305
306impl OneCycleLrScheduler {
307    /// Create a new one-cycle LR scheduler.
308    pub fn new(
309        initial_lr: f64,
310        max_lr: f64,
311        min_lr: f64,
312        total_steps: usize,
313        pct_start: f64,
314    ) -> Self {
315        Self {
316            initial_lr,
317            max_lr,
318            min_lr,
319            total_steps,
320            pct_start,
321            current_step: 0,
322            current_lr: initial_lr,
323        }
324    }
325}
326
327impl LrScheduler for OneCycleLrScheduler {
328    fn step(&mut self, optimizer: &mut dyn Optimizer) {
329        self.current_step += 1;
330
331        let step_num = self.current_step.min(self.total_steps);
332        let pct = step_num as f64 / self.total_steps as f64;
333
334        if pct < self.pct_start {
335            // Increasing phase
336            let phase_pct = pct / self.pct_start;
337            self.current_lr = self.initial_lr + (self.max_lr - self.initial_lr) * phase_pct;
338        } else {
339            // Decreasing phase
340            let phase_pct = (pct - self.pct_start) / (1.0 - self.pct_start);
341            // Cosine annealing for smooth decay
342            let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * phase_pct).cos());
343            self.current_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay;
344        }
345
346        optimizer.set_lr(self.current_lr);
347    }
348
349    fn get_lr(&self) -> f64 {
350        self.current_lr
351    }
352
353    fn state_dict(&self) -> std::collections::HashMap<String, f64> {
354        let mut state = std::collections::HashMap::new();
355        state.insert("initial_lr".to_string(), self.initial_lr);
356        state.insert("max_lr".to_string(), self.max_lr);
357        state.insert("min_lr".to_string(), self.min_lr);
358        state.insert("current_lr".to_string(), self.current_lr);
359        state.insert("current_step".to_string(), self.current_step as f64);
360        state.insert("total_steps".to_string(), self.total_steps as f64);
361        state.insert("pct_start".to_string(), self.pct_start);
362        state
363    }
364
365    fn load_state_dict(
366        &mut self,
367        state: &std::collections::HashMap<String, f64>,
368    ) -> crate::TrainResult<()> {
369        if let Some(&current_lr) = state.get("current_lr") {
370            self.current_lr = current_lr;
371        }
372        if let Some(&current_step) = state.get("current_step") {
373            self.current_step = current_step as usize;
374        }
375        Ok(())
376    }
377}
378
379/// Polynomial decay learning rate scheduler.
380#[derive(Debug, Clone)]
381pub struct PolynomialDecayLrScheduler {
382    /// Initial learning rate.
383    pub initial_lr: f64,
384    /// Final learning rate.
385    pub final_lr: f64,
386    /// Power of the polynomial.
387    pub power: f64,
388    /// Total number of decay steps.
389    pub decay_steps: usize,
390    /// Current step counter.
391    current_step: usize,
392    /// Current learning rate.
393    current_lr: f64,
394}
395
396impl PolynomialDecayLrScheduler {
397    /// Create a new polynomial decay LR scheduler.
398    pub fn new(initial_lr: f64, final_lr: f64, power: f64, decay_steps: usize) -> Self {
399        Self {
400            initial_lr,
401            final_lr,
402            power,
403            decay_steps,
404            current_step: 0,
405            current_lr: initial_lr,
406        }
407    }
408}
409
410impl LrScheduler for PolynomialDecayLrScheduler {
411    fn step(&mut self, optimizer: &mut dyn Optimizer) {
412        self.current_step += 1;
413
414        let step_num = self.current_step.min(self.decay_steps);
415        let decay_factor = (1.0 - (step_num as f64 / self.decay_steps as f64)).powf(self.power);
416        self.current_lr = (self.initial_lr - self.final_lr) * decay_factor + self.final_lr;
417
418        optimizer.set_lr(self.current_lr);
419    }
420
421    fn get_lr(&self) -> f64 {
422        self.current_lr
423    }
424
425    fn state_dict(&self) -> std::collections::HashMap<String, f64> {
426        let mut state = std::collections::HashMap::new();
427        state.insert("initial_lr".to_string(), self.initial_lr);
428        state.insert("final_lr".to_string(), self.final_lr);
429        state.insert("power".to_string(), self.power);
430        state.insert("current_lr".to_string(), self.current_lr);
431        state.insert("current_step".to_string(), self.current_step as f64);
432        state.insert("decay_steps".to_string(), self.decay_steps as f64);
433        state
434    }
435
436    fn load_state_dict(
437        &mut self,
438        state: &std::collections::HashMap<String, f64>,
439    ) -> crate::TrainResult<()> {
440        if let Some(&current_lr) = state.get("current_lr") {
441            self.current_lr = current_lr;
442        }
443        if let Some(&current_step) = state.get("current_step") {
444            self.current_step = current_step as usize;
445        }
446        Ok(())
447    }
448}
449
450/// Cyclic learning rate mode.
451#[derive(Debug, Clone, Copy, PartialEq)]
452pub enum CyclicLrMode {
453    /// Triangular (linear increase and decrease).
454    Triangular,
455    /// Triangular2 (amplitude decreases by half each cycle).
456    Triangular2,
457    /// Exponential range (amplitude decreases exponentially).
458    ExpRange,
459}
460
461/// Cyclic learning rate scheduler.
462#[derive(Debug, Clone)]
463pub struct CyclicLrScheduler {
464    /// Base learning rate.
465    pub base_lr: f64,
466    /// Maximum learning rate.
467    pub max_lr: f64,
468    /// Step size (half of cycle length).
469    pub step_size: usize,
470    /// Cyclic mode.
471    pub mode: CyclicLrMode,
472    /// Gamma for exponential range mode.
473    pub gamma: f64,
474    /// Current step counter.
475    current_step: usize,
476    /// Current learning rate.
477    current_lr: f64,
478    /// Current cycle number.
479    cycle: usize,
480}
481
482impl CyclicLrScheduler {
483    /// Create a new cyclic LR scheduler.
484    pub fn new(base_lr: f64, max_lr: f64, step_size: usize, mode: CyclicLrMode) -> Self {
485        Self {
486            base_lr,
487            max_lr,
488            step_size,
489            mode,
490            gamma: 0.99994,
491            current_step: 0,
492            current_lr: base_lr,
493            cycle: 0,
494        }
495    }
496
497    /// Create a new cyclic LR scheduler with exponential range mode.
498    pub fn new_exp_range(base_lr: f64, max_lr: f64, step_size: usize, gamma: f64) -> Self {
499        Self {
500            base_lr,
501            max_lr,
502            step_size,
503            mode: CyclicLrMode::ExpRange,
504            gamma,
505            current_step: 0,
506            current_lr: base_lr,
507            cycle: 0,
508        }
509    }
510}
511
512impl LrScheduler for CyclicLrScheduler {
513    fn step(&mut self, optimizer: &mut dyn Optimizer) {
514        self.current_step += 1;
515
516        // Determine cycle position
517        let cycle = (self.current_step - 1) / (2 * self.step_size);
518        let x = ((self.current_step - 1) as f64 / self.step_size as f64).abs() % 2.0;
519
520        // Calculate scaling factor based on mode
521        let scale_fn = match self.mode {
522            CyclicLrMode::Triangular => 1.0,
523            CyclicLrMode::Triangular2 => 1.0 / 2.0_f64.powi(cycle as i32),
524            CyclicLrMode::ExpRange => self.gamma.powi(self.current_step as i32),
525        };
526
527        // Calculate current LR
528        if x <= 1.0 {
529            // Increasing phase
530            self.current_lr = self.base_lr + (self.max_lr - self.base_lr) * x * scale_fn;
531        } else {
532            // Decreasing phase
533            self.current_lr = self.base_lr + (self.max_lr - self.base_lr) * (2.0 - x) * scale_fn;
534        }
535
536        self.cycle = cycle;
537        optimizer.set_lr(self.current_lr);
538    }
539
540    fn get_lr(&self) -> f64 {
541        self.current_lr
542    }
543
544    fn state_dict(&self) -> std::collections::HashMap<String, f64> {
545        let mut state = std::collections::HashMap::new();
546        state.insert("base_lr".to_string(), self.base_lr);
547        state.insert("max_lr".to_string(), self.max_lr);
548        state.insert("current_lr".to_string(), self.current_lr);
549        state.insert("current_step".to_string(), self.current_step as f64);
550        state.insert("step_size".to_string(), self.step_size as f64);
551        state.insert("cycle".to_string(), self.cycle as f64);
552        state.insert("gamma".to_string(), self.gamma);
553        state
554    }
555
556    fn load_state_dict(
557        &mut self,
558        state: &std::collections::HashMap<String, f64>,
559    ) -> crate::TrainResult<()> {
560        if let Some(&current_lr) = state.get("current_lr") {
561            self.current_lr = current_lr;
562        }
563        if let Some(&current_step) = state.get("current_step") {
564            self.current_step = current_step as usize;
565        }
566        if let Some(&cycle) = state.get("cycle") {
567            self.cycle = cycle as usize;
568        }
569        Ok(())
570    }
571}
572
573/// Warmup with cosine annealing scheduler.
574#[derive(Debug, Clone)]
575pub struct WarmupCosineLrScheduler {
576    /// Target learning rate after warmup.
577    pub target_lr: f64,
578    /// Minimum learning rate.
579    pub min_lr: f64,
580    /// Number of warmup steps.
581    pub warmup_steps: usize,
582    /// Total number of steps (including warmup).
583    pub total_steps: usize,
584    /// Current step counter.
585    current_step: usize,
586    /// Current learning rate.
587    current_lr: f64,
588}
589
590impl WarmupCosineLrScheduler {
591    /// Create a new warmup cosine LR scheduler.
592    pub fn new(target_lr: f64, min_lr: f64, warmup_steps: usize, total_steps: usize) -> Self {
593        Self {
594            target_lr,
595            min_lr,
596            warmup_steps,
597            total_steps,
598            current_step: 0,
599            current_lr: 0.0,
600        }
601    }
602}
603
604impl LrScheduler for WarmupCosineLrScheduler {
605    fn step(&mut self, optimizer: &mut dyn Optimizer) {
606        self.current_step += 1;
607
608        if self.current_step <= self.warmup_steps {
609            // Warmup phase: linear increase
610            self.current_lr =
611                self.target_lr * (self.current_step as f64 / self.warmup_steps as f64);
612        } else {
613            // Cosine annealing phase
614            let progress = (self.current_step - self.warmup_steps) as f64
615                / (self.total_steps - self.warmup_steps) as f64;
616            let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * progress).cos());
617            self.current_lr = self.min_lr + (self.target_lr - self.min_lr) * cosine_decay;
618        }
619
620        optimizer.set_lr(self.current_lr);
621    }
622
623    fn get_lr(&self) -> f64 {
624        self.current_lr
625    }
626
627    fn state_dict(&self) -> std::collections::HashMap<String, f64> {
628        let mut state = std::collections::HashMap::new();
629        state.insert("target_lr".to_string(), self.target_lr);
630        state.insert("min_lr".to_string(), self.min_lr);
631        state.insert("current_lr".to_string(), self.current_lr);
632        state.insert("current_step".to_string(), self.current_step as f64);
633        state.insert("warmup_steps".to_string(), self.warmup_steps as f64);
634        state.insert("total_steps".to_string(), self.total_steps as f64);
635        state
636    }
637
638    fn load_state_dict(
639        &mut self,
640        state: &std::collections::HashMap<String, f64>,
641    ) -> crate::TrainResult<()> {
642        if let Some(&current_lr) = state.get("current_lr") {
643            self.current_lr = current_lr;
644        }
645        if let Some(&current_step) = state.get("current_step") {
646            self.current_step = current_step as usize;
647        }
648        Ok(())
649    }
650}
651
652/// Noam scheduler (Transformer learning rate schedule).
653///
654/// This is the learning rate schedule used in "Attention is All You Need".
655/// It increases linearly for warmup_steps, then decays proportionally to the
656/// inverse square root of the step number.
657///
658/// Reference: Vaswani et al. "Attention is All You Need" (NIPS 2017)
659#[derive(Debug, Clone)]
660pub struct NoamScheduler {
661    /// Model dimension (d_model) from the paper.
662    model_dim: f64,
663    /// Number of warmup steps.
664    warmup_steps: usize,
665    /// Scaling factor (typically 1.0).
666    scale_factor: f64,
667    /// Current step counter.
668    current_step: usize,
669    /// Current learning rate.
670    current_lr: f64,
671}
672
673impl NoamScheduler {
674    /// Create a new Noam scheduler.
675    ///
676    /// # Arguments
677    /// * `model_dim` - Model dimension (d_model), typically 512 for Transformer
678    /// * `warmup_steps` - Number of warmup steps, typically 4000
679    /// * `scale_factor` - Scaling factor, typically 1.0 or 2.0
680    pub fn new(model_dim: usize, warmup_steps: usize, scale_factor: f64) -> Self {
681        let model_dim_f64 = model_dim as f64;
682        let current_lr = scale_factor * model_dim_f64.powf(-0.5);
683
684        Self {
685            model_dim: model_dim_f64,
686            warmup_steps,
687            scale_factor,
688            current_step: 0,
689            current_lr,
690        }
691    }
692
693    /// Compute learning rate for the current step.
694    fn compute_lr(&self) -> f64 {
695        let step = (self.current_step + 1) as f64; // +1 to avoid division by zero
696        let warmup = self.warmup_steps as f64;
697
698        // lr = scale * d_model^(-0.5) * min(step^(-0.5), step * warmup^(-1.5))
699        self.scale_factor
700            * self.model_dim.powf(-0.5)
701            * step.powf(-0.5).min(step * warmup.powf(-1.5))
702    }
703}
704
705impl LrScheduler for NoamScheduler {
706    fn step(&mut self, optimizer: &mut dyn Optimizer) {
707        self.current_step += 1;
708        self.current_lr = self.compute_lr();
709        optimizer.set_lr(self.current_lr);
710    }
711
712    fn get_lr(&self) -> f64 {
713        self.current_lr
714    }
715
716    fn state_dict(&self) -> std::collections::HashMap<String, f64> {
717        let mut state = std::collections::HashMap::new();
718        state.insert("model_dim".to_string(), self.model_dim);
719        state.insert("warmup_steps".to_string(), self.warmup_steps as f64);
720        state.insert("scale_factor".to_string(), self.scale_factor);
721        state.insert("current_step".to_string(), self.current_step as f64);
722        state.insert("current_lr".to_string(), self.current_lr);
723        state
724    }
725
726    fn load_state_dict(
727        &mut self,
728        state: &std::collections::HashMap<String, f64>,
729    ) -> crate::TrainResult<()> {
730        if let Some(&current_step) = state.get("current_step") {
731            self.current_step = current_step as usize;
732        }
733        if let Some(&current_lr) = state.get("current_lr") {
734            self.current_lr = current_lr;
735        }
736        Ok(())
737    }
738}
739
740/// Multi-step learning rate scheduler.
741///
742/// Decays the learning rate by gamma at specified milestones (epochs).
743/// This is useful when you know specific points where you want to reduce LR.
744#[derive(Debug, Clone)]
745pub struct MultiStepLrScheduler {
746    /// Initial learning rate.
747    pub initial_lr: f64,
748    /// Milestones (epochs) at which to decay LR.
749    pub milestones: Vec<usize>,
750    /// Multiplicative factor of learning rate decay.
751    pub gamma: f64,
752    /// Current epoch counter.
753    current_epoch: usize,
754    /// Current learning rate.
755    current_lr: f64,
756    /// Index of next milestone to trigger.
757    next_milestone_idx: usize,
758}
759
760impl MultiStepLrScheduler {
761    /// Create a new multi-step LR scheduler.
762    ///
763    /// # Arguments
764    /// * `initial_lr` - Initial learning rate
765    /// * `milestones` - Epochs at which to decay (should be sorted)
766    /// * `gamma` - Multiplicative decay factor
767    pub fn new(initial_lr: f64, mut milestones: Vec<usize>, gamma: f64) -> Self {
768        // Ensure milestones are sorted
769        milestones.sort_unstable();
770
771        Self {
772            initial_lr,
773            milestones,
774            gamma,
775            current_epoch: 0,
776            current_lr: initial_lr,
777            next_milestone_idx: 0,
778        }
779    }
780}
781
782impl LrScheduler for MultiStepLrScheduler {
783    fn step(&mut self, optimizer: &mut dyn Optimizer) {
784        self.current_epoch += 1;
785
786        // Check if we've reached a milestone
787        if self.next_milestone_idx < self.milestones.len()
788            && self.current_epoch >= self.milestones[self.next_milestone_idx]
789        {
790            self.current_lr *= self.gamma;
791            self.next_milestone_idx += 1;
792            optimizer.set_lr(self.current_lr);
793        }
794    }
795
796    fn get_lr(&self) -> f64 {
797        self.current_lr
798    }
799
800    fn state_dict(&self) -> std::collections::HashMap<String, f64> {
801        let mut state = std::collections::HashMap::new();
802        state.insert("initial_lr".to_string(), self.initial_lr);
803        state.insert("current_lr".to_string(), self.current_lr);
804        state.insert("current_epoch".to_string(), self.current_epoch as f64);
805        state.insert("gamma".to_string(), self.gamma);
806        state.insert(
807            "next_milestone_idx".to_string(),
808            self.next_milestone_idx as f64,
809        );
810        state
811    }
812
813    fn load_state_dict(
814        &mut self,
815        state: &std::collections::HashMap<String, f64>,
816    ) -> crate::TrainResult<()> {
817        if let Some(&current_lr) = state.get("current_lr") {
818            self.current_lr = current_lr;
819        }
820        if let Some(&current_epoch) = state.get("current_epoch") {
821            self.current_epoch = current_epoch as usize;
822        }
823        if let Some(&next_milestone_idx) = state.get("next_milestone_idx") {
824            self.next_milestone_idx = next_milestone_idx as usize;
825        }
826        Ok(())
827    }
828}
829
830/// Reduce learning rate on plateau (metric-based adaptive scheduler).
831///
832/// Reduces learning rate when a metric (e.g., validation loss) has stopped improving.
833/// This scheduler requires explicit metric updates via `step_with_metric()`.
834#[derive(Debug, Clone)]
835pub struct ReduceLROnPlateauScheduler {
836    /// Current learning rate.
837    current_lr: f64,
838    /// Decay factor.
839    pub factor: f64,
840    /// Number of epochs with no improvement after which LR will be reduced.
841    pub patience: usize,
842    /// Minimum LR.
843    pub min_lr: f64,
844    /// Threshold for measuring improvement (relative).
845    pub threshold: f64,
846    /// Number of epochs to wait before resuming normal operation after LR reduction.
847    pub cooldown: usize,
848    /// Best metric value seen so far.
849    best_metric: Option<f64>,
850    /// Number of epochs with no improvement.
851    num_bad_epochs: usize,
852    /// Epochs remaining in cooldown period.
853    cooldown_counter: usize,
854    /// Mode: "min" (lower is better) or "max" (higher is better).
855    mode: PlateauMode,
856}
857
858/// Mode for ReduceLROnPlateau scheduler.
859#[derive(Debug, Clone, Copy, PartialEq)]
860pub enum PlateauMode {
861    /// Lower metric values are better (e.g., loss).
862    Min,
863    /// Higher metric values are better (e.g., accuracy).
864    Max,
865}
866
867impl ReduceLROnPlateauScheduler {
868    /// Create a new ReduceLROnPlateau scheduler.
869    ///
870    /// # Arguments
871    /// * `initial_lr` - Initial learning rate
872    /// * `mode` - Whether to minimize or maximize the metric
873    /// * `factor` - Factor by which to reduce LR (new_lr = lr * factor)
874    /// * `patience` - Number of epochs with no improvement to wait
875    /// * `threshold` - Threshold for measuring improvement (relative)
876    /// * `min_lr` - Minimum LR (won't reduce below this)
877    /// * `cooldown` - Cooldown epochs after LR reduction
878    pub fn new(
879        initial_lr: f64,
880        mode: PlateauMode,
881        factor: f64,
882        patience: usize,
883        threshold: f64,
884        min_lr: f64,
885        cooldown: usize,
886    ) -> Self {
887        Self {
888            current_lr: initial_lr,
889            factor,
890            patience,
891            min_lr,
892            threshold,
893            cooldown,
894            best_metric: None,
895            num_bad_epochs: 0,
896            cooldown_counter: 0,
897            mode,
898        }
899    }
900
901    /// Step with a metric value.
902    ///
903    /// This should be called with the validation metric at the end of each epoch.
904    pub fn step_with_metric(&mut self, optimizer: &mut dyn Optimizer, metric: f64) {
905        // Check if in cooldown period
906        if self.cooldown_counter > 0 {
907            self.cooldown_counter -= 1;
908            return;
909        }
910
911        // Check if metric has improved
912        let is_better = match self.best_metric {
913            None => true, // First metric always sets the baseline
914            Some(best) => match self.mode {
915                PlateauMode::Min => metric < best * (1.0 - self.threshold),
916                PlateauMode::Max => metric > best * (1.0 + self.threshold),
917            },
918        };
919
920        if is_better {
921            // Metric improved
922            self.best_metric = Some(metric);
923            self.num_bad_epochs = 0;
924        } else {
925            // Metric didn't improve
926            self.num_bad_epochs += 1;
927
928            if self.num_bad_epochs >= self.patience {
929                // Reduce learning rate
930                let new_lr = (self.current_lr * self.factor).max(self.min_lr);
931
932                if new_lr < self.current_lr {
933                    self.current_lr = new_lr;
934                    optimizer.set_lr(self.current_lr);
935                    self.cooldown_counter = self.cooldown;
936                    self.num_bad_epochs = 0;
937                }
938            }
939        }
940    }
941}
942
943impl LrScheduler for ReduceLROnPlateauScheduler {
944    fn step(&mut self, _optimizer: &mut dyn Optimizer) {
945        // This scheduler needs metrics, so the default step() does nothing.
946        // Users should call step_with_metric() instead.
947    }
948
949    fn get_lr(&self) -> f64 {
950        self.current_lr
951    }
952
953    fn state_dict(&self) -> std::collections::HashMap<String, f64> {
954        let mut state = std::collections::HashMap::new();
955        state.insert("current_lr".to_string(), self.current_lr);
956        state.insert("factor".to_string(), self.factor);
957        state.insert("patience".to_string(), self.patience as f64);
958        state.insert("min_lr".to_string(), self.min_lr);
959        state.insert("threshold".to_string(), self.threshold);
960        state.insert("cooldown".to_string(), self.cooldown as f64);
961        state.insert(
962            "best_metric".to_string(),
963            self.best_metric.unwrap_or(f64::NAN),
964        );
965        state.insert("num_bad_epochs".to_string(), self.num_bad_epochs as f64);
966        state.insert("cooldown_counter".to_string(), self.cooldown_counter as f64);
967        state.insert(
968            "mode".to_string(),
969            match self.mode {
970                PlateauMode::Min => 0.0,
971                PlateauMode::Max => 1.0,
972            },
973        );
974        state
975    }
976
977    fn load_state_dict(
978        &mut self,
979        state: &std::collections::HashMap<String, f64>,
980    ) -> crate::TrainResult<()> {
981        if let Some(&current_lr) = state.get("current_lr") {
982            self.current_lr = current_lr;
983        }
984        if let Some(&best_metric) = state.get("best_metric") {
985            self.best_metric = if best_metric.is_nan() {
986                None
987            } else {
988                Some(best_metric)
989            };
990        }
991        if let Some(&num_bad_epochs) = state.get("num_bad_epochs") {
992            self.num_bad_epochs = num_bad_epochs as usize;
993        }
994        if let Some(&cooldown_counter) = state.get("cooldown_counter") {
995            self.cooldown_counter = cooldown_counter as usize;
996        }
997        Ok(())
998    }
999}
1000
1001#[cfg(test)]
1002mod tests {
1003    use super::*;
1004    use crate::{OptimizerConfig, SgdOptimizer};
1005
1006    #[test]
1007    fn test_step_lr_scheduler() {
1008        let config = OptimizerConfig {
1009            learning_rate: 0.1,
1010            ..Default::default()
1011        };
1012        let mut optimizer = SgdOptimizer::new(config);
1013        let mut scheduler = StepLrScheduler::new(0.1, 2, 0.5);
1014
1015        assert_eq!(scheduler.get_lr(), 0.1);
1016
1017        scheduler.step(&mut optimizer);
1018        assert_eq!(scheduler.get_lr(), 0.1);
1019
1020        scheduler.step(&mut optimizer);
1021        assert_eq!(scheduler.get_lr(), 0.05);
1022
1023        scheduler.step(&mut optimizer);
1024        assert_eq!(scheduler.get_lr(), 0.05);
1025
1026        scheduler.step(&mut optimizer);
1027        assert_eq!(scheduler.get_lr(), 0.025);
1028    }
1029
1030    #[test]
1031    fn test_exponential_lr_scheduler() {
1032        let config = OptimizerConfig {
1033            learning_rate: 0.1,
1034            ..Default::default()
1035        };
1036        let mut optimizer = SgdOptimizer::new(config);
1037        let mut scheduler = ExponentialLrScheduler::new(0.1, 0.9);
1038
1039        assert_eq!(scheduler.get_lr(), 0.1);
1040
1041        scheduler.step(&mut optimizer);
1042        assert!((scheduler.get_lr() - 0.09).abs() < 1e-6);
1043
1044        scheduler.step(&mut optimizer);
1045        assert!((scheduler.get_lr() - 0.081).abs() < 1e-6);
1046    }
1047
1048    #[test]
1049    fn test_cosine_annealing_scheduler() {
1050        let config = OptimizerConfig {
1051            learning_rate: 0.1,
1052            ..Default::default()
1053        };
1054        let mut optimizer = SgdOptimizer::new(config);
1055        let mut scheduler = CosineAnnealingLrScheduler::new(0.1, 0.01, 10);
1056
1057        assert_eq!(scheduler.get_lr(), 0.1);
1058
1059        scheduler.step(&mut optimizer);
1060        assert!(scheduler.get_lr() < 0.1);
1061        assert!(scheduler.get_lr() > 0.01);
1062
1063        // Step to halfway point
1064        for _ in 1..5 {
1065            scheduler.step(&mut optimizer);
1066        }
1067        let halfway_lr = scheduler.get_lr();
1068        assert!((halfway_lr - 0.055).abs() < 0.01); // Should be approximately at midpoint
1069    }
1070
1071    #[test]
1072    fn test_warmup_scheduler() {
1073        let config = OptimizerConfig {
1074            learning_rate: 0.0,
1075            ..Default::default()
1076        };
1077        let mut optimizer = SgdOptimizer::new(config);
1078        let mut scheduler = WarmupScheduler::new(0.1, 10);
1079
1080        assert_eq!(scheduler.get_lr(), 0.0);
1081
1082        scheduler.step(&mut optimizer);
1083        assert!((scheduler.get_lr() - 0.01).abs() < 1e-6);
1084
1085        for _ in 1..10 {
1086            scheduler.step(&mut optimizer);
1087        }
1088        assert_eq!(scheduler.get_lr(), 0.1);
1089
1090        scheduler.step(&mut optimizer);
1091        assert_eq!(scheduler.get_lr(), 0.1); // Stays at target after warmup
1092    }
1093
1094    #[test]
1095    fn test_one_cycle_scheduler() {
1096        let config = OptimizerConfig {
1097            learning_rate: 0.01,
1098            ..Default::default()
1099        };
1100        let mut optimizer = SgdOptimizer::new(config);
1101        let mut scheduler = OneCycleLrScheduler::new(0.01, 0.1, 0.001, 100, 0.3);
1102
1103        assert_eq!(scheduler.get_lr(), 0.01);
1104
1105        // Test increasing phase
1106        for _ in 0..30 {
1107            scheduler.step(&mut optimizer);
1108        }
1109        assert!(scheduler.get_lr() > 0.01);
1110        assert!(scheduler.get_lr() <= 0.1);
1111
1112        // Test decreasing phase
1113        for _ in 30..100 {
1114            scheduler.step(&mut optimizer);
1115        }
1116        assert!(scheduler.get_lr() < 0.1);
1117    }
1118
1119    #[test]
1120    fn test_polynomial_decay_scheduler() {
1121        let config = OptimizerConfig {
1122            learning_rate: 0.1,
1123            ..Default::default()
1124        };
1125        let mut optimizer = SgdOptimizer::new(config);
1126        let mut scheduler = PolynomialDecayLrScheduler::new(0.1, 0.001, 2.0, 100);
1127
1128        assert_eq!(scheduler.get_lr(), 0.1);
1129
1130        scheduler.step(&mut optimizer);
1131        assert!(scheduler.get_lr() < 0.1);
1132
1133        for _ in 1..100 {
1134            scheduler.step(&mut optimizer);
1135        }
1136        assert!((scheduler.get_lr() - 0.001).abs() < 1e-6);
1137    }
1138
1139    #[test]
1140    fn test_cyclic_lr_scheduler() {
1141        let config = OptimizerConfig {
1142            learning_rate: 0.01,
1143            ..Default::default()
1144        };
1145        let mut optimizer = SgdOptimizer::new(config);
1146        let mut scheduler = CyclicLrScheduler::new(0.01, 0.1, 10, CyclicLrMode::Triangular);
1147
1148        assert_eq!(scheduler.get_lr(), 0.01);
1149
1150        // Test first cycle
1151        for _ in 0..10 {
1152            scheduler.step(&mut optimizer);
1153        }
1154        assert!(scheduler.get_lr() > 0.01);
1155
1156        for _ in 10..20 {
1157            scheduler.step(&mut optimizer);
1158        }
1159        assert!(scheduler.get_lr() < 0.1);
1160    }
1161
1162    #[test]
1163    fn test_warmup_cosine_scheduler() {
1164        let config = OptimizerConfig {
1165            learning_rate: 0.0,
1166            ..Default::default()
1167        };
1168        let mut optimizer = SgdOptimizer::new(config);
1169        let mut scheduler = WarmupCosineLrScheduler::new(0.1, 0.001, 10, 100);
1170
1171        assert_eq!(scheduler.get_lr(), 0.0);
1172
1173        // Test warmup phase
1174        for _ in 0..10 {
1175            scheduler.step(&mut optimizer);
1176        }
1177        assert!((scheduler.get_lr() - 0.1).abs() < 1e-6);
1178
1179        // Test middle of cosine annealing phase
1180        for _ in 10..50 {
1181            scheduler.step(&mut optimizer);
1182        }
1183        assert!(scheduler.get_lr() < 0.1);
1184        assert!(scheduler.get_lr() > 0.001);
1185
1186        // Test near end of cosine annealing phase
1187        for _ in 50..100 {
1188            scheduler.step(&mut optimizer);
1189        }
1190        assert!(scheduler.get_lr() < 0.1);
1191        // At the end, LR should be close to min_lr
1192        assert!((scheduler.get_lr() - 0.001).abs() < 0.01);
1193    }
1194
1195    #[test]
1196    fn test_noam_scheduler() {
1197        let config = OptimizerConfig {
1198            learning_rate: 0.0,
1199            ..Default::default()
1200        };
1201        let mut optimizer = SgdOptimizer::new(config);
1202        let mut scheduler = NoamScheduler::new(512, 4000, 1.0);
1203
1204        let initial_lr = scheduler.get_lr();
1205        assert!(initial_lr > 0.0);
1206
1207        // Step once
1208        scheduler.step(&mut optimizer);
1209        let step1_lr = scheduler.get_lr();
1210
1211        // After step, LR should change
1212        assert!(step1_lr != initial_lr);
1213
1214        // At peak (warmup_steps), test decrease after that
1215        for _ in 1..4000 {
1216            scheduler.step(&mut optimizer);
1217        }
1218        let peak_lr = scheduler.get_lr();
1219
1220        // After warmup, LR should decrease
1221        for _ in 4000..8000 {
1222            scheduler.step(&mut optimizer);
1223        }
1224        assert!(scheduler.get_lr() < peak_lr);
1225    }
1226
1227    #[test]
1228    fn test_multistep_lr_scheduler() {
1229        let config = OptimizerConfig {
1230            learning_rate: 0.1,
1231            ..Default::default()
1232        };
1233        let mut optimizer = SgdOptimizer::new(config);
1234        let mut scheduler = MultiStepLrScheduler::new(0.1, vec![10, 20, 30], 0.1);
1235
1236        assert_eq!(scheduler.get_lr(), 0.1);
1237
1238        // Before first milestone
1239        for _ in 0..9 {
1240            scheduler.step(&mut optimizer);
1241        }
1242        assert_eq!(scheduler.get_lr(), 0.1);
1243
1244        // At first milestone (epoch 10)
1245        scheduler.step(&mut optimizer);
1246        assert!((scheduler.get_lr() - 0.01).abs() < 1e-6);
1247
1248        // Between first and second milestone
1249        for _ in 10..19 {
1250            scheduler.step(&mut optimizer);
1251        }
1252        assert!((scheduler.get_lr() - 0.01).abs() < 1e-6);
1253
1254        // At second milestone (epoch 20)
1255        scheduler.step(&mut optimizer);
1256        assert!((scheduler.get_lr() - 0.001).abs() < 1e-6);
1257
1258        // At third milestone (epoch 30)
1259        for _ in 20..29 {
1260            scheduler.step(&mut optimizer);
1261        }
1262        scheduler.step(&mut optimizer);
1263        assert!((scheduler.get_lr() - 0.0001).abs() < 1e-6);
1264    }
1265
1266    #[test]
1267    fn test_reduce_lr_on_plateau_min_mode() {
1268        let config = OptimizerConfig {
1269            learning_rate: 0.1,
1270            ..Default::default()
1271        };
1272        let mut optimizer = SgdOptimizer::new(config);
1273        let mut scheduler = ReduceLROnPlateauScheduler::new(
1274            0.1,              // initial_lr
1275            PlateauMode::Min, // mode
1276            0.5,              // factor
1277            3,                // patience
1278            0.01,             // threshold
1279            0.001,            // min_lr
1280            2,                // cooldown
1281        );
1282
1283        assert_eq!(scheduler.get_lr(), 0.1);
1284
1285        // Metric improving - LR should not change
1286        scheduler.step_with_metric(&mut optimizer, 1.0);
1287        assert_eq!(scheduler.get_lr(), 0.1);
1288
1289        scheduler.step_with_metric(&mut optimizer, 0.9);
1290        assert_eq!(scheduler.get_lr(), 0.1);
1291
1292        // Metric plateaus for patience epochs
1293        scheduler.step_with_metric(&mut optimizer, 0.9);
1294        assert_eq!(scheduler.get_lr(), 0.1);
1295
1296        scheduler.step_with_metric(&mut optimizer, 0.9);
1297        assert_eq!(scheduler.get_lr(), 0.1);
1298
1299        scheduler.step_with_metric(&mut optimizer, 0.9);
1300        // After patience epochs, LR should be reduced
1301        assert_eq!(scheduler.get_lr(), 0.05);
1302
1303        // Test cooldown - LR shouldn't change during cooldown
1304        scheduler.step_with_metric(&mut optimizer, 1.0);
1305        assert_eq!(scheduler.get_lr(), 0.05);
1306
1307        scheduler.step_with_metric(&mut optimizer, 1.0);
1308        assert_eq!(scheduler.get_lr(), 0.05);
1309    }
1310
1311    #[test]
1312    fn test_reduce_lr_on_plateau_max_mode() {
1313        let config = OptimizerConfig {
1314            learning_rate: 0.1,
1315            ..Default::default()
1316        };
1317        let mut optimizer = SgdOptimizer::new(config);
1318        let mut scheduler = ReduceLROnPlateauScheduler::new(
1319            0.1,
1320            PlateauMode::Max, // Maximize metric (e.g., accuracy)
1321            0.1,
1322            2,
1323            0.01,
1324            0.001,
1325            0,
1326        );
1327
1328        assert_eq!(scheduler.get_lr(), 0.1);
1329
1330        // Metric improving (increasing) - LR should not change
1331        scheduler.step_with_metric(&mut optimizer, 0.5);
1332        assert_eq!(scheduler.get_lr(), 0.1);
1333
1334        scheduler.step_with_metric(&mut optimizer, 0.6);
1335        assert_eq!(scheduler.get_lr(), 0.1);
1336
1337        // Metric plateaus
1338        scheduler.step_with_metric(&mut optimizer, 0.6);
1339        assert_eq!(scheduler.get_lr(), 0.1);
1340
1341        scheduler.step_with_metric(&mut optimizer, 0.6);
1342        // After patience epochs, LR should be reduced
1343        assert!((scheduler.get_lr() - 0.01).abs() < 1e-6);
1344    }
1345
1346    #[test]
1347    fn test_sgdr_scheduler() {
1348        let mut scheduler = SgdrScheduler::new(0.1, 0.001, 10, 2.0);
1349        let mut optimizer = SgdOptimizer::new(OptimizerConfig::default());
1350
1351        // At start, should be at max_lr
1352        let initial_lr = scheduler.get_current_lr();
1353        assert!((initial_lr - 0.1).abs() < 1e-6);
1354
1355        // Step through first period
1356        for _ in 0..5 {
1357            scheduler.step(&mut optimizer);
1358        }
1359
1360        // Should be decreasing
1361        let mid_lr = scheduler.get_lr();
1362        assert!(mid_lr < initial_lr);
1363
1364        // Complete first period and check restart
1365        for _ in 5..10 {
1366            scheduler.step(&mut optimizer);
1367        }
1368
1369        // After restart, should be back to max_lr
1370        scheduler.step(&mut optimizer);
1371        let restart_lr = scheduler.get_lr();
1372        assert!(restart_lr > mid_lr); // Should have restarted
1373
1374        // Period should have doubled
1375        assert_eq!(scheduler.current_period, 20);
1376    }
1377}
1378
1379/// SGDR: Stochastic Gradient Descent with Warm Restarts scheduler.
1380///
1381/// Based on "SGDR: Stochastic Gradient Descent with Warm Restarts" (Loshchilov & Hutter, 2017).
1382/// Periodically resets the learning rate to a high value and then decays it using cosine annealing.
1383#[derive(Debug, Clone)]
1384pub struct SgdrScheduler {
1385    /// Initial learning rate (after restart).
1386    pub max_lr: f64,
1387    /// Minimum learning rate.
1388    pub min_lr: f64,
1389    /// Initial period length.
1390    pub t_0: usize,
1391    /// Period multiplication factor after each restart.
1392    pub t_mult: f64,
1393    /// Current step within the period.
1394    current_step: usize,
1395    /// Current period length.
1396    current_period: usize,
1397    /// Total steps taken.
1398    total_steps: usize,
1399}
1400
1401impl SgdrScheduler {
1402    /// Create a new SGDR scheduler.
1403    ///
1404    /// # Arguments
1405    /// * `max_lr` - Maximum learning rate (used at restart)
1406    /// * `min_lr` - Minimum learning rate
1407    /// * `t_0` - Initial period length
1408    /// * `t_mult` - Period multiplication factor (typically 1.0 or 2.0)
1409    pub fn new(max_lr: f64, min_lr: f64, t_0: usize, t_mult: f64) -> Self {
1410        Self {
1411            max_lr,
1412            min_lr,
1413            t_0,
1414            t_mult,
1415            current_step: 0,
1416            current_period: t_0,
1417            total_steps: 0,
1418        }
1419    }
1420
1421    /// Get current learning rate based on cosine annealing within the period.
1422    fn get_current_lr(&self) -> f64 {
1423        let progress = self.current_step as f64 / self.current_period as f64;
1424        let cosine_factor = (1.0 + (std::f64::consts::PI * progress).cos()) / 2.0;
1425        self.min_lr + (self.max_lr - self.min_lr) * cosine_factor
1426    }
1427}
1428
1429impl LrScheduler for SgdrScheduler {
1430    fn step(&mut self, optimizer: &mut dyn Optimizer) {
1431        let lr = self.get_current_lr();
1432        optimizer.set_lr(lr);
1433
1434        self.current_step += 1;
1435        self.total_steps += 1;
1436
1437        // Check if we need to restart
1438        if self.current_step >= self.current_period {
1439            self.current_step = 0;
1440            self.current_period = (self.current_period as f64 * self.t_mult) as usize;
1441            // After restart, LR resets to max_lr
1442        }
1443    }
1444
1445    fn get_lr(&self) -> f64 {
1446        self.get_current_lr()
1447    }
1448
1449    fn state_dict(&self) -> std::collections::HashMap<String, f64> {
1450        let mut state = std::collections::HashMap::new();
1451        state.insert("max_lr".to_string(), self.max_lr);
1452        state.insert("min_lr".to_string(), self.min_lr);
1453        state.insert("t_0".to_string(), self.t_0 as f64);
1454        state.insert("t_mult".to_string(), self.t_mult);
1455        state.insert("current_step".to_string(), self.current_step as f64);
1456        state.insert("current_period".to_string(), self.current_period as f64);
1457        state.insert("total_steps".to_string(), self.total_steps as f64);
1458        state
1459    }
1460
1461    fn load_state_dict(
1462        &mut self,
1463        state: &std::collections::HashMap<String, f64>,
1464    ) -> crate::TrainResult<()> {
1465        if let Some(&max_lr) = state.get("max_lr") {
1466            self.max_lr = max_lr;
1467        }
1468        if let Some(&min_lr) = state.get("min_lr") {
1469            self.min_lr = min_lr;
1470        }
1471        if let Some(&t_0) = state.get("t_0") {
1472            self.t_0 = t_0 as usize;
1473        }
1474        if let Some(&t_mult) = state.get("t_mult") {
1475            self.t_mult = t_mult;
1476        }
1477        if let Some(&current_step) = state.get("current_step") {
1478            self.current_step = current_step as usize;
1479        }
1480        if let Some(&current_period) = state.get("current_period") {
1481            self.current_period = current_period as usize;
1482        }
1483        if let Some(&total_steps) = state.get("total_steps") {
1484            self.total_steps = total_steps as usize;
1485        }
1486        Ok(())
1487    }
1488}