1use thiserror::Error;
11
12#[derive(Debug, Error)]
14pub enum SchedulerError {
15 #[error("Invalid config: {0}")]
16 InvalidConfig(String),
17 #[error("Scheduler exhausted after {0} steps")]
18 Exhausted(usize),
19}
20
21pub trait LrSchedulerV2: Send {
23 fn step(&mut self) -> f64;
25 fn current_lr(&self) -> f64;
27 fn reset(&mut self);
29 fn steps_taken(&self) -> usize;
31 fn completed_cycle(&self) -> bool {
33 false
34 }
35}
36
37pub struct StepDecayScheduler {
43 base_lr: f64,
44 gamma: f64,
45 step_size: usize,
46 current_step: usize,
47}
48
49impl StepDecayScheduler {
50 pub fn new(base_lr: f64, gamma: f64, step_size: usize) -> Result<Self, SchedulerError> {
55 if base_lr <= 0.0 {
56 return Err(SchedulerError::InvalidConfig(
57 "base_lr must be positive".into(),
58 ));
59 }
60 if !(0.0..=1.0).contains(&gamma) {
61 return Err(SchedulerError::InvalidConfig(
62 "gamma must be in (0, 1]".into(),
63 ));
64 }
65 if step_size == 0 {
66 return Err(SchedulerError::InvalidConfig(
67 "step_size must be > 0".into(),
68 ));
69 }
70 Ok(StepDecayScheduler {
71 base_lr,
72 gamma,
73 step_size,
74 current_step: 0,
75 })
76 }
77}
78
79impl LrSchedulerV2 for StepDecayScheduler {
80 fn step(&mut self) -> f64 {
81 self.current_step += 1;
82 self.current_lr()
83 }
84
85 fn current_lr(&self) -> f64 {
86 let exponent = self.current_step / self.step_size;
87 self.base_lr * self.gamma.powi(exponent as i32)
88 }
89
90 fn reset(&mut self) {
91 self.current_step = 0;
92 }
93
94 fn steps_taken(&self) -> usize {
95 self.current_step
96 }
97}
98
99pub struct CosineAnnealingScheduler {
107 max_lr: f64,
108 min_lr: f64,
109 t_max: usize,
110 restart_period: Option<usize>,
111 current_step: usize,
112 cycle_count: usize,
113}
114
115impl CosineAnnealingScheduler {
116 pub fn new(max_lr: f64, min_lr: f64, t_max: usize) -> Result<Self, SchedulerError> {
121 if max_lr < min_lr {
122 return Err(SchedulerError::InvalidConfig(
123 "max_lr must be >= min_lr".into(),
124 ));
125 }
126 if t_max == 0 {
127 return Err(SchedulerError::InvalidConfig("t_max must be > 0".into()));
128 }
129 Ok(CosineAnnealingScheduler {
130 max_lr,
131 min_lr,
132 t_max,
133 restart_period: None,
134 current_step: 0,
135 cycle_count: 0,
136 })
137 }
138
139 pub fn with_warm_restarts(mut self, period: usize) -> Self {
141 self.restart_period = Some(period);
142 self
143 }
144}
145
146impl LrSchedulerV2 for CosineAnnealingScheduler {
147 fn step(&mut self) -> f64 {
148 self.current_step += 1;
149 if let Some(period) = self.restart_period {
150 if period > 0 && self.current_step.is_multiple_of(period) {
151 self.current_step = 0;
152 self.cycle_count += 1;
153 }
154 }
155 self.current_lr()
156 }
157
158 fn current_lr(&self) -> f64 {
159 let t_cur = self.current_step.min(self.t_max) as f64;
160 let t_max = self.t_max as f64;
161 let cos_val = (std::f64::consts::PI * t_cur / t_max).cos();
162 self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1.0 + cos_val)
163 }
164
165 fn reset(&mut self) {
166 self.current_step = 0;
167 self.cycle_count = 0;
168 }
169
170 fn steps_taken(&self) -> usize {
171 self.current_step
172 }
173
174 fn completed_cycle(&self) -> bool {
175 self.cycle_count > 0
176 }
177}
178
179pub struct WarmupScheduler {
186 warmup_steps: usize,
187 warmup_start_lr: f64,
188 warmup_end_lr: f64,
189 inner: Box<dyn LrSchedulerV2>,
190 current_step: usize,
191 inner_started: bool,
192}
193
194impl WarmupScheduler {
195 pub fn new(
200 warmup_steps: usize,
201 warmup_start_lr: f64,
202 warmup_end_lr: f64,
203 inner: Box<dyn LrSchedulerV2>,
204 ) -> Result<Self, SchedulerError> {
205 if warmup_steps == 0 {
206 return Err(SchedulerError::InvalidConfig(
207 "warmup_steps must be > 0".into(),
208 ));
209 }
210 Ok(WarmupScheduler {
211 warmup_steps,
212 warmup_start_lr,
213 warmup_end_lr,
214 inner,
215 current_step: 0,
216 inner_started: false,
217 })
218 }
219}
220
221impl LrSchedulerV2 for WarmupScheduler {
222 fn step(&mut self) -> f64 {
223 self.current_step += 1;
224 if self.current_step >= self.warmup_steps {
225 self.inner_started = true;
226 self.inner.step()
227 } else {
228 self.current_lr()
229 }
230 }
231
232 fn current_lr(&self) -> f64 {
233 if self.inner_started || self.current_step >= self.warmup_steps {
234 self.inner.current_lr()
235 } else {
236 let frac = self.current_step as f64 / self.warmup_steps as f64;
237 self.warmup_start_lr + frac * (self.warmup_end_lr - self.warmup_start_lr)
238 }
239 }
240
241 fn reset(&mut self) {
242 self.current_step = 0;
243 self.inner_started = false;
244 self.inner.reset();
245 }
246
247 fn steps_taken(&self) -> usize {
248 self.current_step
249 }
250}
251
252pub struct CyclicalScheduler {
258 min_lr: f64,
259 max_lr: f64,
260 step_size: usize,
261 current_step: usize,
262}
263
264impl CyclicalScheduler {
265 pub fn new(min_lr: f64, max_lr: f64, step_size: usize) -> Result<Self, SchedulerError> {
270 if max_lr <= min_lr {
271 return Err(SchedulerError::InvalidConfig(
272 "max_lr must be > min_lr".into(),
273 ));
274 }
275 if step_size == 0 {
276 return Err(SchedulerError::InvalidConfig(
277 "step_size must be > 0".into(),
278 ));
279 }
280 Ok(CyclicalScheduler {
281 min_lr,
282 max_lr,
283 step_size,
284 current_step: 0,
285 })
286 }
287}
288
289impl LrSchedulerV2 for CyclicalScheduler {
290 fn step(&mut self) -> f64 {
291 self.current_step += 1;
292 self.current_lr()
293 }
294
295 fn current_lr(&self) -> f64 {
296 let cycle = self.current_step / (2 * self.step_size);
297 let x = (self.current_step as f64 / self.step_size as f64) - 2.0 * cycle as f64 - 1.0;
298 let frac = (1.0 - x.abs()).max(0.0);
299 self.min_lr + (self.max_lr - self.min_lr) * frac
300 }
301
302 fn reset(&mut self) {
303 self.current_step = 0;
304 }
305
306 fn steps_taken(&self) -> usize {
307 self.current_step
308 }
309}
310
311pub struct OneCycleLrScheduler {
318 base_lr: f64,
319 max_lr: f64,
320 min_lr: f64,
321 total_steps: usize,
322 pct_start: f64,
323 current_step: usize,
324}
325
326impl OneCycleLrScheduler {
327 pub fn new(
332 base_lr: f64,
333 max_lr: f64,
334 min_lr: f64,
335 total_steps: usize,
336 pct_start: f64,
337 ) -> Result<Self, SchedulerError> {
338 if max_lr <= base_lr {
339 return Err(SchedulerError::InvalidConfig(
340 "max_lr must be > base_lr".into(),
341 ));
342 }
343 if !(0.0..=1.0).contains(&pct_start) {
344 return Err(SchedulerError::InvalidConfig(
345 "pct_start must be in [0, 1]".into(),
346 ));
347 }
348 if total_steps == 0 {
349 return Err(SchedulerError::InvalidConfig(
350 "total_steps must be > 0".into(),
351 ));
352 }
353 Ok(OneCycleLrScheduler {
354 base_lr,
355 max_lr,
356 min_lr,
357 total_steps,
358 pct_start,
359 current_step: 0,
360 })
361 }
362}
363
364impl LrSchedulerV2 for OneCycleLrScheduler {
365 fn step(&mut self) -> f64 {
366 self.current_step = (self.current_step + 1).min(self.total_steps);
367 self.current_lr()
368 }
369
370 fn current_lr(&self) -> f64 {
371 let warmup_steps = (self.total_steps as f64 * self.pct_start) as usize;
372 if self.current_step <= warmup_steps {
373 let frac = if warmup_steps == 0 {
374 1.0
375 } else {
376 self.current_step as f64 / warmup_steps as f64
377 };
378 self.base_lr + frac * (self.max_lr - self.base_lr)
379 } else {
380 let decay_steps = self.total_steps - warmup_steps;
381 let t = self.current_step - warmup_steps;
382 let frac = if decay_steps == 0 {
383 1.0
384 } else {
385 t as f64 / decay_steps as f64
386 };
387 let cos_val = (std::f64::consts::PI * frac).cos();
388 self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1.0 + cos_val)
389 }
390 }
391
392 fn reset(&mut self) {
393 self.current_step = 0;
394 }
395
396 fn steps_taken(&self) -> usize {
397 self.current_step
398 }
399}
400
401#[derive(Debug, Clone)]
405pub struct SchedulerConfig {
406 pub scheduler_type: SchedulerType,
408 pub base_lr: f64,
410 pub max_lr: Option<f64>,
412 pub min_lr: Option<f64>,
414 pub total_steps: Option<usize>,
416 pub step_size: Option<usize>,
418 pub gamma: Option<f64>,
420 pub warmup_steps: Option<usize>,
422 pub pct_start: Option<f64>,
424}
425
426#[derive(Debug, Clone, Copy, PartialEq, Eq)]
428pub enum SchedulerType {
429 StepDecay,
431 CosineAnnealing,
433 CosineAnnealingWarmRestarts,
435 Warmup,
437 Cyclical,
439 OneCycle,
441}
442
443impl SchedulerConfig {
444 pub fn step_decay(base_lr: f64, gamma: f64, step_size: usize) -> Self {
446 SchedulerConfig {
447 scheduler_type: SchedulerType::StepDecay,
448 base_lr,
449 max_lr: None,
450 min_lr: None,
451 total_steps: None,
452 step_size: Some(step_size),
453 gamma: Some(gamma),
454 warmup_steps: None,
455 pct_start: None,
456 }
457 }
458
459 pub fn cosine(base_lr: f64, min_lr: f64, t_max: usize) -> Self {
461 SchedulerConfig {
462 scheduler_type: SchedulerType::CosineAnnealing,
463 base_lr,
464 max_lr: None,
465 min_lr: Some(min_lr),
466 total_steps: Some(t_max),
467 step_size: None,
468 gamma: None,
469 warmup_steps: None,
470 pct_start: None,
471 }
472 }
473
474 pub fn one_cycle(base_lr: f64, max_lr: f64, total_steps: usize) -> Self {
476 SchedulerConfig {
477 scheduler_type: SchedulerType::OneCycle,
478 base_lr,
479 max_lr: Some(max_lr),
480 min_lr: Some(base_lr * 0.01),
481 total_steps: Some(total_steps),
482 step_size: None,
483 gamma: None,
484 warmup_steps: None,
485 pct_start: Some(0.3),
486 }
487 }
488}
489
490#[cfg(test)]
493mod tests {
494 use super::*;
495 use approx::assert_abs_diff_eq;
496
497 #[test]
500 fn test_step_decay_initial_lr() {
501 let s = StepDecayScheduler::new(0.1, 0.5, 10).expect("valid config");
502 assert_abs_diff_eq!(s.current_lr(), 0.1, epsilon = 1e-10);
503 }
504
505 #[test]
506 fn test_step_decay_after_step_size() {
507 let mut s = StepDecayScheduler::new(0.1, 0.5, 5).expect("valid config");
508 for _ in 0..5 {
509 s.step();
510 }
511 assert_abs_diff_eq!(s.current_lr(), 0.05, epsilon = 1e-10);
513 }
514
515 #[test]
516 fn test_step_decay_multiple_decays() {
517 let mut s = StepDecayScheduler::new(0.1, 0.5, 4).expect("valid config");
518 for _ in 0..12 {
519 s.step();
520 }
521 assert_abs_diff_eq!(s.current_lr(), 0.0125, epsilon = 1e-10);
523 }
524
525 #[test]
526 fn test_step_decay_invalid_gamma() {
527 let result = StepDecayScheduler::new(0.1, 1.5, 10);
528 assert!(result.is_err(), "gamma > 1.0 should return Err");
529 }
530
531 #[test]
532 fn test_step_decay_reset() {
533 let mut s = StepDecayScheduler::new(0.1, 0.5, 5).expect("valid config");
534 for _ in 0..10 {
535 s.step();
536 }
537 let after_steps = s.current_lr();
538 assert!(after_steps < 0.1, "LR should have decayed");
539 s.reset();
540 assert_abs_diff_eq!(s.current_lr(), 0.1, epsilon = 1e-10);
541 assert_eq!(s.steps_taken(), 0);
542 }
543
544 #[test]
547 fn test_cosine_initial_is_max() {
548 let s = CosineAnnealingScheduler::new(0.1, 0.001, 100).expect("valid config");
549 assert_abs_diff_eq!(s.current_lr(), 0.1, epsilon = 1e-10);
551 }
552
553 #[test]
554 fn test_cosine_at_tmax() {
555 let mut s = CosineAnnealingScheduler::new(0.1, 0.001, 100).expect("valid config");
556 for _ in 0..100 {
557 s.step();
558 }
559 assert_abs_diff_eq!(s.current_lr(), 0.001, epsilon = 1e-10);
561 }
562
563 #[test]
564 fn test_cosine_monotone_decrease() {
565 let mut s = CosineAnnealingScheduler::new(0.1, 0.001, 50).expect("valid config");
566 let mut prev = s.current_lr();
567 for _ in 0..50 {
568 let lr = s.step();
569 assert!(
570 lr <= prev + 1e-12,
571 "LR should not increase: prev={prev}, lr={lr}"
572 );
573 prev = lr;
574 }
575 }
576
577 #[test]
578 fn test_cosine_warm_restarts_resets() {
579 let period = 10;
580 let mut s = CosineAnnealingScheduler::new(0.1, 0.001, 100)
581 .expect("valid config")
582 .with_warm_restarts(period);
583
584 for _ in 0..(period - 1) {
586 s.step();
587 }
588 let lr_before_restart = s.current_lr();
589
590 let lr_after_restart = s.step();
592
593 assert!(
595 lr_after_restart > lr_before_restart,
596 "LR should increase after warm restart: before={lr_before_restart}, after={lr_after_restart}"
597 );
598 assert!(s.completed_cycle());
599 }
600
601 #[test]
602 fn test_cosine_invalid_config() {
603 let result = CosineAnnealingScheduler::new(0.001, 0.1, 100);
604 assert!(result.is_err(), "max_lr < min_lr should return Err");
605 }
606
607 #[test]
610 fn test_warmup_starts_low() {
611 let inner = Box::new(CosineAnnealingScheduler::new(0.1, 0.001, 100).expect("valid inner"));
612 let mut s = WarmupScheduler::new(10, 1e-6, 0.1, inner).expect("valid warmup config");
613 let lr = s.step();
615 assert!(
616 lr < 0.1,
617 "First warmup LR should be much less than warmup_end_lr"
618 );
619 assert!(lr > 0.0, "First warmup LR should be positive");
620 }
621
622 #[test]
623 fn test_warmup_ends_high() {
624 let inner = Box::new(CosineAnnealingScheduler::new(0.1, 0.001, 100).expect("valid inner"));
625 let mut s = WarmupScheduler::new(5, 0.0, 0.1, inner).expect("valid warmup config");
626 for _ in 0..5 {
628 s.step();
629 }
630 let lr = s.current_lr();
633 assert!(
634 lr >= 0.001,
635 "After warmup, LR should be from inner scheduler (>= min_lr)"
636 );
637 }
638
639 #[test]
640 fn test_warmup_invalid_zero_steps() {
641 let inner = Box::new(CosineAnnealingScheduler::new(0.1, 0.001, 100).expect("valid inner"));
642 let result = WarmupScheduler::new(0, 0.0, 0.1, inner);
643 assert!(result.is_err(), "warmup_steps=0 should return Err");
644 }
645
646 #[test]
649 fn test_cyclical_min_at_start() {
650 let s = CyclicalScheduler::new(0.001, 0.1, 5).expect("valid config");
651 assert_abs_diff_eq!(s.current_lr(), 0.001, epsilon = 1e-10);
653 }
654
655 #[test]
656 fn test_cyclical_max_at_half_period() {
657 let mut s = CyclicalScheduler::new(0.001, 0.1, 5).expect("valid config");
658 for _ in 0..5 {
660 s.step();
661 }
662 assert_abs_diff_eq!(s.current_lr(), 0.1, epsilon = 1e-10);
663 }
664
665 #[test]
666 fn test_cyclical_min_at_full_period() {
667 let step_size = 5;
668 let mut s = CyclicalScheduler::new(0.001, 0.1, step_size).expect("valid config");
669 for _ in 0..(2 * step_size) {
671 s.step();
672 }
673 assert_abs_diff_eq!(s.current_lr(), 0.001, epsilon = 1e-10);
674 }
675
676 #[test]
679 fn test_one_cycle_starts_at_base() {
680 let s = OneCycleLrScheduler::new(0.001, 0.1, 0.0001, 100, 0.3).expect("valid config");
681 assert_abs_diff_eq!(s.current_lr(), 0.001, epsilon = 1e-10);
683 }
684
685 #[test]
686 fn test_one_cycle_peaks_at_warmup_end() {
687 let total_steps = 100;
688 let pct_start = 0.3;
689 let base_lr = 0.001;
690 let max_lr = 0.1;
691 let mut s = OneCycleLrScheduler::new(base_lr, max_lr, 0.0001, total_steps, pct_start)
692 .expect("valid config");
693 let warmup_steps = (total_steps as f64 * pct_start) as usize; for _ in 0..warmup_steps {
695 s.step();
696 }
697 assert_abs_diff_eq!(s.current_lr(), max_lr, epsilon = 1e-9);
699 }
700
701 #[test]
702 fn test_one_cycle_ends_at_min() {
703 let total_steps = 100;
704 let min_lr = 0.0001;
705 let mut s =
706 OneCycleLrScheduler::new(0.001, 0.1, min_lr, total_steps, 0.3).expect("valid config");
707 for _ in 0..total_steps {
708 s.step();
709 }
710 assert_abs_diff_eq!(s.current_lr(), min_lr, epsilon = 1e-9);
712 }
713
714 #[test]
717 fn test_scheduler_config_builders() {
718 let step_cfg = SchedulerConfig::step_decay(0.1, 0.5, 10);
719 assert_eq!(step_cfg.scheduler_type, SchedulerType::StepDecay);
720 assert_abs_diff_eq!(step_cfg.base_lr, 0.1, epsilon = 1e-10);
721 assert_eq!(step_cfg.gamma, Some(0.5));
722 assert_eq!(step_cfg.step_size, Some(10));
723
724 let cosine_cfg = SchedulerConfig::cosine(0.1, 0.001, 100);
725 assert_eq!(cosine_cfg.scheduler_type, SchedulerType::CosineAnnealing);
726 assert_abs_diff_eq!(cosine_cfg.base_lr, 0.1, epsilon = 1e-10);
727 assert_eq!(cosine_cfg.min_lr, Some(0.001));
728 assert_eq!(cosine_cfg.total_steps, Some(100));
729
730 let oc_cfg = SchedulerConfig::one_cycle(0.001, 0.1, 500);
731 assert_eq!(oc_cfg.scheduler_type, SchedulerType::OneCycle);
732 assert_abs_diff_eq!(oc_cfg.base_lr, 0.001, epsilon = 1e-10);
733 assert_eq!(oc_cfg.max_lr, Some(0.1));
734 assert_eq!(oc_cfg.total_steps, Some(500));
735 assert_eq!(oc_cfg.pct_start, Some(0.3));
736 assert_abs_diff_eq!(oc_cfg.min_lr.unwrap(), 0.001 * 0.01, epsilon = 1e-15);
738 }
739}