1use std::f32::consts::PI;
7
8#[derive(Debug, Clone)]
10pub enum SchedulerType {
11 Constant,
13
14 StepDecay { step_size: usize, gamma: f32 },
17
18 Exponential { gamma: f32 },
21
22 CosineAnnealing { t_max: usize, eta_min: f32 },
25
26 WarmupLinear {
30 warmup_steps: usize,
31 total_steps: usize,
32 },
33
34 ReduceOnPlateau {
37 factor: f32,
38 patience: usize,
39 min_lr: f32,
40 },
41}
42
43#[derive(Debug, Clone)]
49pub struct LearningRateScheduler {
50 scheduler_type: SchedulerType,
51 base_lr: f32,
52 current_lr: f32,
53 step_count: usize,
54 best_metric: f32,
55 patience_counter: usize,
56}
57
58impl LearningRateScheduler {
59 pub fn new(scheduler_type: SchedulerType, base_lr: f32) -> Self {
75 Self {
76 scheduler_type,
77 base_lr,
78 current_lr: base_lr,
79 step_count: 0,
80 best_metric: f32::INFINITY,
81 patience_counter: 0,
82 }
83 }
84
85 pub fn step(&mut self) -> f32 {
93 self.step_count += 1;
94 self.current_lr = self.calculate_lr();
95 self.current_lr
96 }
97
98 pub fn step_with_metric(&mut self, metric: f32) -> f32 {
106 self.step_count += 1;
107
108 match &self.scheduler_type {
109 SchedulerType::ReduceOnPlateau {
110 factor,
111 patience,
112 min_lr,
113 } => {
114 if metric < self.best_metric - 1e-8 {
116 self.best_metric = metric;
117 self.patience_counter = 0;
118 } else {
119 self.patience_counter += 1;
120
121 if self.patience_counter >= *patience {
123 self.current_lr = (self.current_lr * factor).max(*min_lr);
124 self.patience_counter = 0;
125 }
126 }
127 }
128 _ => {
129 self.current_lr = self.calculate_lr();
131 }
132 }
133
134 self.current_lr
135 }
136
137 pub fn get_lr(&self) -> f32 {
139 self.current_lr
140 }
141
142 pub fn reset(&mut self) {
144 self.current_lr = self.base_lr;
145 self.step_count = 0;
146 self.best_metric = f32::INFINITY;
147 self.patience_counter = 0;
148 }
149
150 fn calculate_lr(&self) -> f32 {
152 match &self.scheduler_type {
153 SchedulerType::Constant => self.base_lr,
154
155 SchedulerType::StepDecay { step_size, gamma } => {
156 let decay_factor = (*gamma).powi((self.step_count / step_size) as i32);
157 self.base_lr * decay_factor
158 }
159
160 SchedulerType::Exponential { gamma } => {
161 let decay_factor = (*gamma).powi(self.step_count as i32);
162 self.base_lr * decay_factor
163 }
164
165 SchedulerType::CosineAnnealing { t_max, eta_min } => {
166 let cycle_step = self.step_count % t_max;
167 let cos_term = (PI * cycle_step as f32 / *t_max as f32).cos();
168 eta_min + 0.5 * (self.base_lr - eta_min) * (1.0 + cos_term)
169 }
170
171 SchedulerType::WarmupLinear {
172 warmup_steps,
173 total_steps,
174 } => {
175 if self.step_count < *warmup_steps {
176 self.base_lr * (self.step_count as f32 / *warmup_steps as f32)
178 } else if self.step_count < *total_steps {
179 let remaining_steps = *total_steps - self.step_count;
181 let total_decay_steps = *total_steps - *warmup_steps;
182 self.base_lr * (remaining_steps as f32 / total_decay_steps as f32)
183 } else {
184 0.0
186 }
187 }
188
189 SchedulerType::ReduceOnPlateau { .. } => {
190 self.current_lr
192 }
193 }
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 const EPSILON: f32 = 1e-6;
202
203 fn assert_close(a: f32, b: f32, msg: &str) {
204 assert!((a - b).abs() < EPSILON, "{}: {} != {}", msg, a, b);
205 }
206
207 #[test]
208 fn test_constant_scheduler() {
209 let mut scheduler = LearningRateScheduler::new(SchedulerType::Constant, 0.01);
210
211 assert_close(scheduler.get_lr(), 0.01, "Initial LR");
212
213 for i in 1..=10 {
214 let lr = scheduler.step();
215 assert_close(lr, 0.01, &format!("Step {} LR", i));
216 }
217 }
218
219 #[test]
220 fn test_step_decay() {
221 let mut scheduler = LearningRateScheduler::new(
222 SchedulerType::StepDecay {
223 step_size: 5,
224 gamma: 0.5,
225 },
226 0.1,
227 );
228
229 assert_close(scheduler.get_lr(), 0.1, "Initial LR");
230
231 for i in 1..=4 {
233 let lr = scheduler.step();
234 assert_close(lr, 0.1, &format!("Step {} LR", i));
235 }
236
237 let lr = scheduler.step();
239 assert_close(lr, 0.05, "Step 5 LR (first decay)");
240
241 for i in 6..=9 {
243 let lr = scheduler.step();
244 assert_close(lr, 0.05, &format!("Step {} LR", i));
245 }
246
247 let lr = scheduler.step();
249 assert_close(lr, 0.025, "Step 10 LR (second decay)");
250 }
251
252 #[test]
253 fn test_exponential_decay() {
254 let mut scheduler =
255 LearningRateScheduler::new(SchedulerType::Exponential { gamma: 0.9 }, 0.1);
256
257 assert_close(scheduler.get_lr(), 0.1, "Initial LR");
258
259 let expected_lrs = vec![
260 0.1 * 0.9, 0.1 * 0.81, 0.1 * 0.729, ];
264
265 for (i, expected) in expected_lrs.iter().enumerate() {
266 let lr = scheduler.step();
267 assert_close(lr, *expected, &format!("Step {} LR", i + 1));
268 }
269 }
270
271 #[test]
272 fn test_cosine_annealing() {
273 let mut scheduler = LearningRateScheduler::new(
274 SchedulerType::CosineAnnealing {
275 t_max: 10,
276 eta_min: 0.0,
277 },
278 1.0,
279 );
280
281 assert_close(scheduler.get_lr(), 1.0, "Initial LR");
282
283 for _ in 1..=5 {
289 scheduler.step();
290 }
291 assert_close(scheduler.get_lr(), 0.5, "Mid-cycle LR (step 5)");
292
293 for _ in 6..=9 {
295 scheduler.step();
296 }
297 let lr_step9 = scheduler.get_lr();
298 assert!(
299 lr_step9 < 0.1,
300 "Near end of cycle LR (step 9) should be small: {}",
301 lr_step9
302 );
303
304 scheduler.step();
306 assert_close(
307 scheduler.get_lr(),
308 1.0,
309 "Restart at step 10 (cycle_step = 0)",
310 );
311
312 scheduler.step();
314 assert!(
315 scheduler.get_lr() < 1.0,
316 "Step 11 should be less than base LR"
317 );
318 }
319
320 #[test]
321 fn test_warmup_linear() {
322 let mut scheduler = LearningRateScheduler::new(
323 SchedulerType::WarmupLinear {
324 warmup_steps: 5,
325 total_steps: 10,
326 },
327 1.0,
328 );
329
330 assert_close(scheduler.get_lr(), 1.0, "Initial LR");
331
332 scheduler.step();
334 assert_close(scheduler.get_lr(), 0.2, "Step 1 (warmup)");
335
336 scheduler.step();
337 assert_close(scheduler.get_lr(), 0.4, "Step 2 (warmup)");
338
339 scheduler.step();
340 assert_close(scheduler.get_lr(), 0.6, "Step 3 (warmup)");
341
342 scheduler.step();
343 assert_close(scheduler.get_lr(), 0.8, "Step 4 (warmup)");
344
345 scheduler.step();
346 assert_close(scheduler.get_lr(), 1.0, "Step 5 (warmup end)");
347
348 scheduler.step();
350 assert_close(scheduler.get_lr(), 0.8, "Step 6 (decay)");
351
352 scheduler.step();
353 assert_close(scheduler.get_lr(), 0.6, "Step 7 (decay)");
354
355 scheduler.step();
356 assert_close(scheduler.get_lr(), 0.4, "Step 8 (decay)");
357
358 scheduler.step();
359 assert_close(scheduler.get_lr(), 0.2, "Step 9 (decay)");
360
361 scheduler.step();
362 assert_close(scheduler.get_lr(), 0.0, "Step 10 (decay end)");
363
364 scheduler.step();
366 assert_close(scheduler.get_lr(), 0.0, "Step 11 (after total)");
367 }
368
369 #[test]
370 fn test_reduce_on_plateau() {
371 let mut scheduler = LearningRateScheduler::new(
372 SchedulerType::ReduceOnPlateau {
373 factor: 0.5,
374 patience: 3,
375 min_lr: 0.0001,
376 },
377 0.01,
378 );
379
380 assert_close(scheduler.get_lr(), 0.01, "Initial LR");
381
382 scheduler.step_with_metric(1.0);
384 assert_close(
385 scheduler.get_lr(),
386 0.01,
387 "Step 1 (first metric, sets baseline)",
388 );
389
390 scheduler.step_with_metric(0.9);
391 assert_close(scheduler.get_lr(), 0.01, "Step 2 (improving)");
392
393 scheduler.step_with_metric(0.91);
395 assert_close(scheduler.get_lr(), 0.01, "Step 3 (plateau 1)");
396
397 scheduler.step_with_metric(0.92);
398 assert_close(scheduler.get_lr(), 0.01, "Step 4 (plateau 2)");
399
400 scheduler.step_with_metric(0.93);
403 assert_close(
404 scheduler.get_lr(),
405 0.005,
406 "Step 5 (patience exceeded, reduced)",
407 );
408
409 scheduler.step_with_metric(0.94); assert_close(scheduler.get_lr(), 0.005, "Step 6 (plateau 1 after reset)");
412
413 scheduler.step_with_metric(0.95); assert_close(scheduler.get_lr(), 0.005, "Step 7 (plateau 2)");
415
416 scheduler.step_with_metric(0.96); assert_close(scheduler.get_lr(), 0.0025, "Step 8 (reduced again)");
418
419 for _ in 0..20 {
421 scheduler.step_with_metric(1.0);
422 }
423 assert!(
424 scheduler.get_lr() >= 0.0001,
425 "LR should not go below min_lr"
426 );
427 }
428
429 #[test]
430 fn test_scheduler_reset() {
431 let mut scheduler =
432 LearningRateScheduler::new(SchedulerType::Exponential { gamma: 0.9 }, 0.1);
433
434 for _ in 0..5 {
436 scheduler.step();
437 }
438 assert!(scheduler.get_lr() < 0.1, "LR should have decayed");
439
440 scheduler.reset();
442 assert_close(scheduler.get_lr(), 0.1, "Reset LR");
443 assert_eq!(scheduler.step_count, 0, "Reset step count");
444 }
445
446 #[test]
447 fn test_scheduler_cloning() {
448 let scheduler1 = LearningRateScheduler::new(
449 SchedulerType::StepDecay {
450 step_size: 10,
451 gamma: 0.5,
452 },
453 0.01,
454 );
455
456 let mut scheduler2 = scheduler1.clone();
457
458 scheduler2.step();
460
461 assert_close(scheduler1.get_lr(), 0.01, "Original LR");
463 assert_close(scheduler2.get_lr(), 0.01, "Clone LR after step");
464 }
465
466 #[test]
467 fn test_multiple_scheduler_types() {
468 let schedulers = vec![
469 (SchedulerType::Constant, 0.01),
470 (
471 SchedulerType::StepDecay {
472 step_size: 5,
473 gamma: 0.9,
474 },
475 0.01,
476 ),
477 (SchedulerType::Exponential { gamma: 0.95 }, 0.01),
478 (
479 SchedulerType::CosineAnnealing {
480 t_max: 10,
481 eta_min: 0.001,
482 },
483 0.01,
484 ),
485 (
486 SchedulerType::WarmupLinear {
487 warmup_steps: 5,
488 total_steps: 20,
489 },
490 0.01,
491 ),
492 (
493 SchedulerType::ReduceOnPlateau {
494 factor: 0.5,
495 patience: 5,
496 min_lr: 0.0001,
497 },
498 0.01,
499 ),
500 ];
501
502 for (sched_type, base_lr) in schedulers {
503 let mut scheduler = LearningRateScheduler::new(sched_type, base_lr);
504
505 assert_close(scheduler.get_lr(), base_lr, "Initial LR for scheduler type");
507
508 let _ = scheduler.step();
510 assert!(scheduler.get_lr() >= 0.0, "LR should be non-negative");
511 }
512 }
513
514 #[test]
515 fn test_edge_cases() {
516 let mut scheduler = LearningRateScheduler::new(SchedulerType::Constant, 0.0);
518 assert_close(scheduler.get_lr(), 0.0, "Zero LR");
519 scheduler.step();
520 assert_close(scheduler.get_lr(), 0.0, "Zero LR after step");
521
522 let mut scheduler =
524 LearningRateScheduler::new(SchedulerType::Exponential { gamma: 0.1 }, 1.0);
525 for _ in 0..10 {
526 scheduler.step();
527 }
528 assert!(scheduler.get_lr() > 0.0, "LR should remain positive");
529 assert!(scheduler.get_lr() < 1e-8, "LR should be very small");
530 }
531}