Skip to main content

scirs2_optimize/stochastic/
schedules.rs

1//! Learning Rate Schedules
2//!
3//! Provides a trait and concrete implementations for learning rate scheduling
4//! in stochastic optimization. These schedules control how the learning rate
5//! changes over training to improve convergence and final performance.
6
7use std::f64::consts::PI;
8
9/// Trait for learning rate schedules.
10///
11/// Implementors compute the learning rate for a given epoch or step from the
12/// base learning rate and any schedule-specific hyperparameters.
13pub trait LrSchedule: Send + Sync {
14    /// Compute the learning rate for the given epoch/step.
15    ///
16    /// # Arguments
17    /// * `epoch` - Current training epoch or step (0-indexed)
18    /// * `base_lr` - The initial/base learning rate
19    ///
20    /// # Returns
21    /// The learning rate to use at `epoch`
22    fn get_lr(&self, epoch: usize, base_lr: f64) -> f64;
23}
24
25// ─── Step Decay ──────────────────────────────────────────────────────────────
26
27/// Step decay schedule: multiplies learning rate by `gamma` every `step_size` epochs.
28///
29/// `lr = base_lr * gamma^(floor(epoch / step_size))`
30#[derive(Debug, Clone)]
31pub struct StepDecay {
32    /// Number of epochs between each decay step
33    pub step_size: usize,
34    /// Multiplicative decay factor (typically 0 < gamma < 1)
35    pub gamma: f64,
36}
37
38impl StepDecay {
39    /// Create a new step decay schedule.
40    ///
41    /// # Arguments
42    /// * `step_size` - Epochs between decay applications
43    /// * `gamma` - Multiplicative decay factor
44    pub fn new(step_size: usize, gamma: f64) -> Self {
45        Self { step_size, gamma }
46    }
47}
48
49impl LrSchedule for StepDecay {
50    fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
51        let steps = epoch / self.step_size.max(1);
52        base_lr * self.gamma.powi(steps as i32)
53    }
54}
55
56// ─── Cosine Annealing ────────────────────────────────────────────────────────
57
58/// Cosine annealing schedule: smoothly decays the learning rate following a
59/// cosine curve from `base_lr` down to `eta_min`.
60///
61/// `lr = eta_min + 0.5 * (base_lr - eta_min) * (1 + cos(π * epoch / t_max))`
62///
63/// Reference: Loshchilov & Hutter (2016), "SGDR: Stochastic Gradient Descent
64/// with Warm Restarts".
65#[derive(Debug, Clone)]
66pub struct CosineAnnealing {
67    /// Period of the cosine cycle (number of epochs for one full descent)
68    pub t_max: usize,
69    /// Minimum learning rate at the end of a cycle
70    pub eta_min: f64,
71}
72
73impl CosineAnnealing {
74    /// Create a new cosine annealing schedule.
75    ///
76    /// # Arguments
77    /// * `t_max` - Period (epochs) for one cosine cycle
78    /// * `eta_min` - Minimum learning rate
79    pub fn new(t_max: usize, eta_min: f64) -> Self {
80        Self { t_max, eta_min }
81    }
82}
83
84impl LrSchedule for CosineAnnealing {
85    fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
86        let t_max = self.t_max.max(1) as f64;
87        let cos_val = (PI * epoch as f64 / t_max).cos();
88        self.eta_min + 0.5 * (base_lr - self.eta_min) * (1.0 + cos_val)
89    }
90}
91
92// ─── One Cycle ───────────────────────────────────────────────────────────────
93
94/// Annealing strategy for the one-cycle policy.
95#[derive(Debug, Clone, Copy, PartialEq)]
96pub enum AnnealStrategy {
97    /// Cosine annealing (smooth, recommended)
98    Cos,
99    /// Linear annealing (simple)
100    Linear,
101}
102
103/// One-cycle learning rate policy.
104///
105/// Implements Smith & Touvron's 1cycle policy: the learning rate rises from
106/// `base_lr` to `max_lr` over the first `pct_start` fraction of training,
107/// then anneals back down to a minimum learning rate over the remainder.
108///
109/// Reference: Smith (2018), "A disciplined approach to neural network
110/// hyper-parameters".
111#[derive(Debug, Clone)]
112pub struct OneCycle {
113    /// Maximum learning rate (peak of the cycle)
114    pub max_lr: f64,
115    /// Fraction of total epochs for the increasing phase (0 < pct_start < 1)
116    pub pct_start: f64,
117    /// Annealing strategy for the decreasing phase
118    pub anneal_strategy: AnnealStrategy,
119    /// Total number of training epochs
120    pub total_epochs: usize,
121    /// Minimum (final) learning rate as a fraction of `base_lr`
122    pub div_factor: f64,
123    /// Final learning rate divisor (final_lr = base_lr / final_div_factor)
124    pub final_div_factor: f64,
125}
126
127impl OneCycle {
128    /// Create a new one-cycle schedule.
129    ///
130    /// # Arguments
131    /// * `max_lr` - Peak learning rate
132    /// * `pct_start` - Fraction of epochs for the warmup/increase phase
133    /// * `anneal_strategy` - How to anneal during the decrease phase
134    /// * `total_epochs` - Total training epochs
135    pub fn new(
136        max_lr: f64,
137        pct_start: f64,
138        anneal_strategy: AnnealStrategy,
139        total_epochs: usize,
140    ) -> Self {
141        Self {
142            max_lr,
143            pct_start: pct_start.clamp(0.0, 1.0),
144            anneal_strategy,
145            total_epochs,
146            div_factor: 25.0,
147            final_div_factor: 1e4,
148        }
149    }
150
151    /// Apply the chosen annealing strategy over the progress fraction [0,1].
152    fn anneal(&self, start: f64, end: f64, pct: f64) -> f64 {
153        let p = pct.clamp(0.0, 1.0);
154        match self.anneal_strategy {
155            AnnealStrategy::Cos => end + (start - end) / 2.0 * (1.0 + (PI * p).cos()),
156            AnnealStrategy::Linear => start + (end - start) * p,
157        }
158    }
159}
160
161impl LrSchedule for OneCycle {
162    fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
163        let total = self.total_epochs.max(1) as f64;
164        let pct = epoch as f64 / total;
165        let init_lr = base_lr / self.div_factor;
166        let final_lr = init_lr / self.final_div_factor;
167
168        if pct <= self.pct_start {
169            // Warmup / increasing phase
170            let phase_pct = if self.pct_start > 0.0 {
171                pct / self.pct_start
172            } else {
173                1.0
174            };
175            self.anneal(init_lr, self.max_lr, phase_pct)
176        } else {
177            // Annealing phase
178            let phase_pct = (pct - self.pct_start) / (1.0 - self.pct_start).max(1e-9);
179            self.anneal(self.max_lr, final_lr, phase_pct)
180        }
181    }
182}
183
184// ─── Warmup + Cosine ─────────────────────────────────────────────────────────
185
186/// Warmup followed by cosine decay schedule.
187///
188/// The learning rate increases linearly from 0 to `base_lr` over
189/// `warmup_steps` steps, then decays following a cosine curve down to
190/// `min_lr` over the remaining steps.
191///
192/// This is commonly used in Transformer training (Vaswani et al., 2017).
193#[derive(Debug, Clone)]
194pub struct WarmupCosine {
195    /// Number of warmup epochs/steps (linear ramp from 0 → base_lr)
196    pub warmup_steps: usize,
197    /// Total training epochs/steps
198    pub total_steps: usize,
199    /// Minimum learning rate at the end of cosine decay
200    pub min_lr: f64,
201}
202
203impl WarmupCosine {
204    /// Create a new warmup + cosine decay schedule.
205    ///
206    /// # Arguments
207    /// * `warmup_steps` - Epochs for linear warmup
208    /// * `total_steps` - Total training epochs
209    /// * `min_lr` - Minimum learning rate after full decay
210    pub fn new(warmup_steps: usize, total_steps: usize, min_lr: f64) -> Self {
211        Self {
212            warmup_steps,
213            total_steps,
214            min_lr,
215        }
216    }
217}
218
219impl LrSchedule for WarmupCosine {
220    fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
221        if epoch < self.warmup_steps {
222            // Linear warmup
223            let warmup = self.warmup_steps.max(1) as f64;
224            base_lr * epoch as f64 / warmup
225        } else {
226            // Cosine decay from base_lr to min_lr
227            let decay_steps = (self.total_steps.saturating_sub(self.warmup_steps)).max(1) as f64;
228            let step = (epoch - self.warmup_steps) as f64;
229            let cos_val = (PI * step / decay_steps).cos();
230            self.min_lr + 0.5 * (base_lr - self.min_lr) * (1.0 + cos_val)
231        }
232    }
233}
234
235// ─── Exponential Decay ───────────────────────────────────────────────────────
236
237/// Exponential decay schedule.
238///
239/// `lr = base_lr * gamma^epoch`
240#[derive(Debug, Clone)]
241pub struct ExponentialDecay {
242    /// Decay factor per epoch (typically close to 1, e.g. 0.99)
243    pub gamma: f64,
244}
245
246impl ExponentialDecay {
247    /// Create a new exponential decay schedule.
248    pub fn new(gamma: f64) -> Self {
249        Self { gamma }
250    }
251}
252
253impl LrSchedule for ExponentialDecay {
254    fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
255        base_lr * self.gamma.powi(epoch as i32)
256    }
257}
258
259// ─── Constant Schedule ───────────────────────────────────────────────────────
260
261/// Constant (no-op) schedule: always returns `base_lr` unchanged.
262#[derive(Debug, Clone, Default)]
263pub struct ConstantLr;
264
265impl LrSchedule for ConstantLr {
266    fn get_lr(&self, _epoch: usize, base_lr: f64) -> f64 {
267        base_lr
268    }
269}
270
271// ─── Polynomial Decay ────────────────────────────────────────────────────────
272
273/// Polynomial decay schedule.
274///
275/// `lr = base_lr * (1 - epoch / total_epochs)^power`
276#[derive(Debug, Clone)]
277pub struct PolynomialDecay {
278    /// Total epochs for decay
279    pub total_epochs: usize,
280    /// Power of the polynomial (1.0 = linear, 2.0 = quadratic)
281    pub power: f64,
282    /// Minimum learning rate floor
283    pub end_lr: f64,
284}
285
286impl PolynomialDecay {
287    /// Create a new polynomial decay schedule.
288    pub fn new(total_epochs: usize, power: f64, end_lr: f64) -> Self {
289        Self {
290            total_epochs,
291            power,
292            end_lr,
293        }
294    }
295}
296
297impl LrSchedule for PolynomialDecay {
298    fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
299        let total = self.total_epochs.max(1);
300        if epoch >= total {
301            return self.end_lr;
302        }
303        let decay = (1.0 - epoch as f64 / total as f64).powf(self.power);
304        let lr = (base_lr - self.end_lr) * decay + self.end_lr;
305        lr.max(self.end_lr)
306    }
307}
308
309// ─── Cyclic LR ───────────────────────────────────────────────────────────────
310
311/// Cyclic learning rate schedule.
312///
313/// Alternates between `min_lr` and `max_lr` in a triangular wave pattern.
314/// Reference: Smith (2017), "Cyclical Learning Rates for Training Neural Networks".
315#[derive(Debug, Clone)]
316pub struct CyclicLr {
317    /// Base (minimum) learning rate
318    pub base_lr: f64,
319    /// Maximum learning rate
320    pub max_lr: f64,
321    /// Half-cycle length in epochs
322    pub step_size: usize,
323}
324
325impl CyclicLr {
326    /// Create a new cyclic learning rate schedule.
327    pub fn new(base_lr: f64, max_lr: f64, step_size: usize) -> Self {
328        Self {
329            base_lr,
330            max_lr,
331            step_size: step_size.max(1),
332        }
333    }
334}
335
336impl LrSchedule for CyclicLr {
337    fn get_lr(&self, epoch: usize, _base_lr: f64) -> f64 {
338        let cycle = epoch / (2 * self.step_size);
339        let x = (epoch as f64 / self.step_size as f64) - 2.0 * cycle as f64 - 1.0;
340        let scale = (1.0 - x.abs()).max(0.0);
341        self.base_lr + (self.max_lr - self.base_lr) * scale
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use approx::assert_abs_diff_eq;
349
350    #[test]
351    fn test_step_decay() {
352        let sched = StepDecay::new(10, 0.5);
353        assert_abs_diff_eq!(sched.get_lr(0, 0.1), 0.1, epsilon = 1e-12);
354        assert_abs_diff_eq!(sched.get_lr(9, 0.1), 0.1, epsilon = 1e-12);
355        assert_abs_diff_eq!(sched.get_lr(10, 0.1), 0.05, epsilon = 1e-12);
356        assert_abs_diff_eq!(sched.get_lr(20, 0.1), 0.025, epsilon = 1e-12);
357    }
358
359    #[test]
360    fn test_cosine_annealing() {
361        let sched = CosineAnnealing::new(100, 0.0);
362        let lr_start = sched.get_lr(0, 1.0);
363        let lr_mid = sched.get_lr(50, 1.0);
364        let lr_end = sched.get_lr(100, 1.0);
365        assert_abs_diff_eq!(lr_start, 1.0, epsilon = 1e-12);
366        assert_abs_diff_eq!(lr_mid, 0.5, epsilon = 1e-10);
367        assert_abs_diff_eq!(lr_end, 0.0, epsilon = 1e-12);
368    }
369
370    #[test]
371    fn test_one_cycle_warmup_peak() {
372        let sched = OneCycle::new(0.1, 0.3, AnnealStrategy::Cos, 100);
373        // At pct=0: init_lr = base_lr/div_factor
374        let lr_start = sched.get_lr(0, 0.01);
375        // At pct_start=30%: should be near max_lr
376        let lr_peak = sched.get_lr(30, 0.01);
377        assert!(lr_peak >= lr_start, "peak must exceed start");
378        assert_abs_diff_eq!(lr_peak, sched.max_lr, epsilon = 1e-10);
379    }
380
381    #[test]
382    fn test_warmup_cosine() {
383        let sched = WarmupCosine::new(10, 100, 0.0);
384        // During warmup: should be linear
385        assert_abs_diff_eq!(sched.get_lr(0, 1.0), 0.0, epsilon = 1e-12);
386        assert_abs_diff_eq!(sched.get_lr(5, 1.0), 0.5, epsilon = 1e-12);
387        assert_abs_diff_eq!(sched.get_lr(10, 1.0), 1.0, epsilon = 1e-12);
388        // After warmup: cosine decay
389        let lr_after = sched.get_lr(55, 1.0);
390        assert!(lr_after < 1.0, "should decay after warmup");
391        assert!(lr_after >= 0.0, "should not go below min_lr");
392    }
393
394    #[test]
395    fn test_exponential_decay() {
396        let sched = ExponentialDecay::new(0.9);
397        assert_abs_diff_eq!(sched.get_lr(0, 1.0), 1.0, epsilon = 1e-12);
398        assert_abs_diff_eq!(sched.get_lr(1, 1.0), 0.9, epsilon = 1e-12);
399        assert_abs_diff_eq!(sched.get_lr(2, 1.0), 0.81, epsilon = 1e-12);
400    }
401
402    #[test]
403    fn test_constant_lr() {
404        let sched = ConstantLr;
405        for epoch in 0..100 {
406            assert_abs_diff_eq!(sched.get_lr(epoch, 0.01), 0.01, epsilon = 1e-12);
407        }
408    }
409
410    #[test]
411    fn test_cyclic_lr() {
412        let sched = CyclicLr::new(0.001, 0.01, 5);
413        // At epoch 0: base_lr
414        let lr0 = sched.get_lr(0, 0.0);
415        // At epoch 5: max_lr
416        let lr5 = sched.get_lr(5, 0.0);
417        assert_abs_diff_eq!(lr5, sched.max_lr, epsilon = 1e-10);
418        // At epoch 10: back to base_lr
419        let lr10 = sched.get_lr(10, 0.0);
420        assert_abs_diff_eq!(lr10, sched.base_lr, epsilon = 1e-10);
421        assert!(lr5 > lr0, "peak should exceed start");
422    }
423}