Skip to main content

tensorlogic_train/
lr_scheduler.rs

1//! Learning rate schedulers for TensorLogic training.
2//!
3//! Provides classic and adaptive scheduling strategies:
4//! - Step decay
5//! - Cosine annealing (with optional warm restarts)
6//! - Linear warmup
7//! - Cyclical learning rates
8//! - One-cycle policy
9
10use thiserror::Error;
11
12/// Error types for scheduler operations.
13#[derive(Debug, Error)]
14pub enum SchedulerError {
15    #[error("Invalid config: {0}")]
16    InvalidConfig(String),
17    #[error("Scheduler exhausted after {0} steps")]
18    Exhausted(usize),
19}
20
21/// Trait for learning rate schedulers.
22pub trait LrSchedulerV2: Send {
23    /// Advance one step and return the new learning rate.
24    fn step(&mut self) -> f64;
25    /// Return the current learning rate without advancing.
26    fn current_lr(&self) -> f64;
27    /// Reset the scheduler to its initial state.
28    fn reset(&mut self);
29    /// Total number of steps taken.
30    fn steps_taken(&self) -> usize;
31    /// Whether the scheduler has completed a cycle (if applicable).
32    fn completed_cycle(&self) -> bool {
33        false
34    }
35}
36
37// ------- StepDecayScheduler -------
38
39/// Multiplies the learning rate by `gamma` every `step_size` steps.
40///
41/// lr_t = base_lr * gamma^floor(t / step_size)
42pub struct StepDecayScheduler {
43    base_lr: f64,
44    gamma: f64,
45    step_size: usize,
46    current_step: usize,
47}
48
49impl StepDecayScheduler {
50    /// Create a new step decay scheduler.
51    ///
52    /// # Errors
53    /// Returns [`SchedulerError::InvalidConfig`] if any parameter is invalid.
54    pub fn new(base_lr: f64, gamma: f64, step_size: usize) -> Result<Self, SchedulerError> {
55        if base_lr <= 0.0 {
56            return Err(SchedulerError::InvalidConfig(
57                "base_lr must be positive".into(),
58            ));
59        }
60        if !(0.0..=1.0).contains(&gamma) {
61            return Err(SchedulerError::InvalidConfig(
62                "gamma must be in (0, 1]".into(),
63            ));
64        }
65        if step_size == 0 {
66            return Err(SchedulerError::InvalidConfig(
67                "step_size must be > 0".into(),
68            ));
69        }
70        Ok(StepDecayScheduler {
71            base_lr,
72            gamma,
73            step_size,
74            current_step: 0,
75        })
76    }
77}
78
79impl LrSchedulerV2 for StepDecayScheduler {
80    fn step(&mut self) -> f64 {
81        self.current_step += 1;
82        self.current_lr()
83    }
84
85    fn current_lr(&self) -> f64 {
86        let exponent = self.current_step / self.step_size;
87        self.base_lr * self.gamma.powi(exponent as i32)
88    }
89
90    fn reset(&mut self) {
91        self.current_step = 0;
92    }
93
94    fn steps_taken(&self) -> usize {
95        self.current_step
96    }
97}
98
99// ------- CosineAnnealingScheduler -------
100
101/// Cosine annealing with optional warm restarts (SGDR).
102///
103/// lr_t = min_lr + 0.5 * (max_lr - min_lr) * (1 + cos(pi * t_cur / t_max))
104///
105/// If `restart_period` is Some(T), restarts every T steps (warm restarts).
106pub struct CosineAnnealingScheduler {
107    max_lr: f64,
108    min_lr: f64,
109    t_max: usize,
110    restart_period: Option<usize>,
111    current_step: usize,
112    cycle_count: usize,
113}
114
115impl CosineAnnealingScheduler {
116    /// Create a new cosine annealing scheduler.
117    ///
118    /// # Errors
119    /// Returns [`SchedulerError::InvalidConfig`] if any parameter is invalid.
120    pub fn new(max_lr: f64, min_lr: f64, t_max: usize) -> Result<Self, SchedulerError> {
121        if max_lr < min_lr {
122            return Err(SchedulerError::InvalidConfig(
123                "max_lr must be >= min_lr".into(),
124            ));
125        }
126        if t_max == 0 {
127            return Err(SchedulerError::InvalidConfig("t_max must be > 0".into()));
128        }
129        Ok(CosineAnnealingScheduler {
130            max_lr,
131            min_lr,
132            t_max,
133            restart_period: None,
134            current_step: 0,
135            cycle_count: 0,
136        })
137    }
138
139    /// Enable warm restarts every `period` steps.
140    pub fn with_warm_restarts(mut self, period: usize) -> Self {
141        self.restart_period = Some(period);
142        self
143    }
144}
145
146impl LrSchedulerV2 for CosineAnnealingScheduler {
147    fn step(&mut self) -> f64 {
148        self.current_step += 1;
149        if let Some(period) = self.restart_period {
150            if period > 0 && self.current_step.is_multiple_of(period) {
151                self.current_step = 0;
152                self.cycle_count += 1;
153            }
154        }
155        self.current_lr()
156    }
157
158    fn current_lr(&self) -> f64 {
159        let t_cur = self.current_step.min(self.t_max) as f64;
160        let t_max = self.t_max as f64;
161        let cos_val = (std::f64::consts::PI * t_cur / t_max).cos();
162        self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1.0 + cos_val)
163    }
164
165    fn reset(&mut self) {
166        self.current_step = 0;
167        self.cycle_count = 0;
168    }
169
170    fn steps_taken(&self) -> usize {
171        self.current_step
172    }
173
174    fn completed_cycle(&self) -> bool {
175        self.cycle_count > 0
176    }
177}
178
179// ------- WarmupScheduler -------
180
181/// Linear warmup followed by another scheduler.
182///
183/// During warmup: lr = warmup_start_lr + (warmup_end_lr - warmup_start_lr) * (step / warmup_steps)
184/// After warmup: delegates to the inner scheduler.
185pub struct WarmupScheduler {
186    warmup_steps: usize,
187    warmup_start_lr: f64,
188    warmup_end_lr: f64,
189    inner: Box<dyn LrSchedulerV2>,
190    current_step: usize,
191    inner_started: bool,
192}
193
194impl WarmupScheduler {
195    /// Create a new warmup scheduler wrapping an inner scheduler.
196    ///
197    /// # Errors
198    /// Returns [`SchedulerError::InvalidConfig`] if `warmup_steps` is zero.
199    pub fn new(
200        warmup_steps: usize,
201        warmup_start_lr: f64,
202        warmup_end_lr: f64,
203        inner: Box<dyn LrSchedulerV2>,
204    ) -> Result<Self, SchedulerError> {
205        if warmup_steps == 0 {
206            return Err(SchedulerError::InvalidConfig(
207                "warmup_steps must be > 0".into(),
208            ));
209        }
210        Ok(WarmupScheduler {
211            warmup_steps,
212            warmup_start_lr,
213            warmup_end_lr,
214            inner,
215            current_step: 0,
216            inner_started: false,
217        })
218    }
219}
220
221impl LrSchedulerV2 for WarmupScheduler {
222    fn step(&mut self) -> f64 {
223        self.current_step += 1;
224        if self.current_step >= self.warmup_steps {
225            self.inner_started = true;
226            self.inner.step()
227        } else {
228            self.current_lr()
229        }
230    }
231
232    fn current_lr(&self) -> f64 {
233        if self.inner_started || self.current_step >= self.warmup_steps {
234            self.inner.current_lr()
235        } else {
236            let frac = self.current_step as f64 / self.warmup_steps as f64;
237            self.warmup_start_lr + frac * (self.warmup_end_lr - self.warmup_start_lr)
238        }
239    }
240
241    fn reset(&mut self) {
242        self.current_step = 0;
243        self.inner_started = false;
244        self.inner.reset();
245    }
246
247    fn steps_taken(&self) -> usize {
248        self.current_step
249    }
250}
251
252// ------- CyclicalScheduler -------
253
254/// Cyclical learning rates (CLR) — oscillates between min_lr and max_lr.
255///
256/// Uses triangular policy: linear up then linear down, period = 2 * step_size.
257pub struct CyclicalScheduler {
258    min_lr: f64,
259    max_lr: f64,
260    step_size: usize,
261    current_step: usize,
262}
263
264impl CyclicalScheduler {
265    /// Create a new cyclical learning rate scheduler.
266    ///
267    /// # Errors
268    /// Returns [`SchedulerError::InvalidConfig`] if any parameter is invalid.
269    pub fn new(min_lr: f64, max_lr: f64, step_size: usize) -> Result<Self, SchedulerError> {
270        if max_lr <= min_lr {
271            return Err(SchedulerError::InvalidConfig(
272                "max_lr must be > min_lr".into(),
273            ));
274        }
275        if step_size == 0 {
276            return Err(SchedulerError::InvalidConfig(
277                "step_size must be > 0".into(),
278            ));
279        }
280        Ok(CyclicalScheduler {
281            min_lr,
282            max_lr,
283            step_size,
284            current_step: 0,
285        })
286    }
287}
288
289impl LrSchedulerV2 for CyclicalScheduler {
290    fn step(&mut self) -> f64 {
291        self.current_step += 1;
292        self.current_lr()
293    }
294
295    fn current_lr(&self) -> f64 {
296        let cycle = self.current_step / (2 * self.step_size);
297        let x = (self.current_step as f64 / self.step_size as f64) - 2.0 * cycle as f64 - 1.0;
298        let frac = (1.0 - x.abs()).max(0.0);
299        self.min_lr + (self.max_lr - self.min_lr) * frac
300    }
301
302    fn reset(&mut self) {
303        self.current_step = 0;
304    }
305
306    fn steps_taken(&self) -> usize {
307        self.current_step
308    }
309}
310
311// ------- OneCycleLrScheduler -------
312
313/// One-cycle learning rate policy.
314///
315/// Phase 1 (pct_start of total_steps): linear ramp from base_lr to max_lr
316/// Phase 2 (remaining): cosine decay from max_lr to min_lr
317pub struct OneCycleLrScheduler {
318    base_lr: f64,
319    max_lr: f64,
320    min_lr: f64,
321    total_steps: usize,
322    pct_start: f64,
323    current_step: usize,
324}
325
326impl OneCycleLrScheduler {
327    /// Create a new one-cycle learning rate scheduler.
328    ///
329    /// # Errors
330    /// Returns [`SchedulerError::InvalidConfig`] if any parameter is invalid.
331    pub fn new(
332        base_lr: f64,
333        max_lr: f64,
334        min_lr: f64,
335        total_steps: usize,
336        pct_start: f64,
337    ) -> Result<Self, SchedulerError> {
338        if max_lr <= base_lr {
339            return Err(SchedulerError::InvalidConfig(
340                "max_lr must be > base_lr".into(),
341            ));
342        }
343        if !(0.0..=1.0).contains(&pct_start) {
344            return Err(SchedulerError::InvalidConfig(
345                "pct_start must be in [0, 1]".into(),
346            ));
347        }
348        if total_steps == 0 {
349            return Err(SchedulerError::InvalidConfig(
350                "total_steps must be > 0".into(),
351            ));
352        }
353        Ok(OneCycleLrScheduler {
354            base_lr,
355            max_lr,
356            min_lr,
357            total_steps,
358            pct_start,
359            current_step: 0,
360        })
361    }
362}
363
364impl LrSchedulerV2 for OneCycleLrScheduler {
365    fn step(&mut self) -> f64 {
366        self.current_step = (self.current_step + 1).min(self.total_steps);
367        self.current_lr()
368    }
369
370    fn current_lr(&self) -> f64 {
371        let warmup_steps = (self.total_steps as f64 * self.pct_start) as usize;
372        if self.current_step <= warmup_steps {
373            let frac = if warmup_steps == 0 {
374                1.0
375            } else {
376                self.current_step as f64 / warmup_steps as f64
377            };
378            self.base_lr + frac * (self.max_lr - self.base_lr)
379        } else {
380            let decay_steps = self.total_steps - warmup_steps;
381            let t = self.current_step - warmup_steps;
382            let frac = if decay_steps == 0 {
383                1.0
384            } else {
385                t as f64 / decay_steps as f64
386            };
387            let cos_val = (std::f64::consts::PI * frac).cos();
388            self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1.0 + cos_val)
389        }
390    }
391
392    fn reset(&mut self) {
393        self.current_step = 0;
394    }
395
396    fn steps_taken(&self) -> usize {
397        self.current_step
398    }
399}
400
401// ------- SchedulerConfig -------
402
403/// Builder for creating scheduler configurations.
404#[derive(Debug, Clone)]
405pub struct SchedulerConfig {
406    /// The type of scheduler to construct.
407    pub scheduler_type: SchedulerType,
408    /// Base learning rate.
409    pub base_lr: f64,
410    /// Optional maximum learning rate (used by cyclical and one-cycle schedulers).
411    pub max_lr: Option<f64>,
412    /// Optional minimum learning rate floor.
413    pub min_lr: Option<f64>,
414    /// Optional total number of training steps.
415    pub total_steps: Option<usize>,
416    /// Optional step size for decay / cyclical half-period.
417    pub step_size: Option<usize>,
418    /// Optional decay factor (used by step decay).
419    pub gamma: Option<f64>,
420    /// Optional number of warmup steps.
421    pub warmup_steps: Option<usize>,
422    /// Optional fraction of total steps used for warmup in one-cycle.
423    pub pct_start: Option<f64>,
424}
425
426/// Enum identifying the scheduler algorithm.
427#[derive(Debug, Clone, Copy, PartialEq, Eq)]
428pub enum SchedulerType {
429    /// Step decay: multiply LR by gamma every step_size steps.
430    StepDecay,
431    /// Cosine annealing without restarts.
432    CosineAnnealing,
433    /// Cosine annealing with warm restarts (SGDR).
434    CosineAnnealingWarmRestarts,
435    /// Linear warmup followed by an inner scheduler.
436    Warmup,
437    /// Cyclical (triangular) learning rates.
438    Cyclical,
439    /// One-cycle learning rate policy.
440    OneCycle,
441}
442
443impl SchedulerConfig {
444    /// Create a step-decay scheduler configuration.
445    pub fn step_decay(base_lr: f64, gamma: f64, step_size: usize) -> Self {
446        SchedulerConfig {
447            scheduler_type: SchedulerType::StepDecay,
448            base_lr,
449            max_lr: None,
450            min_lr: None,
451            total_steps: None,
452            step_size: Some(step_size),
453            gamma: Some(gamma),
454            warmup_steps: None,
455            pct_start: None,
456        }
457    }
458
459    /// Create a cosine annealing scheduler configuration.
460    pub fn cosine(base_lr: f64, min_lr: f64, t_max: usize) -> Self {
461        SchedulerConfig {
462            scheduler_type: SchedulerType::CosineAnnealing,
463            base_lr,
464            max_lr: None,
465            min_lr: Some(min_lr),
466            total_steps: Some(t_max),
467            step_size: None,
468            gamma: None,
469            warmup_steps: None,
470            pct_start: None,
471        }
472    }
473
474    /// Create a one-cycle scheduler configuration.
475    pub fn one_cycle(base_lr: f64, max_lr: f64, total_steps: usize) -> Self {
476        SchedulerConfig {
477            scheduler_type: SchedulerType::OneCycle,
478            base_lr,
479            max_lr: Some(max_lr),
480            min_lr: Some(base_lr * 0.01),
481            total_steps: Some(total_steps),
482            step_size: None,
483            gamma: None,
484            warmup_steps: None,
485            pct_start: Some(0.3),
486        }
487    }
488}
489
490// ------- Tests -------
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use approx::assert_abs_diff_eq;
496
497    // ---- StepDecayScheduler ----
498
499    #[test]
500    fn test_step_decay_initial_lr() {
501        let s = StepDecayScheduler::new(0.1, 0.5, 10).expect("valid config");
502        assert_abs_diff_eq!(s.current_lr(), 0.1, epsilon = 1e-10);
503    }
504
505    #[test]
506    fn test_step_decay_after_step_size() {
507        let mut s = StepDecayScheduler::new(0.1, 0.5, 5).expect("valid config");
508        for _ in 0..5 {
509            s.step();
510        }
511        // floor(5/5) = 1 → 0.1 * 0.5^1 = 0.05
512        assert_abs_diff_eq!(s.current_lr(), 0.05, epsilon = 1e-10);
513    }
514
515    #[test]
516    fn test_step_decay_multiple_decays() {
517        let mut s = StepDecayScheduler::new(0.1, 0.5, 4).expect("valid config");
518        for _ in 0..12 {
519            s.step();
520        }
521        // floor(12/4) = 3 → 0.1 * 0.5^3 = 0.0125
522        assert_abs_diff_eq!(s.current_lr(), 0.0125, epsilon = 1e-10);
523    }
524
525    #[test]
526    fn test_step_decay_invalid_gamma() {
527        let result = StepDecayScheduler::new(0.1, 1.5, 10);
528        assert!(result.is_err(), "gamma > 1.0 should return Err");
529    }
530
531    #[test]
532    fn test_step_decay_reset() {
533        let mut s = StepDecayScheduler::new(0.1, 0.5, 5).expect("valid config");
534        for _ in 0..10 {
535            s.step();
536        }
537        let after_steps = s.current_lr();
538        assert!(after_steps < 0.1, "LR should have decayed");
539        s.reset();
540        assert_abs_diff_eq!(s.current_lr(), 0.1, epsilon = 1e-10);
541        assert_eq!(s.steps_taken(), 0);
542    }
543
544    // ---- CosineAnnealingScheduler ----
545
546    #[test]
547    fn test_cosine_initial_is_max() {
548        let s = CosineAnnealingScheduler::new(0.1, 0.001, 100).expect("valid config");
549        // At step 0, cos(0) = 1 → lr = min + 0.5*(max-min)*(1+1) = max
550        assert_abs_diff_eq!(s.current_lr(), 0.1, epsilon = 1e-10);
551    }
552
553    #[test]
554    fn test_cosine_at_tmax() {
555        let mut s = CosineAnnealingScheduler::new(0.1, 0.001, 100).expect("valid config");
556        for _ in 0..100 {
557            s.step();
558        }
559        // At t_max, cos(pi) = -1 → lr = min + 0.5*(max-min)*0 = min
560        assert_abs_diff_eq!(s.current_lr(), 0.001, epsilon = 1e-10);
561    }
562
563    #[test]
564    fn test_cosine_monotone_decrease() {
565        let mut s = CosineAnnealingScheduler::new(0.1, 0.001, 50).expect("valid config");
566        let mut prev = s.current_lr();
567        for _ in 0..50 {
568            let lr = s.step();
569            assert!(
570                lr <= prev + 1e-12,
571                "LR should not increase: prev={prev}, lr={lr}"
572            );
573            prev = lr;
574        }
575    }
576
577    #[test]
578    fn test_cosine_warm_restarts_resets() {
579        let period = 10;
580        let mut s = CosineAnnealingScheduler::new(0.1, 0.001, 100)
581            .expect("valid config")
582            .with_warm_restarts(period);
583
584        // Step up to just before the restart
585        for _ in 0..(period - 1) {
586            s.step();
587        }
588        let lr_before_restart = s.current_lr();
589
590        // This step triggers the restart (current_step == period)
591        let lr_after_restart = s.step();
592
593        // After restart, current_step resets to 0 → LR should be near max
594        assert!(
595            lr_after_restart > lr_before_restart,
596            "LR should increase after warm restart: before={lr_before_restart}, after={lr_after_restart}"
597        );
598        assert!(s.completed_cycle());
599    }
600
601    #[test]
602    fn test_cosine_invalid_config() {
603        let result = CosineAnnealingScheduler::new(0.001, 0.1, 100);
604        assert!(result.is_err(), "max_lr < min_lr should return Err");
605    }
606
607    // ---- WarmupScheduler ----
608
609    #[test]
610    fn test_warmup_starts_low() {
611        let inner = Box::new(CosineAnnealingScheduler::new(0.1, 0.001, 100).expect("valid inner"));
612        let mut s = WarmupScheduler::new(10, 1e-6, 0.1, inner).expect("valid warmup config");
613        // step 1 → frac = 1/10 → lr ≈ 1e-6 + 0.1*(0.1 - 1e-6)
614        let lr = s.step();
615        assert!(
616            lr < 0.1,
617            "First warmup LR should be much less than warmup_end_lr"
618        );
619        assert!(lr > 0.0, "First warmup LR should be positive");
620    }
621
622    #[test]
623    fn test_warmup_ends_high() {
624        let inner = Box::new(CosineAnnealingScheduler::new(0.1, 0.001, 100).expect("valid inner"));
625        let mut s = WarmupScheduler::new(5, 0.0, 0.1, inner).expect("valid warmup config");
626        // After warmup_steps steps, delegates to inner
627        for _ in 0..5 {
628            s.step();
629        }
630        // Now inner has been stepped once (step at current_step == warmup_steps)
631        // The inner scheduler should return a value >= min_lr
632        let lr = s.current_lr();
633        assert!(
634            lr >= 0.001,
635            "After warmup, LR should be from inner scheduler (>= min_lr)"
636        );
637    }
638
639    #[test]
640    fn test_warmup_invalid_zero_steps() {
641        let inner = Box::new(CosineAnnealingScheduler::new(0.1, 0.001, 100).expect("valid inner"));
642        let result = WarmupScheduler::new(0, 0.0, 0.1, inner);
643        assert!(result.is_err(), "warmup_steps=0 should return Err");
644    }
645
646    // ---- CyclicalScheduler ----
647
648    #[test]
649    fn test_cyclical_min_at_start() {
650        let s = CyclicalScheduler::new(0.001, 0.1, 5).expect("valid config");
651        // At step 0: cycle=0, x = 0/5 - 0 - 1 = -1, frac = max(0, 1-1) = 0 → min_lr
652        assert_abs_diff_eq!(s.current_lr(), 0.001, epsilon = 1e-10);
653    }
654
655    #[test]
656    fn test_cyclical_max_at_half_period() {
657        let mut s = CyclicalScheduler::new(0.001, 0.1, 5).expect("valid config");
658        // At step step_size=5: cycle=0, x = 5/5 - 0 - 1 = 0, frac=1 → max_lr
659        for _ in 0..5 {
660            s.step();
661        }
662        assert_abs_diff_eq!(s.current_lr(), 0.1, epsilon = 1e-10);
663    }
664
665    #[test]
666    fn test_cyclical_min_at_full_period() {
667        let step_size = 5;
668        let mut s = CyclicalScheduler::new(0.001, 0.1, step_size).expect("valid config");
669        // At step 2*step_size=10: cycle=1, x = 10/5 - 2*1 - 1 = -1, frac=0 → min_lr
670        for _ in 0..(2 * step_size) {
671            s.step();
672        }
673        assert_abs_diff_eq!(s.current_lr(), 0.001, epsilon = 1e-10);
674    }
675
676    // ---- OneCycleLrScheduler ----
677
678    #[test]
679    fn test_one_cycle_starts_at_base() {
680        let s = OneCycleLrScheduler::new(0.001, 0.1, 0.0001, 100, 0.3).expect("valid config");
681        // At step 0 (no steps taken): frac=0 → base_lr
682        assert_abs_diff_eq!(s.current_lr(), 0.001, epsilon = 1e-10);
683    }
684
685    #[test]
686    fn test_one_cycle_peaks_at_warmup_end() {
687        let total_steps = 100;
688        let pct_start = 0.3;
689        let base_lr = 0.001;
690        let max_lr = 0.1;
691        let mut s = OneCycleLrScheduler::new(base_lr, max_lr, 0.0001, total_steps, pct_start)
692            .expect("valid config");
693        let warmup_steps = (total_steps as f64 * pct_start) as usize; // 30
694        for _ in 0..warmup_steps {
695            s.step();
696        }
697        // At exactly warmup_steps: frac=1.0 → base_lr + 1.0*(max_lr-base_lr) = max_lr
698        assert_abs_diff_eq!(s.current_lr(), max_lr, epsilon = 1e-9);
699    }
700
701    #[test]
702    fn test_one_cycle_ends_at_min() {
703        let total_steps = 100;
704        let min_lr = 0.0001;
705        let mut s =
706            OneCycleLrScheduler::new(0.001, 0.1, min_lr, total_steps, 0.3).expect("valid config");
707        for _ in 0..total_steps {
708            s.step();
709        }
710        // At total_steps: frac=1.0 in decay phase, cos(pi)=-1 → min_lr + 0 = min_lr
711        assert_abs_diff_eq!(s.current_lr(), min_lr, epsilon = 1e-9);
712    }
713
714    // ---- SchedulerConfig builders ----
715
716    #[test]
717    fn test_scheduler_config_builders() {
718        let step_cfg = SchedulerConfig::step_decay(0.1, 0.5, 10);
719        assert_eq!(step_cfg.scheduler_type, SchedulerType::StepDecay);
720        assert_abs_diff_eq!(step_cfg.base_lr, 0.1, epsilon = 1e-10);
721        assert_eq!(step_cfg.gamma, Some(0.5));
722        assert_eq!(step_cfg.step_size, Some(10));
723
724        let cosine_cfg = SchedulerConfig::cosine(0.1, 0.001, 100);
725        assert_eq!(cosine_cfg.scheduler_type, SchedulerType::CosineAnnealing);
726        assert_abs_diff_eq!(cosine_cfg.base_lr, 0.1, epsilon = 1e-10);
727        assert_eq!(cosine_cfg.min_lr, Some(0.001));
728        assert_eq!(cosine_cfg.total_steps, Some(100));
729
730        let oc_cfg = SchedulerConfig::one_cycle(0.001, 0.1, 500);
731        assert_eq!(oc_cfg.scheduler_type, SchedulerType::OneCycle);
732        assert_abs_diff_eq!(oc_cfg.base_lr, 0.001, epsilon = 1e-10);
733        assert_eq!(oc_cfg.max_lr, Some(0.1));
734        assert_eq!(oc_cfg.total_steps, Some(500));
735        assert_eq!(oc_cfg.pct_start, Some(0.3));
736        // min_lr should be base_lr * 0.01
737        assert_abs_diff_eq!(oc_cfg.min_lr.unwrap(), 0.001 * 0.01, epsilon = 1e-15);
738    }
739}