rust_lstm/
schedulers.rs

1use std::f64::consts::PI;
2
3/// Learning rate scheduler trait for adaptive learning rate adjustment during training
4pub trait LearningRateScheduler {
5    /// Get the learning rate for the current epoch
6    fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64;
7    
8    /// Reset the scheduler state (useful for multiple training runs)
9    fn reset(&mut self);
10    
11    /// Get the name of the scheduler for logging
12    fn name(&self) -> &'static str;
13}
14
15/// Constant learning rate (no scheduling)
16#[derive(Clone, Debug)]
17pub struct ConstantLR;
18
19impl LearningRateScheduler for ConstantLR {
20    fn get_lr(&mut self, _epoch: usize, base_lr: f64) -> f64 {
21        base_lr
22    }
23    
24    fn reset(&mut self) {}
25    
26    fn name(&self) -> &'static str {
27        "ConstantLR"
28    }
29}
30
31/// Step decay scheduler: multiply LR by gamma every step_size epochs
32#[derive(Clone, Debug)]
33pub struct StepLR {
34    step_size: usize,
35    gamma: f64,
36}
37
38impl StepLR {
39    pub fn new(step_size: usize, gamma: f64) -> Self {
40        StepLR { step_size, gamma }
41    }
42}
43
44impl LearningRateScheduler for StepLR {
45    fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
46        let steps = epoch / self.step_size;
47        base_lr * self.gamma.powi(steps as i32)
48    }
49    
50    fn reset(&mut self) {}
51    
52    fn name(&self) -> &'static str {
53        "StepLR"
54    }
55}
56
57/// Multi-step decay: multiply LR by gamma at specific milestones
58#[derive(Clone, Debug)]
59pub struct MultiStepLR {
60    milestones: Vec<usize>,
61    gamma: f64,
62}
63
64impl MultiStepLR {
65    pub fn new(milestones: Vec<usize>, gamma: f64) -> Self {
66        MultiStepLR { milestones, gamma }
67    }
68}
69
70impl LearningRateScheduler for MultiStepLR {
71    fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
72        let num_reductions = self.milestones.iter()
73            .filter(|&&milestone| epoch >= milestone)
74            .count();
75        base_lr * self.gamma.powi(num_reductions as i32)
76    }
77    
78    fn reset(&mut self) {}
79    
80    fn name(&self) -> &'static str {
81        "MultiStepLR"
82    }
83}
84
85/// Exponential decay scheduler: multiply LR by gamma every epoch
86#[derive(Clone, Debug)]
87pub struct ExponentialLR {
88    gamma: f64,
89}
90
91impl ExponentialLR {
92    pub fn new(gamma: f64) -> Self {
93        ExponentialLR { gamma }
94    }
95}
96
97impl LearningRateScheduler for ExponentialLR {
98    fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
99        base_lr * self.gamma.powi(epoch as i32)
100    }
101    
102    fn reset(&mut self) {}
103    
104    fn name(&self) -> &'static str {
105        "ExponentialLR"
106    }
107}
108
109/// Cosine annealing scheduler with warm restarts
110#[derive(Clone, Debug)]
111pub struct CosineAnnealingLR {
112    t_max: usize,
113    eta_min: f64,
114    last_epoch: usize,
115}
116
117impl CosineAnnealingLR {
118    pub fn new(t_max: usize, eta_min: f64) -> Self {
119        CosineAnnealingLR {
120            t_max,
121            eta_min,
122            last_epoch: 0,
123        }
124    }
125}
126
127impl LearningRateScheduler for CosineAnnealingLR {
128    fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
129        self.last_epoch = epoch;
130        if epoch == 0 {
131            return base_lr;
132        }
133        
134        let t = epoch % self.t_max;
135        self.eta_min + (base_lr - self.eta_min) * 
136            (1.0 + (PI * t as f64 / self.t_max as f64).cos()) / 2.0
137    }
138    
139    fn reset(&mut self) {
140        self.last_epoch = 0;
141    }
142    
143    fn name(&self) -> &'static str {
144        "CosineAnnealingLR"
145    }
146}
147
148/// Cosine annealing with warm restarts
149#[derive(Clone, Debug)]
150pub struct CosineAnnealingWarmRestarts {
151    t_0: usize,
152    t_mult: usize,
153    eta_min: f64,
154    last_restart: usize,
155    restart_count: usize,
156}
157
158impl CosineAnnealingWarmRestarts {
159    pub fn new(t_0: usize, t_mult: usize, eta_min: f64) -> Self {
160        CosineAnnealingWarmRestarts {
161            t_0,
162            t_mult,
163            eta_min,
164            last_restart: 0,
165            restart_count: 0,
166        }
167    }
168}
169
170impl LearningRateScheduler for CosineAnnealingWarmRestarts {
171    fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
172        if epoch == 0 {
173            return base_lr;
174        }
175        
176        let t_cur = epoch - self.last_restart;
177        let t_i = self.t_0 * self.t_mult.pow(self.restart_count as u32);
178        
179        if t_cur >= t_i {
180            self.last_restart = epoch;
181            self.restart_count += 1;
182            return base_lr;
183        }
184        
185        self.eta_min + (base_lr - self.eta_min) * 
186            (1.0 + (PI * t_cur as f64 / t_i as f64).cos()) / 2.0
187    }
188    
189    fn reset(&mut self) {
190        self.last_restart = 0;
191        self.restart_count = 0;
192    }
193    
194    fn name(&self) -> &'static str {
195        "CosineAnnealingWarmRestarts"
196    }
197}
198
199/// One cycle learning rate policy (popular for modern deep learning)
200#[derive(Clone, Debug)]
201pub struct OneCycleLR {
202    max_lr: f64,
203    total_steps: usize,
204    pct_start: f64,
205    anneal_strategy: AnnealStrategy,
206    div_factor: f64,
207    final_div_factor: f64,
208}
209
210#[derive(Clone, Debug)]
211pub enum AnnealStrategy {
212    Cos,
213    Linear,
214}
215
216impl OneCycleLR {
217    pub fn new(max_lr: f64, total_steps: usize) -> Self {
218        OneCycleLR {
219            max_lr,
220            total_steps,
221            pct_start: 0.3,
222            anneal_strategy: AnnealStrategy::Cos,
223            div_factor: 25.0,
224            final_div_factor: 10000.0,
225        }
226    }
227    
228    pub fn with_params(
229        max_lr: f64,
230        total_steps: usize,
231        pct_start: f64,
232        anneal_strategy: AnnealStrategy,
233        div_factor: f64,
234        final_div_factor: f64,
235    ) -> Self {
236        OneCycleLR {
237            max_lr,
238            total_steps,
239            pct_start,
240            anneal_strategy,
241            div_factor,
242            final_div_factor,
243        }
244    }
245}
246
247impl LearningRateScheduler for OneCycleLR {
248    fn get_lr(&mut self, epoch: usize, _base_lr: f64) -> f64 {
249        if epoch >= self.total_steps {
250            return self.max_lr / self.final_div_factor;
251        }
252        
253        let _step_ratio = epoch as f64 / self.total_steps as f64;
254        let warmup_steps = (self.total_steps as f64 * self.pct_start) as usize;
255        
256        if epoch < warmup_steps {
257            // Warmup phase
258            let warmup_ratio = epoch as f64 / warmup_steps as f64;
259            (self.max_lr / self.div_factor) + 
260                (self.max_lr - self.max_lr / self.div_factor) * warmup_ratio
261        } else {
262            // Annealing phase
263            let anneal_ratio = (epoch - warmup_steps) as f64 / 
264                (self.total_steps - warmup_steps) as f64;
265            
266            match self.anneal_strategy {
267                AnnealStrategy::Cos => {
268                    let cos_factor = (1.0 + (PI * anneal_ratio).cos()) / 2.0;
269                    (self.max_lr / self.final_div_factor) + 
270                        (self.max_lr - self.max_lr / self.final_div_factor) * cos_factor
271                },
272                AnnealStrategy::Linear => {
273                    self.max_lr - (self.max_lr - self.max_lr / self.final_div_factor) * anneal_ratio
274                }
275            }
276        }
277    }
278    
279    fn reset(&mut self) {}
280    
281    fn name(&self) -> &'static str {
282        "OneCycleLR"
283    }
284}
285
286/// Reduce learning rate on plateau (when validation loss stops improving)
287#[derive(Clone, Debug)]
288pub struct ReduceLROnPlateau {
289    factor: f64,
290    patience: usize,
291    threshold: f64,
292    cooldown: usize,
293    min_lr: f64,
294    best_loss: f64,
295    wait_count: usize,
296    cooldown_counter: usize,
297    current_lr: f64,
298}
299
300impl ReduceLROnPlateau {
301    pub fn new(factor: f64, patience: usize) -> Self {
302        ReduceLROnPlateau {
303            factor,
304            patience,
305            threshold: 1e-4,
306            cooldown: 0,
307            min_lr: 0.0,
308            best_loss: f64::INFINITY,
309            wait_count: 0,
310            cooldown_counter: 0,
311            current_lr: 0.0,
312        }
313    }
314    
315    pub fn with_params(
316        factor: f64,
317        patience: usize,
318        threshold: f64,
319        cooldown: usize,
320        min_lr: f64,
321    ) -> Self {
322        ReduceLROnPlateau {
323            factor,
324            patience,
325            threshold,
326            cooldown,
327            min_lr,
328            best_loss: f64::INFINITY,
329            wait_count: 0,
330            cooldown_counter: 0,
331            current_lr: 0.0,
332        }
333    }
334    
335    /// Update the scheduler with the current validation loss
336    pub fn step(&mut self, val_loss: f64, base_lr: f64) -> f64 {
337        if self.current_lr == 0.0 {
338            self.current_lr = base_lr;
339        }
340        
341        if self.cooldown_counter > 0 {
342            self.cooldown_counter -= 1;
343            return self.current_lr;
344        }
345        
346        if val_loss < self.best_loss - self.threshold {
347            self.best_loss = val_loss;
348            self.wait_count = 0;
349        } else {
350            self.wait_count += 1;
351            
352            if self.wait_count >= self.patience {
353                let new_lr = self.current_lr * self.factor;
354                self.current_lr = new_lr.max(self.min_lr);
355                self.wait_count = 0;
356                self.cooldown_counter = self.cooldown;
357                println!("ReduceLROnPlateau: reducing learning rate to {:.2e}", self.current_lr);
358            }
359        }
360        
361        self.current_lr
362    }
363}
364
365impl LearningRateScheduler for ReduceLROnPlateau {
366    fn get_lr(&mut self, _epoch: usize, base_lr: f64) -> f64 {
367        if self.current_lr == 0.0 {
368            self.current_lr = base_lr;
369        }
370        self.current_lr
371    }
372    
373    fn reset(&mut self) {
374        self.best_loss = f64::INFINITY;
375        self.wait_count = 0;
376        self.cooldown_counter = 0;
377        self.current_lr = 0.0;
378    }
379    
380    fn name(&self) -> &'static str {
381        "ReduceLROnPlateau"
382    }
383}
384
385/// Linear learning rate schedule
386#[derive(Clone, Debug)]
387pub struct LinearLR {
388    start_factor: f64,
389    end_factor: f64,
390    total_iters: usize,
391}
392
393impl LinearLR {
394    pub fn new(start_factor: f64, end_factor: f64, total_iters: usize) -> Self {
395        LinearLR {
396            start_factor,
397            end_factor,
398            total_iters,
399        }
400    }
401}
402
403impl LearningRateScheduler for LinearLR {
404    fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
405        if epoch >= self.total_iters {
406            return base_lr * self.end_factor;
407        }
408        
409        let progress = epoch as f64 / self.total_iters as f64;
410        let factor = self.start_factor + 
411            (self.end_factor - self.start_factor) * progress;
412        
413        base_lr * factor
414    }
415    
416    fn reset(&mut self) {}
417    
418    fn name(&self) -> &'static str {
419        "LinearLR"
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    #[test]
428    fn test_constant_lr() {
429        let mut scheduler = ConstantLR;
430        let base_lr = 0.01;
431        
432        assert_eq!(scheduler.get_lr(0, base_lr), base_lr);
433        assert_eq!(scheduler.get_lr(10, base_lr), base_lr);
434        assert_eq!(scheduler.get_lr(100, base_lr), base_lr);
435    }
436
437    #[test]
438    fn test_step_lr() {
439        let mut scheduler = StepLR::new(10, 0.1);
440        let base_lr = 0.01;
441        
442        assert_eq!(scheduler.get_lr(0, base_lr), base_lr);
443        assert_eq!(scheduler.get_lr(9, base_lr), base_lr);
444        assert!((scheduler.get_lr(10, base_lr) - base_lr * 0.1).abs() < 1e-15);
445        assert!((scheduler.get_lr(20, base_lr) - base_lr * 0.01).abs() < 1e-15);
446    }
447
448    #[test]
449    fn test_exponential_lr() {
450        let mut scheduler = ExponentialLR::new(0.9);
451        let base_lr = 0.01;
452        
453        assert_eq!(scheduler.get_lr(0, base_lr), base_lr);
454        assert!((scheduler.get_lr(1, base_lr) - base_lr * 0.9).abs() < 1e-10);
455        assert!((scheduler.get_lr(2, base_lr) - base_lr * 0.81).abs() < 1e-10);
456    }
457
458    #[test]
459    fn test_multi_step_lr() {
460        let mut scheduler = MultiStepLR::new(vec![10, 20], 0.1);
461        let base_lr = 0.01;
462        
463        assert_eq!(scheduler.get_lr(5, base_lr), base_lr);
464        assert!((scheduler.get_lr(10, base_lr) - base_lr * 0.1).abs() < 1e-15);
465        assert!((scheduler.get_lr(15, base_lr) - base_lr * 0.1).abs() < 1e-15);
466        assert!((scheduler.get_lr(20, base_lr) - base_lr * 0.01).abs() < 1e-15);
467    }
468
469    #[test]
470    fn test_one_cycle_lr() {
471        let mut scheduler = OneCycleLR::new(0.1, 100);
472        let base_lr = 0.01;
473        
474        let lr_0 = scheduler.get_lr(0, base_lr);
475        let lr_30 = scheduler.get_lr(30, base_lr); // Should be close to max
476        let lr_100 = scheduler.get_lr(100, base_lr); // Should be very small
477        
478        assert!(lr_0 < lr_30);
479        assert!(lr_100 < lr_0);
480        assert!(lr_30 <= 0.1);
481    }
482
483    #[test]
484    fn test_reduce_lr_on_plateau() {
485        let mut scheduler = ReduceLROnPlateau::new(0.5, 2);
486        let base_lr = 0.01;
487        
488        // Should not reduce initially
489        let lr1 = scheduler.step(1.0, base_lr);
490        assert_eq!(lr1, base_lr);
491        
492        // Should not reduce with improving loss
493        let lr2 = scheduler.step(0.8, base_lr);
494        assert_eq!(lr2, base_lr);
495        
496        // Should reduce after patience epochs without improvement
497        let lr3 = scheduler.step(0.9, base_lr);
498        let lr4 = scheduler.step(0.9, base_lr);
499        let lr5 = scheduler.step(0.9, base_lr);
500        
501        assert!(lr5 < base_lr);
502        assert!((lr5 - base_lr * 0.5).abs() < 1e-10);
503    }
504
505    #[test]
506    fn test_linear_lr() {
507        let mut scheduler = LinearLR::new(1.0, 0.1, 10);
508        let base_lr = 0.01;
509        
510        assert_eq!(scheduler.get_lr(0, base_lr), base_lr);
511        assert!((scheduler.get_lr(5, base_lr) - base_lr * 0.55).abs() < 1e-10);
512        assert!((scheduler.get_lr(10, base_lr) - base_lr * 0.1).abs() < 1e-10);
513    }
514}