tiny_recursive_rs/training/
scheduler.rs1use std::f64::consts::PI;
3
4#[derive(Debug, Clone)]
6pub struct CosineSchedulerConfig {
7 pub lr_init: f64,
9 pub lr_min: f64,
11 pub warmup_steps: usize,
13 pub total_steps: usize,
15}
16
17impl Default for CosineSchedulerConfig {
18 fn default() -> Self {
19 Self {
20 lr_init: 1e-3,
21 lr_min: 1e-5,
22 warmup_steps: 1000,
23 total_steps: 100000,
24 }
25 }
26}
27
28pub struct CosineScheduler {
34 config: CosineSchedulerConfig,
35 current_step: usize,
36}
37
38impl CosineScheduler {
39 pub fn new(config: CosineSchedulerConfig) -> Self {
41 Self {
42 config,
43 current_step: 0,
44 }
45 }
46
47 pub fn get_lr(&self) -> f64 {
49 self.get_lr_at_step(self.current_step)
50 }
51
52 pub fn get_lr_at_step(&self, step: usize) -> f64 {
54 if step < self.config.warmup_steps {
55 self.config.lr_init * (step as f64 / self.config.warmup_steps as f64)
57 } else {
58 let progress = (step - self.config.warmup_steps) as f64
60 / (self.config.total_steps - self.config.warmup_steps) as f64;
61
62 let progress = progress.min(1.0).max(0.0);
64
65 let cosine_factor = 0.5 * (1.0 + (PI * progress).cos());
68 self.config.lr_min + (self.config.lr_init - self.config.lr_min) * cosine_factor
69 }
70 }
71
72 pub fn step(&mut self) {
74 self.current_step += 1;
75 }
76
77 pub fn get_step(&self) -> usize {
79 self.current_step
80 }
81
82 pub fn reset(&mut self) {
84 self.current_step = 0;
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91
92 #[test]
93 fn test_warmup_phase() {
94 let config = CosineSchedulerConfig {
95 lr_init: 1.0,
96 lr_min: 0.0,
97 warmup_steps: 100,
98 total_steps: 1000,
99 };
100
101 let scheduler = CosineScheduler::new(config);
102
103 assert!((scheduler.get_lr_at_step(0) - 0.0).abs() < 1e-6);
105
106 assert!((scheduler.get_lr_at_step(50) - 0.5).abs() < 1e-6);
108
109 assert!((scheduler.get_lr_at_step(100) - 1.0).abs() < 1e-6);
111 }
112
113 #[test]
114 fn test_cosine_annealing() {
115 let config = CosineSchedulerConfig {
116 lr_init: 1.0,
117 lr_min: 0.0,
118 warmup_steps: 0,
119 total_steps: 1000,
120 };
121
122 let scheduler = CosineScheduler::new(config);
123
124 assert!((scheduler.get_lr_at_step(0) - 1.0).abs() < 1e-6);
126
127 let lr_mid = scheduler.get_lr_at_step(500);
129 assert!((lr_mid - 0.5).abs() < 0.1);
130
131 assert!((scheduler.get_lr_at_step(1000) - 0.0).abs() < 1e-6);
133 }
134
135 #[test]
136 fn test_scheduler_stepping() {
137 let config = CosineSchedulerConfig {
138 lr_init: 1.0,
139 lr_min: 0.1,
140 warmup_steps: 10,
141 total_steps: 100,
142 };
143
144 let mut scheduler = CosineScheduler::new(config);
145
146 assert_eq!(scheduler.get_step(), 0);
147
148 scheduler.step();
149 assert_eq!(scheduler.get_step(), 1);
150
151 scheduler.step();
152 assert_eq!(scheduler.get_step(), 2);
153
154 let lr1 = scheduler.get_lr_at_step(5);
156 let lr2 = scheduler.get_lr_at_step(8);
157 assert!(lr2 > lr1);
158 }
159
160 #[test]
161 fn test_reset() {
162 let config = CosineSchedulerConfig::default();
163 let mut scheduler = CosineScheduler::new(config);
164
165 scheduler.step();
166 scheduler.step();
167 assert_eq!(scheduler.get_step(), 2);
168
169 scheduler.reset();
170 assert_eq!(scheduler.get_step(), 0);
171 }
172
173 #[test]
174 fn test_lr_never_exceeds_init() {
175 let config = CosineSchedulerConfig {
176 lr_init: 1.0,
177 lr_min: 0.1,
178 warmup_steps: 100,
179 total_steps: 1000,
180 };
181
182 let scheduler = CosineScheduler::new(config.clone());
183
184 for step in 0..=config.total_steps {
186 let lr = scheduler.get_lr_at_step(step);
187 assert!(lr <= config.lr_init + 1e-6, "LR {} exceeds max {} at step {}", lr, config.lr_init, step);
188 }
189
190 for step in config.warmup_steps..=config.total_steps {
192 let lr = scheduler.get_lr_at_step(step);
193 assert!(lr >= config.lr_min - 1e-6, "LR {} is below min {} at step {}", lr, config.lr_min, step);
194 }
195 }
196}