Skip to main content

yscv_optim/
scheduler.rs

1use super::validate::{
2    validate_cosine_t_max, validate_lr, validate_one_cycle_final_div_factor,
3    validate_one_cycle_pct_start, validate_one_cycle_total_steps, validate_step_gamma,
4    validate_step_size, validate_warmup_steps,
5};
6use super::{LearningRate, OptimError};
7
8/// Scheduler abstraction for stateful learning-rate policies.
9pub trait LrScheduler {
10    /// Advances scheduler by one epoch and returns resulting optimizer LR.
11    fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError>;
12
13    /// Returns number of already-processed step calls.
14    fn epoch(&self) -> usize;
15
16    /// Resets scheduler internal state.
17    fn reset(&mut self);
18}
19
20/// Piecewise constant learning-rate scheduler.
21///
22/// Every `step_size` calls to [`StepLr::step`], the optimizer learning rate is
23/// multiplied by `gamma`.
24#[derive(Debug, Clone, PartialEq)]
25pub struct StepLr {
26    step_size: usize,
27    gamma: f32,
28    epoch: usize,
29}
30
31impl StepLr {
32    /// Creates step scheduler with required `step_size > 0` and `gamma in (0, 1]`.
33    pub fn new(step_size: usize, gamma: f32) -> Result<Self, OptimError> {
34        validate_step_size(step_size)?;
35        validate_step_gamma(gamma)?;
36        Ok(Self {
37            step_size,
38            gamma,
39            epoch: 0,
40        })
41    }
42
43    /// Returns configured step size.
44    pub fn step_size(&self) -> usize {
45        self.step_size
46    }
47
48    /// Returns configured decay factor.
49    pub fn gamma(&self) -> f32 {
50        self.gamma
51    }
52
53    /// Returns number of already-processed step calls.
54    pub fn epoch(&self) -> usize {
55        self.epoch
56    }
57
58    /// Resets internal epoch counter.
59    pub fn reset(&mut self) {
60        <Self as LrScheduler>::reset(self);
61    }
62
63    /// Advances scheduler by one epoch and returns resulting optimizer LR.
64    pub fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
65        <Self as LrScheduler>::step(self, optimizer)
66    }
67}
68
69impl LrScheduler for StepLr {
70    fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
71        self.epoch = self.epoch.saturating_add(1);
72        if self.epoch.is_multiple_of(self.step_size) {
73            let next_lr = optimizer.learning_rate() * self.gamma;
74            optimizer.set_learning_rate(next_lr)?;
75        }
76        Ok(optimizer.learning_rate())
77    }
78
79    fn epoch(&self) -> usize {
80        self.epoch
81    }
82
83    fn reset(&mut self) {
84        self.epoch = 0;
85    }
86}
87
88/// Cosine annealing learning-rate scheduler.
89///
90/// Computes:
91/// `lr = min_lr + 0.5 * (base_lr - min_lr) * (1 + cos(pi * t_cur / t_max))`,
92/// where `t_cur` is clamped to `t_max`.
93#[derive(Debug, Clone, PartialEq)]
94pub struct CosineAnnealingLr {
95    t_max: usize,
96    min_lr: f32,
97    epoch: usize,
98    base_lr: Option<f32>,
99}
100
101impl CosineAnnealingLr {
102    /// Creates cosine scheduler with `t_max > 0` and finite `min_lr >= 0`.
103    pub fn new(t_max: usize, min_lr: f32) -> Result<Self, OptimError> {
104        validate_cosine_t_max(t_max)?;
105        validate_lr(min_lr)?;
106        Ok(Self {
107            t_max,
108            min_lr,
109            epoch: 0,
110            base_lr: None,
111        })
112    }
113
114    /// Pins explicit base LR used by cosine schedule.
115    pub fn with_base_lr(mut self, base_lr: f32) -> Result<Self, OptimError> {
116        validate_lr(base_lr)?;
117        if self.min_lr > base_lr {
118            return Err(OptimError::SchedulerMinLrExceedsBase {
119                min_lr: self.min_lr,
120                base_lr,
121            });
122        }
123        self.base_lr = Some(base_lr);
124        Ok(self)
125    }
126
127    pub fn t_max(&self) -> usize {
128        self.t_max
129    }
130
131    pub fn min_lr(&self) -> f32 {
132        self.min_lr
133    }
134
135    pub fn base_lr(&self) -> Option<f32> {
136        self.base_lr
137    }
138
139    pub fn epoch(&self) -> usize {
140        self.epoch
141    }
142
143    pub fn reset(&mut self) {
144        <Self as LrScheduler>::reset(self);
145    }
146
147    pub fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
148        <Self as LrScheduler>::step(self, optimizer)
149    }
150}
151
152impl LrScheduler for CosineAnnealingLr {
153    fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
154        self.epoch = self.epoch.saturating_add(1);
155
156        let base_lr = match self.base_lr {
157            Some(base) => base,
158            None => {
159                let current = optimizer.learning_rate();
160                self.base_lr = Some(current);
161                current
162            }
163        };
164        if self.min_lr > base_lr {
165            return Err(OptimError::SchedulerMinLrExceedsBase {
166                min_lr: self.min_lr,
167                base_lr,
168            });
169        }
170
171        let t_cur = self.epoch.min(self.t_max) as f32;
172        let t_max = self.t_max as f32;
173        let cos_term = (std::f32::consts::PI * t_cur / t_max).cos();
174        let next_lr = self.min_lr + 0.5 * (base_lr - self.min_lr) * (1.0 + cos_term);
175        optimizer.set_learning_rate(next_lr)?;
176        Ok(next_lr)
177    }
178
179    fn epoch(&self) -> usize {
180        self.epoch
181    }
182
183    fn reset(&mut self) {
184        self.epoch = 0;
185    }
186}
187
188/// Linear warmup learning-rate scheduler.
189///
190/// Computes:
191/// `lr = start_lr + (base_lr - start_lr) * min(epoch, warmup_steps)/warmup_steps`.
192#[derive(Debug, Clone, PartialEq)]
193pub struct LinearWarmupLr {
194    warmup_steps: usize,
195    start_lr: Option<f32>,
196    base_lr: Option<f32>,
197    epoch: usize,
198}
199
200impl LinearWarmupLr {
201    /// Creates warmup scheduler with `warmup_steps > 0`.
202    pub fn new(warmup_steps: usize) -> Result<Self, OptimError> {
203        validate_warmup_steps(warmup_steps)?;
204        Ok(Self {
205            warmup_steps,
206            start_lr: None,
207            base_lr: None,
208            epoch: 0,
209        })
210    }
211
212    /// Sets explicit warmup start learning rate.
213    pub fn with_start_lr(mut self, start_lr: f32) -> Result<Self, OptimError> {
214        validate_lr(start_lr)?;
215        if let Some(base_lr) = self.base_lr
216            && start_lr > base_lr
217        {
218            return Err(OptimError::SchedulerStartLrExceedsBase { start_lr, base_lr });
219        }
220        self.start_lr = Some(start_lr);
221        Ok(self)
222    }
223
224    /// Sets explicit warmup end/base learning rate.
225    pub fn with_base_lr(mut self, base_lr: f32) -> Result<Self, OptimError> {
226        validate_lr(base_lr)?;
227        if let Some(start_lr) = self.start_lr
228            && start_lr > base_lr
229        {
230            return Err(OptimError::SchedulerStartLrExceedsBase { start_lr, base_lr });
231        }
232        self.base_lr = Some(base_lr);
233        Ok(self)
234    }
235
236    pub fn warmup_steps(&self) -> usize {
237        self.warmup_steps
238    }
239
240    pub fn start_lr(&self) -> Option<f32> {
241        self.start_lr
242    }
243
244    pub fn base_lr(&self) -> Option<f32> {
245        self.base_lr
246    }
247
248    pub fn epoch(&self) -> usize {
249        self.epoch
250    }
251
252    pub fn reset(&mut self) {
253        <Self as LrScheduler>::reset(self);
254    }
255
256    pub fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
257        <Self as LrScheduler>::step(self, optimizer)
258    }
259}
260
261impl LrScheduler for LinearWarmupLr {
262    fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
263        self.epoch = self.epoch.saturating_add(1);
264
265        let base_lr = match self.base_lr {
266            Some(base_lr) => base_lr,
267            None => {
268                let current = optimizer.learning_rate();
269                self.base_lr = Some(current);
270                current
271            }
272        };
273        let start_lr = self.start_lr.unwrap_or(0.0);
274        if start_lr > base_lr {
275            return Err(OptimError::SchedulerStartLrExceedsBase { start_lr, base_lr });
276        }
277
278        let warmup_ratio = self.epoch.min(self.warmup_steps) as f32 / self.warmup_steps as f32;
279        let next_lr = start_lr + (base_lr - start_lr) * warmup_ratio;
280        optimizer.set_learning_rate(next_lr)?;
281        Ok(next_lr)
282    }
283
284    fn epoch(&self) -> usize {
285        self.epoch
286    }
287
288    fn reset(&mut self) {
289        self.epoch = 0;
290    }
291}
292
293/// One-cycle learning-rate scheduler with linear warmup and linear cooldown.
294#[derive(Debug, Clone, PartialEq)]
295pub struct OneCycleLr {
296    total_steps: usize,
297    max_lr: f32,
298    pct_start: f32,
299    final_div_factor: f32,
300    initial_lr: Option<f32>,
301    epoch: usize,
302}
303
304impl OneCycleLr {
305    /// Creates one-cycle scheduler.
306    ///
307    /// - `total_steps > 0`
308    /// - `max_lr >= 0`
309    pub fn new(total_steps: usize, max_lr: f32) -> Result<Self, OptimError> {
310        validate_one_cycle_total_steps(total_steps)?;
311        validate_lr(max_lr)?;
312        Ok(Self {
313            total_steps,
314            max_lr,
315            pct_start: 0.3,
316            final_div_factor: 1_000.0,
317            initial_lr: None,
318            epoch: 0,
319        })
320    }
321
322    /// Sets fraction of cycle spent in the up phase.
323    pub fn with_pct_start(mut self, pct_start: f32) -> Result<Self, OptimError> {
324        validate_one_cycle_pct_start(pct_start)?;
325        self.pct_start = pct_start;
326        Ok(self)
327    }
328
329    /// Sets divisor used for final LR (`final_lr = initial_lr / final_div_factor`).
330    pub fn with_final_div_factor(mut self, final_div_factor: f32) -> Result<Self, OptimError> {
331        validate_one_cycle_final_div_factor(final_div_factor)?;
332        self.final_div_factor = final_div_factor;
333        Ok(self)
334    }
335
336    /// Pins explicit initial LR used by the schedule.
337    pub fn with_initial_lr(mut self, initial_lr: f32) -> Result<Self, OptimError> {
338        validate_lr(initial_lr)?;
339        if self.max_lr < initial_lr {
340            return Err(OptimError::SchedulerMaxLrBelowInitial {
341                max_lr: self.max_lr,
342                initial_lr,
343            });
344        }
345        self.initial_lr = Some(initial_lr);
346        Ok(self)
347    }
348
349    pub fn total_steps(&self) -> usize {
350        self.total_steps
351    }
352
353    pub fn max_lr(&self) -> f32 {
354        self.max_lr
355    }
356
357    pub fn pct_start(&self) -> f32 {
358        self.pct_start
359    }
360
361    pub fn final_div_factor(&self) -> f32 {
362        self.final_div_factor
363    }
364
365    pub fn initial_lr(&self) -> Option<f32> {
366        self.initial_lr
367    }
368
369    pub fn epoch(&self) -> usize {
370        self.epoch
371    }
372
373    pub fn reset(&mut self) {
374        <Self as LrScheduler>::reset(self);
375    }
376
377    pub fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
378        <Self as LrScheduler>::step(self, optimizer)
379    }
380}
381
382impl LrScheduler for OneCycleLr {
383    fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
384        self.epoch = self.epoch.saturating_add(1);
385
386        let initial_lr = match self.initial_lr {
387            Some(initial_lr) => initial_lr,
388            None => {
389                let current = optimizer.learning_rate();
390                self.initial_lr = Some(current);
391                current
392            }
393        };
394        if self.max_lr < initial_lr {
395            return Err(OptimError::SchedulerMaxLrBelowInitial {
396                max_lr: self.max_lr,
397                initial_lr,
398            });
399        }
400
401        let final_lr = initial_lr / self.final_div_factor;
402        let up_steps = one_cycle_up_steps(self.total_steps, self.pct_start);
403        let clamped_epoch = self.epoch.min(self.total_steps);
404        let next_lr = if clamped_epoch <= up_steps {
405            let progress = clamped_epoch as f32 / up_steps as f32;
406            initial_lr + (self.max_lr - initial_lr) * progress
407        } else {
408            let down_steps = self.total_steps.saturating_sub(up_steps).max(1);
409            let down_epoch = clamped_epoch - up_steps;
410            let progress = down_epoch as f32 / down_steps as f32;
411            self.max_lr - (self.max_lr - final_lr) * progress
412        };
413        optimizer.set_learning_rate(next_lr)?;
414        Ok(next_lr)
415    }
416
417    fn epoch(&self) -> usize {
418        self.epoch
419    }
420
421    fn reset(&mut self) {
422        self.epoch = 0;
423    }
424}
425
426fn one_cycle_up_steps(total_steps: usize, pct_start: f32) -> usize {
427    ((total_steps as f32 * pct_start).ceil() as usize).clamp(1, total_steps)
428}
429
430/// Exponential learning-rate scheduler.
431///
432/// Every step, the optimizer learning rate is multiplied by `gamma`:
433/// `lr = lr * gamma`.
434#[derive(Debug, Clone, PartialEq)]
435pub struct ExponentialLr {
436    gamma: f32,
437    epoch: usize,
438}
439
440impl ExponentialLr {
441    /// Creates exponential scheduler with `gamma in (0, 1]`.
442    pub fn new(gamma: f32) -> Result<Self, OptimError> {
443        validate_step_gamma(gamma)?;
444        Ok(Self { gamma, epoch: 0 })
445    }
446
447    /// Returns configured decay factor.
448    pub fn gamma(&self) -> f32 {
449        self.gamma
450    }
451
452    /// Returns number of already-processed step calls.
453    pub fn epoch(&self) -> usize {
454        self.epoch
455    }
456
457    /// Resets internal epoch counter.
458    pub fn reset(&mut self) {
459        <Self as LrScheduler>::reset(self);
460    }
461
462    /// Advances scheduler by one epoch and returns resulting optimizer LR.
463    pub fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
464        <Self as LrScheduler>::step(self, optimizer)
465    }
466}
467
468impl LrScheduler for ExponentialLr {
469    fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
470        self.epoch = self.epoch.saturating_add(1);
471        let next_lr = optimizer.learning_rate() * self.gamma;
472        optimizer.set_learning_rate(next_lr)?;
473        Ok(next_lr)
474    }
475
476    fn epoch(&self) -> usize {
477        self.epoch
478    }
479
480    fn reset(&mut self) {
481        self.epoch = 0;
482    }
483}
484
485/// Polynomial decay learning-rate scheduler.
486///
487/// Decays the learning rate from its initial value to `end_lr` over `total_steps`
488/// using a polynomial of the given `power`:
489/// `lr = (base_lr - end_lr) * (1 - epoch/total_steps)^power + end_lr`.
490#[derive(Debug, Clone, PartialEq)]
491pub struct PolynomialDecayLr {
492    total_steps: usize,
493    power: f32,
494    end_lr: f32,
495    base_lr: Option<f32>,
496    epoch: usize,
497}
498
499impl PolynomialDecayLr {
500    /// Creates polynomial decay scheduler.
501    ///
502    /// - `total_steps > 0`
503    /// - `power > 0` and finite
504    /// - `end_lr >= 0` and finite
505    pub fn new(total_steps: usize, power: f32, end_lr: f32) -> Result<Self, OptimError> {
506        if total_steps == 0 {
507            return Err(OptimError::InvalidStepSize {
508                step_size: total_steps,
509            });
510        }
511        if !power.is_finite() || power <= 0.0 {
512            return Err(OptimError::InvalidStepGamma { gamma: power });
513        }
514        validate_lr(end_lr)?;
515        Ok(Self {
516            total_steps,
517            power,
518            end_lr,
519            base_lr: None,
520            epoch: 0,
521        })
522    }
523
524    /// Pins explicit base LR used by the schedule.
525    pub fn with_base_lr(mut self, base_lr: f32) -> Result<Self, OptimError> {
526        validate_lr(base_lr)?;
527        self.base_lr = Some(base_lr);
528        Ok(self)
529    }
530
531    pub fn total_steps(&self) -> usize {
532        self.total_steps
533    }
534
535    pub fn power(&self) -> f32 {
536        self.power
537    }
538
539    pub fn end_lr(&self) -> f32 {
540        self.end_lr
541    }
542
543    pub fn base_lr(&self) -> Option<f32> {
544        self.base_lr
545    }
546
547    pub fn epoch(&self) -> usize {
548        self.epoch
549    }
550
551    pub fn reset(&mut self) {
552        <Self as LrScheduler>::reset(self);
553    }
554
555    pub fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
556        <Self as LrScheduler>::step(self, optimizer)
557    }
558}
559
560impl LrScheduler for PolynomialDecayLr {
561    fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
562        self.epoch = self.epoch.saturating_add(1);
563
564        let base_lr = match self.base_lr {
565            Some(base) => base,
566            None => {
567                let current = optimizer.learning_rate();
568                self.base_lr = Some(current);
569                current
570            }
571        };
572
573        let t = (self.epoch.min(self.total_steps) as f32) / (self.total_steps as f32);
574        let next_lr = (base_lr - self.end_lr) * (1.0 - t).powf(self.power) + self.end_lr;
575        optimizer.set_learning_rate(next_lr)?;
576        Ok(next_lr)
577    }
578
579    fn epoch(&self) -> usize {
580        self.epoch
581    }
582
583    fn reset(&mut self) {
584        self.epoch = 0;
585    }
586}
587
588/// Reduce learning rate when a metric has stopped improving.
589///
590/// When the metric has not improved for `patience` consecutive calls to
591/// [`ReduceLrOnPlateau::step_with_metric`], the learning rate is multiplied
592/// by `factor` (clamped to `min_lr`).
593#[derive(Debug, Clone, PartialEq)]
594pub struct ReduceLrOnPlateau {
595    factor: f32,
596    patience: usize,
597    min_lr: f32,
598    best_metric: f32,
599    wait: usize,
600    epoch: usize,
601}
602
603impl ReduceLrOnPlateau {
604    /// Creates a plateau scheduler.
605    ///
606    /// - `factor in (0, 1]`
607    /// - `patience >= 1`
608    /// - `min_lr >= 0` and finite
609    pub fn new(factor: f32, patience: usize, min_lr: f32) -> Result<Self, OptimError> {
610        validate_step_gamma(factor)?;
611        if patience == 0 {
612            return Err(OptimError::InvalidStepSize {
613                step_size: patience,
614            });
615        }
616        validate_lr(min_lr)?;
617        Ok(Self {
618            factor,
619            patience,
620            min_lr,
621            best_metric: f32::INFINITY,
622            wait: 0,
623            epoch: 0,
624        })
625    }
626
627    pub fn factor(&self) -> f32 {
628        self.factor
629    }
630
631    pub fn patience(&self) -> usize {
632        self.patience
633    }
634
635    pub fn min_lr(&self) -> f32 {
636        self.min_lr
637    }
638
639    pub fn best_metric(&self) -> f32 {
640        self.best_metric
641    }
642
643    pub fn wait(&self) -> usize {
644        self.wait
645    }
646
647    pub fn epoch(&self) -> usize {
648        self.epoch
649    }
650
651    pub fn reset(&mut self) {
652        <Self as LrScheduler>::reset(self);
653    }
654
655    /// Steps the scheduler with a metric value. If the metric has not improved
656    /// for `patience` consecutive steps, the LR is reduced by `factor`.
657    /// Lower metric is considered better.
658    pub fn step_with_metric<O: LearningRate>(
659        &mut self,
660        metric: f32,
661        optimizer: &mut O,
662    ) -> Result<f32, OptimError> {
663        self.epoch = self.epoch.saturating_add(1);
664
665        if metric < self.best_metric {
666            self.best_metric = metric;
667            self.wait = 0;
668        } else {
669            self.wait += 1;
670            if self.wait >= self.patience {
671                let next_lr = (optimizer.learning_rate() * self.factor).max(self.min_lr);
672                optimizer.set_learning_rate(next_lr)?;
673                self.wait = 0;
674            }
675        }
676
677        Ok(optimizer.learning_rate())
678    }
679}
680
681impl LrScheduler for ReduceLrOnPlateau {
682    /// Standard step without a metric does nothing to the LR
683    /// (use `step_with_metric` instead).
684    fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
685        self.epoch = self.epoch.saturating_add(1);
686        Ok(optimizer.learning_rate())
687    }
688
689    fn epoch(&self) -> usize {
690        self.epoch
691    }
692
693    fn reset(&mut self) {
694        self.epoch = 0;
695        self.best_metric = f32::INFINITY;
696        self.wait = 0;
697    }
698}
699
700/// Cyclic learning-rate scheduler with triangular policy.
701///
702/// Cycles the learning rate between `base_lr` and `max_lr` with a triangular
703/// waveform defined by `step_size_up` (ascending half) and `step_size_down`
704/// (descending half).
705#[derive(Debug, Clone, PartialEq)]
706pub struct CyclicLr {
707    base_lr: f32,
708    max_lr: f32,
709    step_size_up: usize,
710    step_size_down: usize,
711    epoch: usize,
712}
713
714impl CyclicLr {
715    /// Creates cyclic LR scheduler.
716    ///
717    /// - `base_lr >= 0`, `max_lr >= base_lr`
718    /// - `step_size_up > 0`, `step_size_down > 0`
719    pub fn new(
720        base_lr: f32,
721        max_lr: f32,
722        step_size_up: usize,
723        step_size_down: usize,
724    ) -> Result<Self, OptimError> {
725        validate_lr(base_lr)?;
726        validate_lr(max_lr)?;
727        if max_lr < base_lr {
728            return Err(OptimError::SchedulerMaxLrBelowInitial {
729                max_lr,
730                initial_lr: base_lr,
731            });
732        }
733        if step_size_up == 0 {
734            return Err(OptimError::InvalidStepSize {
735                step_size: step_size_up,
736            });
737        }
738        if step_size_down == 0 {
739            return Err(OptimError::InvalidStepSize {
740                step_size: step_size_down,
741            });
742        }
743        Ok(Self {
744            base_lr,
745            max_lr,
746            step_size_up,
747            step_size_down,
748            epoch: 0,
749        })
750    }
751
752    pub fn base_lr(&self) -> f32 {
753        self.base_lr
754    }
755
756    pub fn max_lr(&self) -> f32 {
757        self.max_lr
758    }
759
760    pub fn step_size_up(&self) -> usize {
761        self.step_size_up
762    }
763
764    pub fn step_size_down(&self) -> usize {
765        self.step_size_down
766    }
767
768    pub fn epoch(&self) -> usize {
769        self.epoch
770    }
771
772    pub fn reset(&mut self) {
773        <Self as LrScheduler>::reset(self);
774    }
775
776    pub fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
777        <Self as LrScheduler>::step(self, optimizer)
778    }
779}
780
781impl LrScheduler for CyclicLr {
782    fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
783        self.epoch = self.epoch.saturating_add(1);
784
785        let cycle_len = self.step_size_up + self.step_size_down;
786        let pos = (self.epoch - 1) % cycle_len; // 0-indexed position in current cycle
787
788        let next_lr = if pos < self.step_size_up {
789            // ascending
790            let progress = pos as f32 / self.step_size_up as f32;
791            self.base_lr + (self.max_lr - self.base_lr) * progress
792        } else {
793            // descending
794            let down_pos = pos - self.step_size_up;
795            let progress = down_pos as f32 / self.step_size_down as f32;
796            self.max_lr - (self.max_lr - self.base_lr) * progress
797        };
798
799        optimizer.set_learning_rate(next_lr)?;
800        Ok(next_lr)
801    }
802
803    fn epoch(&self) -> usize {
804        self.epoch
805    }
806
807    fn reset(&mut self) {
808        self.epoch = 0;
809    }
810}
811
812/// Lambda learning-rate scheduler.
813///
814/// Computes `lr = base_lr * lr_lambda(step_count)` at each step, where
815/// `lr_lambda` is a user-provided closure mapping epoch to a multiplicative
816/// factor.
817pub struct LambdaLr {
818    base_lr: f32,
819    current_lr: f32,
820    lr_lambda: Box<dyn Fn(usize) -> f32>,
821    step_count: usize,
822}
823
824impl LambdaLr {
825    /// Creates a lambda scheduler with the given base learning rate and lambda
826    /// function. The lambda receives the current epoch (after increment) and
827    /// returns a multiplicative factor applied to `base_lr`.
828    pub fn new(base_lr: f32, lr_lambda: Box<dyn Fn(usize) -> f32>) -> Self {
829        Self {
830            base_lr,
831            current_lr: base_lr,
832            lr_lambda,
833            step_count: 0,
834        }
835    }
836
837    /// Returns the base learning rate.
838    pub fn base_lr(&self) -> f32 {
839        self.base_lr
840    }
841
842    /// Returns the current learning rate.
843    pub fn current_lr(&self) -> f32 {
844        self.current_lr
845    }
846
847    /// Returns the current step count.
848    pub fn step_count(&self) -> usize {
849        self.step_count
850    }
851
852    /// Returns number of already-processed step calls.
853    pub fn epoch(&self) -> usize {
854        self.step_count
855    }
856
857    /// Resets internal state.
858    pub fn reset(&mut self) {
859        <Self as LrScheduler>::reset(self);
860    }
861
862    /// Advances scheduler by one epoch and returns resulting optimizer LR.
863    pub fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
864        <Self as LrScheduler>::step(self, optimizer)
865    }
866}
867
868impl LrScheduler for LambdaLr {
869    fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
870        self.step_count = self.step_count.saturating_add(1);
871        self.current_lr = self.base_lr * (self.lr_lambda)(self.step_count);
872        optimizer.set_learning_rate(self.current_lr)?;
873        Ok(self.current_lr)
874    }
875
876    fn epoch(&self) -> usize {
877        self.step_count
878    }
879
880    fn reset(&mut self) {
881        self.step_count = 0;
882        self.current_lr = self.base_lr;
883    }
884}
885
886/// Multi-step learning-rate scheduler.
887///
888/// Drops the learning rate by `gamma` at each milestone epoch in `milestones`.
889#[derive(Debug, Clone, PartialEq)]
890pub struct MultiStepLr {
891    milestones: Vec<usize>,
892    gamma: f32,
893    epoch: usize,
894    base_lr: Option<f32>,
895}
896
897impl MultiStepLr {
898    /// Creates multi-step scheduler with sorted `milestones` and `gamma in (0, 1]`.
899    pub fn new(mut milestones: Vec<usize>, gamma: f32) -> Result<Self, OptimError> {
900        validate_step_gamma(gamma)?;
901        milestones.sort();
902        milestones.dedup();
903        Ok(Self {
904            milestones,
905            gamma,
906            epoch: 0,
907            base_lr: None,
908        })
909    }
910
911    pub fn milestones(&self) -> &[usize] {
912        &self.milestones
913    }
914
915    pub fn gamma(&self) -> f32 {
916        self.gamma
917    }
918
919    pub fn epoch(&self) -> usize {
920        self.epoch
921    }
922
923    pub fn reset(&mut self) {
924        self.epoch = 0;
925        self.base_lr = None;
926    }
927
928    pub fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
929        <Self as LrScheduler>::step(self, optimizer)
930    }
931}
932
933impl LrScheduler for MultiStepLr {
934    fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
935        self.epoch = self.epoch.saturating_add(1);
936
937        let base_lr = match self.base_lr {
938            Some(base) => base,
939            None => {
940                let current = optimizer.learning_rate();
941                self.base_lr = Some(current);
942                current
943            }
944        };
945
946        let num_decays = self.milestones.iter().filter(|&&m| self.epoch >= m).count();
947        let next_lr = base_lr * self.gamma.powi(num_decays as i32);
948        optimizer.set_learning_rate(next_lr)?;
949        Ok(next_lr)
950    }
951
952    fn epoch(&self) -> usize {
953        self.epoch
954    }
955
956    fn reset(&mut self) {
957        self.epoch = 0;
958        self.base_lr = None;
959    }
960}
961
962/// Cosine annealing with warm restarts learning-rate scheduler.
963///
964/// Within each period, LR follows cosine decay from `base_lr` to `eta_min`.
965/// After `t_0` epochs, the schedule restarts and the next period length is
966/// `t_0 * t_mult`.
967///
968/// Formula within each period:
969/// `lr = eta_min + 0.5 * (base_lr - eta_min) * (1 + cos(pi * t_cur / t_i))`
970#[derive(Debug, Clone, PartialEq)]
971pub struct CosineAnnealingWarmRestarts {
972    t_0: usize,
973    t_mult: usize,
974    eta_min: f32,
975    base_lr: Option<f32>,
976    epoch: usize,
977}
978
979impl CosineAnnealingWarmRestarts {
980    /// Creates a cosine warm restarts scheduler.
981    ///
982    /// - `t_0 > 0`: initial period length
983    /// - `t_mult >= 1`: period multiplier after each restart
984    /// - `eta_min >= 0`: minimum learning rate
985    pub fn new(t_0: usize, t_mult: usize, eta_min: f32) -> Result<Self, OptimError> {
986        validate_cosine_t_max(t_0)?;
987        if t_mult == 0 {
988            return Err(OptimError::InvalidStepSize { step_size: 0 });
989        }
990        validate_lr(eta_min)?;
991        Ok(Self {
992            t_0,
993            t_mult,
994            eta_min,
995            base_lr: None,
996            epoch: 0,
997        })
998    }
999
1000    /// Pins explicit base LR used by the schedule.
1001    pub fn with_base_lr(mut self, base_lr: f32) -> Result<Self, OptimError> {
1002        validate_lr(base_lr)?;
1003        if self.eta_min > base_lr {
1004            return Err(OptimError::SchedulerMinLrExceedsBase {
1005                min_lr: self.eta_min,
1006                base_lr,
1007            });
1008        }
1009        self.base_lr = Some(base_lr);
1010        Ok(self)
1011    }
1012
1013    pub fn t_0(&self) -> usize {
1014        self.t_0
1015    }
1016
1017    pub fn t_mult(&self) -> usize {
1018        self.t_mult
1019    }
1020
1021    pub fn eta_min(&self) -> f32 {
1022        self.eta_min
1023    }
1024
1025    pub fn base_lr(&self) -> Option<f32> {
1026        self.base_lr
1027    }
1028
1029    pub fn epoch(&self) -> usize {
1030        self.epoch
1031    }
1032
1033    pub fn reset(&mut self) {
1034        <Self as LrScheduler>::reset(self);
1035    }
1036
1037    pub fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
1038        <Self as LrScheduler>::step(self, optimizer)
1039    }
1040}
1041
1042impl LrScheduler for CosineAnnealingWarmRestarts {
1043    fn step<O: LearningRate>(&mut self, optimizer: &mut O) -> Result<f32, OptimError> {
1044        self.epoch = self.epoch.saturating_add(1);
1045
1046        let base_lr = match self.base_lr {
1047            Some(base) => base,
1048            None => {
1049                let current = optimizer.learning_rate();
1050                self.base_lr = Some(current);
1051                current
1052            }
1053        };
1054        if self.eta_min > base_lr {
1055            return Err(OptimError::SchedulerMinLrExceedsBase {
1056                min_lr: self.eta_min,
1057                base_lr,
1058            });
1059        }
1060
1061        // Determine current period and position within it
1062        let (t_cur, t_i) = cosine_warm_restarts_position(self.epoch, self.t_0, self.t_mult);
1063
1064        let cos_term = (std::f32::consts::PI * t_cur as f32 / t_i as f32).cos();
1065        let next_lr = self.eta_min + 0.5 * (base_lr - self.eta_min) * (1.0 + cos_term);
1066        optimizer.set_learning_rate(next_lr)?;
1067        Ok(next_lr)
1068    }
1069
1070    fn epoch(&self) -> usize {
1071        self.epoch
1072    }
1073
1074    fn reset(&mut self) {
1075        self.epoch = 0;
1076    }
1077}
1078
1079/// Returns `(t_cur, t_i)` where `t_cur` is the position within the current
1080/// period and `t_i` is the current period length.
1081fn cosine_warm_restarts_position(epoch: usize, t_0: usize, t_mult: usize) -> (usize, usize) {
1082    if t_mult == 1 {
1083        // All periods have the same length t_0
1084        let t_cur = ((epoch - 1) % t_0) + 1;
1085        (t_cur, t_0)
1086    } else {
1087        // Periods grow: t_0, t_0*t_mult, t_0*t_mult^2, ...
1088        let mut t_i = t_0;
1089        let mut cumulative = 0usize;
1090        loop {
1091            if epoch <= cumulative + t_i {
1092                let t_cur = epoch - cumulative;
1093                return (t_cur, t_i);
1094            }
1095            cumulative += t_i;
1096            t_i *= t_mult;
1097        }
1098    }
1099}