1use std::f64::consts::PI;
2
3pub trait LearningRateScheduler {
5 fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64;
7
8 fn reset(&mut self);
10
11 fn name(&self) -> &'static str;
13}
14
15#[derive(Clone, Debug)]
17pub struct ConstantLR;
18
19impl LearningRateScheduler for ConstantLR {
20 fn get_lr(&mut self, _epoch: usize, base_lr: f64) -> f64 {
21 base_lr
22 }
23
24 fn reset(&mut self) {}
25
26 fn name(&self) -> &'static str {
27 "ConstantLR"
28 }
29}
30
31#[derive(Clone, Debug)]
33pub struct StepLR {
34 step_size: usize,
35 gamma: f64,
36}
37
38impl StepLR {
39 pub fn new(step_size: usize, gamma: f64) -> Self {
40 StepLR { step_size, gamma }
41 }
42}
43
44impl LearningRateScheduler for StepLR {
45 fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
46 let steps = epoch / self.step_size;
47 base_lr * self.gamma.powi(steps as i32)
48 }
49
50 fn reset(&mut self) {}
51
52 fn name(&self) -> &'static str {
53 "StepLR"
54 }
55}
56
57#[derive(Clone, Debug)]
59pub struct MultiStepLR {
60 milestones: Vec<usize>,
61 gamma: f64,
62}
63
64impl MultiStepLR {
65 pub fn new(milestones: Vec<usize>, gamma: f64) -> Self {
66 MultiStepLR { milestones, gamma }
67 }
68}
69
70impl LearningRateScheduler for MultiStepLR {
71 fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
72 let num_reductions = self.milestones.iter()
73 .filter(|&&milestone| epoch >= milestone)
74 .count();
75 base_lr * self.gamma.powi(num_reductions as i32)
76 }
77
78 fn reset(&mut self) {}
79
80 fn name(&self) -> &'static str {
81 "MultiStepLR"
82 }
83}
84
85#[derive(Clone, Debug)]
87pub struct ExponentialLR {
88 gamma: f64,
89}
90
91impl ExponentialLR {
92 pub fn new(gamma: f64) -> Self {
93 ExponentialLR { gamma }
94 }
95}
96
97impl LearningRateScheduler for ExponentialLR {
98 fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
99 base_lr * self.gamma.powi(epoch as i32)
100 }
101
102 fn reset(&mut self) {}
103
104 fn name(&self) -> &'static str {
105 "ExponentialLR"
106 }
107}
108
109#[derive(Clone, Debug)]
111pub struct CosineAnnealingLR {
112 t_max: usize,
113 eta_min: f64,
114 last_epoch: usize,
115}
116
117impl CosineAnnealingLR {
118 pub fn new(t_max: usize, eta_min: f64) -> Self {
119 CosineAnnealingLR {
120 t_max,
121 eta_min,
122 last_epoch: 0,
123 }
124 }
125}
126
127impl LearningRateScheduler for CosineAnnealingLR {
128 fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
129 self.last_epoch = epoch;
130 if epoch == 0 {
131 return base_lr;
132 }
133
134 let t = epoch % self.t_max;
135 self.eta_min + (base_lr - self.eta_min) *
136 (1.0 + (PI * t as f64 / self.t_max as f64).cos()) / 2.0
137 }
138
139 fn reset(&mut self) {
140 self.last_epoch = 0;
141 }
142
143 fn name(&self) -> &'static str {
144 "CosineAnnealingLR"
145 }
146}
147
148#[derive(Clone, Debug)]
150pub struct CosineAnnealingWarmRestarts {
151 t_0: usize,
152 t_mult: usize,
153 eta_min: f64,
154 last_restart: usize,
155 restart_count: usize,
156}
157
158impl CosineAnnealingWarmRestarts {
159 pub fn new(t_0: usize, t_mult: usize, eta_min: f64) -> Self {
160 CosineAnnealingWarmRestarts {
161 t_0,
162 t_mult,
163 eta_min,
164 last_restart: 0,
165 restart_count: 0,
166 }
167 }
168}
169
170impl LearningRateScheduler for CosineAnnealingWarmRestarts {
171 fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
172 if epoch == 0 {
173 return base_lr;
174 }
175
176 let t_cur = epoch - self.last_restart;
177 let t_i = self.t_0 * self.t_mult.pow(self.restart_count as u32);
178
179 if t_cur >= t_i {
180 self.last_restart = epoch;
181 self.restart_count += 1;
182 return base_lr;
183 }
184
185 self.eta_min + (base_lr - self.eta_min) *
186 (1.0 + (PI * t_cur as f64 / t_i as f64).cos()) / 2.0
187 }
188
189 fn reset(&mut self) {
190 self.last_restart = 0;
191 self.restart_count = 0;
192 }
193
194 fn name(&self) -> &'static str {
195 "CosineAnnealingWarmRestarts"
196 }
197}
198
199#[derive(Clone, Debug)]
201pub struct OneCycleLR {
202 max_lr: f64,
203 total_steps: usize,
204 pct_start: f64,
205 anneal_strategy: AnnealStrategy,
206 div_factor: f64,
207 final_div_factor: f64,
208}
209
210#[derive(Clone, Debug)]
211pub enum AnnealStrategy {
212 Cos,
213 Linear,
214}
215
216impl OneCycleLR {
217 pub fn new(max_lr: f64, total_steps: usize) -> Self {
218 OneCycleLR {
219 max_lr,
220 total_steps,
221 pct_start: 0.3,
222 anneal_strategy: AnnealStrategy::Cos,
223 div_factor: 25.0,
224 final_div_factor: 10000.0,
225 }
226 }
227
228 pub fn with_params(
229 max_lr: f64,
230 total_steps: usize,
231 pct_start: f64,
232 anneal_strategy: AnnealStrategy,
233 div_factor: f64,
234 final_div_factor: f64,
235 ) -> Self {
236 OneCycleLR {
237 max_lr,
238 total_steps,
239 pct_start,
240 anneal_strategy,
241 div_factor,
242 final_div_factor,
243 }
244 }
245}
246
247impl LearningRateScheduler for OneCycleLR {
248 fn get_lr(&mut self, epoch: usize, _base_lr: f64) -> f64 {
249 if epoch >= self.total_steps {
250 return self.max_lr / self.final_div_factor;
251 }
252
253 let _step_ratio = epoch as f64 / self.total_steps as f64;
254 let warmup_steps = (self.total_steps as f64 * self.pct_start) as usize;
255
256 if epoch < warmup_steps {
257 let warmup_ratio = epoch as f64 / warmup_steps as f64;
259 (self.max_lr / self.div_factor) +
260 (self.max_lr - self.max_lr / self.div_factor) * warmup_ratio
261 } else {
262 let anneal_ratio = (epoch - warmup_steps) as f64 /
264 (self.total_steps - warmup_steps) as f64;
265
266 match self.anneal_strategy {
267 AnnealStrategy::Cos => {
268 let cos_factor = (1.0 + (PI * anneal_ratio).cos()) / 2.0;
269 (self.max_lr / self.final_div_factor) +
270 (self.max_lr - self.max_lr / self.final_div_factor) * cos_factor
271 },
272 AnnealStrategy::Linear => {
273 self.max_lr - (self.max_lr - self.max_lr / self.final_div_factor) * anneal_ratio
274 }
275 }
276 }
277 }
278
279 fn reset(&mut self) {}
280
281 fn name(&self) -> &'static str {
282 "OneCycleLR"
283 }
284}
285
286#[derive(Clone, Debug)]
288pub struct ReduceLROnPlateau {
289 factor: f64,
290 patience: usize,
291 threshold: f64,
292 cooldown: usize,
293 min_lr: f64,
294 best_loss: f64,
295 wait_count: usize,
296 cooldown_counter: usize,
297 current_lr: f64,
298}
299
300impl ReduceLROnPlateau {
301 pub fn new(factor: f64, patience: usize) -> Self {
302 ReduceLROnPlateau {
303 factor,
304 patience,
305 threshold: 1e-4,
306 cooldown: 0,
307 min_lr: 0.0,
308 best_loss: f64::INFINITY,
309 wait_count: 0,
310 cooldown_counter: 0,
311 current_lr: 0.0,
312 }
313 }
314
315 pub fn with_params(
316 factor: f64,
317 patience: usize,
318 threshold: f64,
319 cooldown: usize,
320 min_lr: f64,
321 ) -> Self {
322 ReduceLROnPlateau {
323 factor,
324 patience,
325 threshold,
326 cooldown,
327 min_lr,
328 best_loss: f64::INFINITY,
329 wait_count: 0,
330 cooldown_counter: 0,
331 current_lr: 0.0,
332 }
333 }
334
335 pub fn step(&mut self, val_loss: f64, base_lr: f64) -> f64 {
337 if self.current_lr == 0.0 {
338 self.current_lr = base_lr;
339 }
340
341 if self.cooldown_counter > 0 {
342 self.cooldown_counter -= 1;
343 return self.current_lr;
344 }
345
346 if val_loss < self.best_loss - self.threshold {
347 self.best_loss = val_loss;
348 self.wait_count = 0;
349 } else {
350 self.wait_count += 1;
351
352 if self.wait_count >= self.patience {
353 let new_lr = self.current_lr * self.factor;
354 self.current_lr = new_lr.max(self.min_lr);
355 self.wait_count = 0;
356 self.cooldown_counter = self.cooldown;
357 println!("ReduceLROnPlateau: reducing learning rate to {:.2e}", self.current_lr);
358 }
359 }
360
361 self.current_lr
362 }
363}
364
365impl LearningRateScheduler for ReduceLROnPlateau {
366 fn get_lr(&mut self, _epoch: usize, base_lr: f64) -> f64 {
367 if self.current_lr == 0.0 {
368 self.current_lr = base_lr;
369 }
370 self.current_lr
371 }
372
373 fn reset(&mut self) {
374 self.best_loss = f64::INFINITY;
375 self.wait_count = 0;
376 self.cooldown_counter = 0;
377 self.current_lr = 0.0;
378 }
379
380 fn name(&self) -> &'static str {
381 "ReduceLROnPlateau"
382 }
383}
384
385#[derive(Clone, Debug)]
387pub struct LinearLR {
388 start_factor: f64,
389 end_factor: f64,
390 total_iters: usize,
391}
392
393impl LinearLR {
394 pub fn new(start_factor: f64, end_factor: f64, total_iters: usize) -> Self {
395 LinearLR {
396 start_factor,
397 end_factor,
398 total_iters,
399 }
400 }
401}
402
403impl LearningRateScheduler for LinearLR {
404 fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
405 if epoch >= self.total_iters {
406 return base_lr * self.end_factor;
407 }
408
409 let progress = epoch as f64 / self.total_iters as f64;
410 let factor = self.start_factor +
411 (self.end_factor - self.start_factor) * progress;
412
413 base_lr * factor
414 }
415
416 fn reset(&mut self) {}
417
418 fn name(&self) -> &'static str {
419 "LinearLR"
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
428 fn test_constant_lr() {
429 let mut scheduler = ConstantLR;
430 let base_lr = 0.01;
431
432 assert_eq!(scheduler.get_lr(0, base_lr), base_lr);
433 assert_eq!(scheduler.get_lr(10, base_lr), base_lr);
434 assert_eq!(scheduler.get_lr(100, base_lr), base_lr);
435 }
436
437 #[test]
438 fn test_step_lr() {
439 let mut scheduler = StepLR::new(10, 0.1);
440 let base_lr = 0.01;
441
442 assert_eq!(scheduler.get_lr(0, base_lr), base_lr);
443 assert_eq!(scheduler.get_lr(9, base_lr), base_lr);
444 assert!((scheduler.get_lr(10, base_lr) - base_lr * 0.1).abs() < 1e-15);
445 assert!((scheduler.get_lr(20, base_lr) - base_lr * 0.01).abs() < 1e-15);
446 }
447
448 #[test]
449 fn test_exponential_lr() {
450 let mut scheduler = ExponentialLR::new(0.9);
451 let base_lr = 0.01;
452
453 assert_eq!(scheduler.get_lr(0, base_lr), base_lr);
454 assert!((scheduler.get_lr(1, base_lr) - base_lr * 0.9).abs() < 1e-10);
455 assert!((scheduler.get_lr(2, base_lr) - base_lr * 0.81).abs() < 1e-10);
456 }
457
458 #[test]
459 fn test_multi_step_lr() {
460 let mut scheduler = MultiStepLR::new(vec![10, 20], 0.1);
461 let base_lr = 0.01;
462
463 assert_eq!(scheduler.get_lr(5, base_lr), base_lr);
464 assert!((scheduler.get_lr(10, base_lr) - base_lr * 0.1).abs() < 1e-15);
465 assert!((scheduler.get_lr(15, base_lr) - base_lr * 0.1).abs() < 1e-15);
466 assert!((scheduler.get_lr(20, base_lr) - base_lr * 0.01).abs() < 1e-15);
467 }
468
469 #[test]
470 fn test_one_cycle_lr() {
471 let mut scheduler = OneCycleLR::new(0.1, 100);
472 let base_lr = 0.01;
473
474 let lr_0 = scheduler.get_lr(0, base_lr);
475 let lr_30 = scheduler.get_lr(30, base_lr); let lr_100 = scheduler.get_lr(100, base_lr); assert!(lr_0 < lr_30);
479 assert!(lr_100 < lr_0);
480 assert!(lr_30 <= 0.1);
481 }
482
483 #[test]
484 fn test_reduce_lr_on_plateau() {
485 let mut scheduler = ReduceLROnPlateau::new(0.5, 2);
486 let base_lr = 0.01;
487
488 let lr1 = scheduler.step(1.0, base_lr);
490 assert_eq!(lr1, base_lr);
491
492 let lr2 = scheduler.step(0.8, base_lr);
494 assert_eq!(lr2, base_lr);
495
496 let lr3 = scheduler.step(0.9, base_lr);
498 let lr4 = scheduler.step(0.9, base_lr);
499 let lr5 = scheduler.step(0.9, base_lr);
500
501 assert!(lr5 < base_lr);
502 assert!((lr5 - base_lr * 0.5).abs() < 1e-10);
503 }
504
505 #[test]
506 fn test_linear_lr() {
507 let mut scheduler = LinearLR::new(1.0, 0.1, 10);
508 let base_lr = 0.01;
509
510 assert_eq!(scheduler.get_lr(0, base_lr), base_lr);
511 assert!((scheduler.get_lr(5, base_lr) - base_lr * 0.55).abs() < 1e-10);
512 assert!((scheduler.get_lr(10, base_lr) - base_lr * 0.1).abs() < 1e-10);
513 }
514}