Skip to main content

tiny_recursive_rs/training/
scheduler.rs

1/// Cosine learning rate scheduler with warmup
2use std::f64::consts::PI;
3
4/// Cosine annealing learning rate scheduler configuration
5#[derive(Debug, Clone)]
6pub struct CosineSchedulerConfig {
7    /// Initial learning rate
8    pub lr_init: f64,
9    /// Minimum learning rate (at end of schedule)
10    pub lr_min: f64,
11    /// Number of warmup steps
12    pub warmup_steps: usize,
13    /// Total number of training steps
14    pub total_steps: usize,
15}
16
17impl Default for CosineSchedulerConfig {
18    fn default() -> Self {
19        Self {
20            lr_init: 1e-3,
21            lr_min: 1e-5,
22            warmup_steps: 1000,
23            total_steps: 100000,
24        }
25    }
26}
27
28/// Cosine learning rate scheduler
29///
30/// Implements cosine annealing with linear warmup:
31/// - Linear warmup from 0 to lr_init over warmup_steps
32/// - Cosine annealing from lr_init to lr_min over remaining steps
33pub struct CosineScheduler {
34    config: CosineSchedulerConfig,
35    current_step: usize,
36}
37
38impl CosineScheduler {
39    /// Create new cosine scheduler
40    pub fn new(config: CosineSchedulerConfig) -> Self {
41        Self {
42            config,
43            current_step: 0,
44        }
45    }
46
47    /// Get learning rate for current step
48    pub fn get_lr(&self) -> f64 {
49        self.get_lr_at_step(self.current_step)
50    }
51
52    /// Get learning rate for a specific step
53    pub fn get_lr_at_step(&self, step: usize) -> f64 {
54        if step < self.config.warmup_steps {
55            // Linear warmup: lr = lr_init * (step / warmup_steps)
56            self.config.lr_init * (step as f64 / self.config.warmup_steps as f64)
57        } else {
58            // Cosine annealing
59            let progress = (step - self.config.warmup_steps) as f64
60                / (self.config.total_steps - self.config.warmup_steps) as f64;
61
62            // Clamp progress to [0, 1]
63            let progress = progress.min(1.0).max(0.0);
64
65            // Cosine annealing formula:
66            // lr = lr_min + (lr_init - lr_min) * 0.5 * (1 + cos(π * progress))
67            let cosine_factor = 0.5 * (1.0 + (PI * progress).cos());
68            self.config.lr_min + (self.config.lr_init - self.config.lr_min) * cosine_factor
69        }
70    }
71
72    /// Step the scheduler (increment step counter)
73    pub fn step(&mut self) {
74        self.current_step += 1;
75    }
76
77    /// Get current step
78    pub fn get_step(&self) -> usize {
79        self.current_step
80    }
81
82    /// Reset scheduler to initial state
83    pub fn reset(&mut self) {
84        self.current_step = 0;
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91
92    #[test]
93    fn test_warmup_phase() {
94        let config = CosineSchedulerConfig {
95            lr_init: 1.0,
96            lr_min: 0.0,
97            warmup_steps: 100,
98            total_steps: 1000,
99        };
100
101        let scheduler = CosineScheduler::new(config);
102
103        // At step 0, lr should be 0
104        assert!((scheduler.get_lr_at_step(0) - 0.0).abs() < 1e-6);
105
106        // At step 50 (halfway through warmup), lr should be 0.5
107        assert!((scheduler.get_lr_at_step(50) - 0.5).abs() < 1e-6);
108
109        // At step 100 (end of warmup), lr should be 1.0
110        assert!((scheduler.get_lr_at_step(100) - 1.0).abs() < 1e-6);
111    }
112
113    #[test]
114    fn test_cosine_annealing() {
115        let config = CosineSchedulerConfig {
116            lr_init: 1.0,
117            lr_min: 0.0,
118            warmup_steps: 0,
119            total_steps: 1000,
120        };
121
122        let scheduler = CosineScheduler::new(config);
123
124        // At step 0, lr should be lr_init
125        assert!((scheduler.get_lr_at_step(0) - 1.0).abs() < 1e-6);
126
127        // At step 500 (halfway), lr should be ~0.5
128        let lr_mid = scheduler.get_lr_at_step(500);
129        assert!((lr_mid - 0.5).abs() < 0.1);
130
131        // At step 1000 (end), lr should be lr_min
132        assert!((scheduler.get_lr_at_step(1000) - 0.0).abs() < 1e-6);
133    }
134
135    #[test]
136    fn test_scheduler_stepping() {
137        let config = CosineSchedulerConfig {
138            lr_init: 1.0,
139            lr_min: 0.1,
140            warmup_steps: 10,
141            total_steps: 100,
142        };
143
144        let mut scheduler = CosineScheduler::new(config);
145
146        assert_eq!(scheduler.get_step(), 0);
147
148        scheduler.step();
149        assert_eq!(scheduler.get_step(), 1);
150
151        scheduler.step();
152        assert_eq!(scheduler.get_step(), 2);
153
154        // LR should be increasing during warmup
155        let lr1 = scheduler.get_lr_at_step(5);
156        let lr2 = scheduler.get_lr_at_step(8);
157        assert!(lr2 > lr1);
158    }
159
160    #[test]
161    fn test_reset() {
162        let config = CosineSchedulerConfig::default();
163        let mut scheduler = CosineScheduler::new(config);
164
165        scheduler.step();
166        scheduler.step();
167        assert_eq!(scheduler.get_step(), 2);
168
169        scheduler.reset();
170        assert_eq!(scheduler.get_step(), 0);
171    }
172
173    #[test]
174    fn test_lr_never_exceeds_init() {
175        let config = CosineSchedulerConfig {
176            lr_init: 1.0,
177            lr_min: 0.1,
178            warmup_steps: 100,
179            total_steps: 1000,
180        };
181
182        let scheduler = CosineScheduler::new(config.clone());
183
184        // Test that LR never exceeds lr_init
185        for step in 0..=config.total_steps {
186            let lr = scheduler.get_lr_at_step(step);
187            assert!(lr <= config.lr_init + 1e-6, "LR {} exceeds max {} at step {}", lr, config.lr_init, step);
188        }
189
190        // After warmup, LR should be >= lr_min
191        for step in config.warmup_steps..=config.total_steps {
192            let lr = scheduler.get_lr_at_step(step);
193            assert!(lr >= config.lr_min - 1e-6, "LR {} is below min {} at step {}", lr, config.lr_min, step);
194        }
195    }
196}