1pub trait LRScheduler: Send + Sync {
70 fn get_lr(&self, step: usize) -> f32;
72 fn step(&mut self);
74}
75
76#[derive(Debug)]
81pub struct LinearScheduler {
82 base_lr: f32,
83 warmup_steps: usize,
84 total_steps: usize,
85 current_step: usize,
86}
87
88impl LinearScheduler {
89 pub fn new(base_lr: f32, warmup_steps: usize, total_steps: usize) -> Self {
90 Self {
91 base_lr,
92 warmup_steps,
93 total_steps,
94 current_step: 0,
95 }
96 }
97}
98
99impl LRScheduler for LinearScheduler {
100 fn get_lr(&self, step: usize) -> f32 {
101 if step < self.warmup_steps {
102 self.base_lr * (step as f32) / (self.warmup_steps as f32)
103 } else {
104 let progress =
105 (step - self.warmup_steps) as f32 / (self.total_steps - self.warmup_steps) as f32;
106 self.base_lr * (1.0 - progress).max(0.0)
107 }
108 }
109
110 fn step(&mut self) {
111 self.current_step += 1;
112 }
113}
114
115#[derive(Debug)]
121pub struct CosineScheduler {
122 base_lr: f32,
123 warmup_steps: usize,
124 total_steps: usize,
125 current_step: usize,
126 min_lr: f32,
127}
128
129impl CosineScheduler {
130 pub fn new(base_lr: f32, warmup_steps: usize, total_steps: usize, min_lr: f32) -> Self {
131 Self {
132 base_lr,
133 warmup_steps,
134 total_steps,
135 current_step: 0,
136 min_lr,
137 }
138 }
139}
140
141impl LRScheduler for CosineScheduler {
142 fn get_lr(&self, step: usize) -> f32 {
143 use std::f32::consts::PI;
144
145 if step < self.warmup_steps {
146 self.base_lr * (step as f32) / (self.warmup_steps as f32)
147 } else {
148 let progress =
149 (step - self.warmup_steps) as f32 / (self.total_steps - self.warmup_steps) as f32;
150 let cosine_decay = 0.5 * (1.0 + (PI * progress).cos());
151 self.min_lr + (self.base_lr - self.min_lr) * cosine_decay
152 }
153 }
154
155 fn step(&mut self) {
156 self.current_step += 1;
157 }
158}
159
160#[derive(Debug)]
168pub struct PolynomialScheduler {
169 base_lr: f32,
170 warmup_steps: usize,
171 total_steps: usize,
172 current_step: usize,
173 min_lr: f32,
174 power: f32,
175}
176
177impl PolynomialScheduler {
178 pub fn new(
179 base_lr: f32,
180 warmup_steps: usize,
181 total_steps: usize,
182 min_lr: f32,
183 power: f32,
184 ) -> Self {
185 Self {
186 base_lr,
187 warmup_steps,
188 total_steps,
189 current_step: 0,
190 min_lr,
191 power,
192 }
193 }
194}
195
196impl LRScheduler for PolynomialScheduler {
197 fn get_lr(&self, step: usize) -> f32 {
198 if step < self.warmup_steps {
199 self.base_lr * (step as f32) / (self.warmup_steps as f32)
200 } else {
201 let progress =
202 (step - self.warmup_steps) as f32 / (self.total_steps - self.warmup_steps) as f32;
203 let decay_factor = (1.0 - progress.min(1.0)).powf(self.power);
204 self.min_lr + (self.base_lr - self.min_lr) * decay_factor
205 }
206 }
207
208 fn step(&mut self) {
209 self.current_step += 1;
210 }
211}
212
213#[derive(Debug)]
215pub struct ConstantWithWarmupScheduler {
216 base_lr: f32,
217 warmup_steps: usize,
218 current_step: usize,
219}
220
221impl ConstantWithWarmupScheduler {
222 pub fn new(base_lr: f32, warmup_steps: usize) -> Self {
223 Self {
224 base_lr,
225 warmup_steps,
226 current_step: 0,
227 }
228 }
229}
230
231impl LRScheduler for ConstantWithWarmupScheduler {
232 fn get_lr(&self, step: usize) -> f32 {
233 if step < self.warmup_steps {
234 self.base_lr * (step as f32) / (self.warmup_steps as f32)
235 } else {
236 self.base_lr
237 }
238 }
239
240 fn step(&mut self) {
241 self.current_step += 1;
242 }
243}
244
245#[derive(Debug)]
247pub struct ExponentialScheduler {
248 base_lr: f32,
249 warmup_steps: usize,
250 current_step: usize,
251 decay_rate: f32,
252 decay_steps: usize,
253}
254
255impl ExponentialScheduler {
256 pub fn new(base_lr: f32, warmup_steps: usize, decay_rate: f32, decay_steps: usize) -> Self {
257 Self {
258 base_lr,
259 warmup_steps,
260 current_step: 0,
261 decay_rate,
262 decay_steps,
263 }
264 }
265}
266
267impl LRScheduler for ExponentialScheduler {
268 fn get_lr(&self, step: usize) -> f32 {
269 if step < self.warmup_steps {
270 self.base_lr * (step as f32) / (self.warmup_steps as f32)
271 } else {
272 let decay_step = (step - self.warmup_steps) / self.decay_steps;
273 self.base_lr * self.decay_rate.powf(decay_step as f32)
274 }
275 }
276
277 fn step(&mut self) {
278 self.current_step += 1;
279 }
280}
281
282#[derive(Debug)]
284pub struct StepScheduler {
285 base_lr: f32,
286 warmup_steps: usize,
287 current_step: usize,
288 step_size: usize,
289 gamma: f32,
290}
291
292impl StepScheduler {
293 pub fn new(base_lr: f32, warmup_steps: usize, step_size: usize, gamma: f32) -> Self {
294 Self {
295 base_lr,
296 warmup_steps,
297 current_step: 0,
298 step_size,
299 gamma,
300 }
301 }
302}
303
304impl LRScheduler for StepScheduler {
305 fn get_lr(&self, step: usize) -> f32 {
306 if step < self.warmup_steps {
307 self.base_lr * (step as f32) / (self.warmup_steps as f32)
308 } else {
309 let decay_step = (step - self.warmup_steps) / self.step_size;
310 self.base_lr * self.gamma.powf(decay_step as f32)
311 }
312 }
313
314 fn step(&mut self) {
315 self.current_step += 1;
316 }
317}
318
319#[derive(Debug)]
325pub struct OneCycleScheduler {
326 max_lr: f32,
327 final_lr: f32,
328 total_steps: usize,
329 pct_start: f32,
330 current_step: usize,
331}
332
333impl OneCycleScheduler {
334 pub fn new(max_lr: f32, total_steps: usize, pct_start: f32, final_lr: f32) -> Self {
335 Self {
336 max_lr,
337 final_lr,
338 total_steps,
339 pct_start: pct_start.clamp(0.0, 1.0),
340 current_step: 0,
341 }
342 }
343}
344
345impl LRScheduler for OneCycleScheduler {
346 fn get_lr(&self, step: usize) -> f32 {
347 use std::f32::consts::PI;
348
349 let step = step.min(self.total_steps);
350 let pct = step as f32 / self.total_steps as f32;
351
352 if pct <= self.pct_start {
353 let phase_pct = pct / self.pct_start;
355 let cosine_term = 0.5 * (1.0 - (PI * phase_pct).cos());
356 self.final_lr + (self.max_lr - self.final_lr) * cosine_term
357 } else {
358 let remaining_pct = (pct - self.pct_start) / (1.0 - self.pct_start);
360 let cosine_term = 0.5 * (1.0 + (PI * remaining_pct).cos());
361 self.final_lr + (self.max_lr - self.final_lr) * cosine_term
362 }
363 }
364
365 fn step(&mut self) {
366 self.current_step += 1;
367 }
368}
369
370#[derive(Debug)]
375pub struct CosineWithRestartsScheduler {
376 base_lr: f32,
377 min_lr: f32,
378 t_0: usize,
379 t_mult: f32,
380 current_step: usize,
381 next_restart: usize,
382 current_t: usize,
383}
384
385impl CosineWithRestartsScheduler {
386 pub fn new(base_lr: f32, min_lr: f32, t_0: usize, t_mult: f32) -> Self {
387 Self {
388 base_lr,
389 min_lr,
390 t_0,
391 t_mult,
392 current_step: 0,
393 next_restart: t_0,
394 current_t: t_0,
395 }
396 }
397}
398
399impl LRScheduler for CosineWithRestartsScheduler {
400 fn get_lr(&self, step: usize) -> f32 {
401 use std::f32::consts::PI;
402
403 let mut step_in_cycle = step;
404 let mut cycle_length = self.t_0;
405
406 while step_in_cycle >= cycle_length {
408 step_in_cycle -= cycle_length;
409 cycle_length = (cycle_length as f32 * self.t_mult) as usize;
410 }
411
412 let progress = step_in_cycle as f32 / cycle_length as f32;
413 let cosine_decay = 0.5 * (1.0 + (PI * progress).cos());
414
415 self.min_lr + (self.base_lr - self.min_lr) * cosine_decay
416 }
417
418 fn step(&mut self) {
419 self.current_step += 1;
420
421 if self.current_step >= self.next_restart {
422 self.current_t = (self.current_t as f32 * self.t_mult) as usize;
423 self.next_restart += self.current_t;
424 }
425 }
426}
427
428#[derive(Debug)]
433pub struct CyclicalScheduler {
434 base_lr: f32,
435 max_lr: f32,
436 step_size_up: usize,
437 step_size_down: usize,
438 current_step: usize,
439 mode: CyclicalMode,
440}
441
442#[derive(Debug, Clone)]
443pub enum CyclicalMode {
444 Triangular,
445 Triangular2,
446 ExpRange(f32), }
448
449impl CyclicalScheduler {
450 pub fn new(
451 base_lr: f32,
452 max_lr: f32,
453 step_size_up: usize,
454 step_size_down: usize,
455 mode: CyclicalMode,
456 ) -> Self {
457 Self {
458 base_lr,
459 max_lr,
460 step_size_up,
461 step_size_down,
462 current_step: 0,
463 mode,
464 }
465 }
466}
467
468impl LRScheduler for CyclicalScheduler {
469 fn get_lr(&self, step: usize) -> f32 {
470 let cycle_length = self.step_size_up + self.step_size_down;
471 let cycle = (step / cycle_length) + 1;
472 let x = (step % cycle_length) as f32;
473
474 let (amplitude, _phase) = if x <= self.step_size_up as f32 {
475 (x / self.step_size_up as f32, 1.0)
477 } else {
478 (
480 (self.step_size_down as f32 - (x - self.step_size_up as f32))
481 / self.step_size_down as f32,
482 1.0,
483 )
484 };
485
486 let scale_factor = match &self.mode {
487 CyclicalMode::Triangular => 1.0,
488 CyclicalMode::Triangular2 => 1.0 / (2.0_f32.powi((cycle - 1) as i32)),
489 CyclicalMode::ExpRange(gamma) => gamma.powi(step as i32),
490 };
491
492 self.base_lr + (self.max_lr - self.base_lr) * amplitude * scale_factor
493 }
494
495 fn step(&mut self) {
496 self.current_step += 1;
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503
504 #[test]
505 fn test_linear_scheduler() {
506 let scheduler = LinearScheduler::new(1e-3, 100, 1000);
507
508 assert_eq!(scheduler.get_lr(0), 0.0);
510 assert_eq!(scheduler.get_lr(50), 5e-4);
511 assert_eq!(scheduler.get_lr(100), 1e-3);
512
513 assert_eq!(scheduler.get_lr(550), 5e-4);
515 assert_eq!(scheduler.get_lr(1000), 0.0);
516 }
517
518 #[test]
519 fn test_cosine_scheduler() {
520 let scheduler = CosineScheduler::new(1e-3, 100, 1000, 1e-5);
521
522 assert_eq!(scheduler.get_lr(0), 0.0);
524 assert_eq!(scheduler.get_lr(50), 5e-4);
525 assert_eq!(scheduler.get_lr(100), 1e-3);
526
527 let mid_lr = scheduler.get_lr(550);
529 assert!(mid_lr > 1e-5 && mid_lr < 1e-3);
530
531 let end_lr = scheduler.get_lr(1000);
533 assert!((end_lr - 1e-5).abs() < 1e-6);
534 }
535
536 #[test]
537 fn test_polynomial_scheduler() {
538 let scheduler = PolynomialScheduler::new(1e-3, 100, 1000, 1e-5, 2.0);
539
540 assert_eq!(scheduler.get_lr(0), 0.0);
542 assert_eq!(scheduler.get_lr(100), 1e-3);
543
544 let mid_lr = scheduler.get_lr(550);
546 assert!(mid_lr > 1e-5 && mid_lr < 1e-3);
547 }
548
549 #[test]
550 fn test_constant_with_warmup_scheduler() {
551 let scheduler = ConstantWithWarmupScheduler::new(1e-3, 100);
552
553 assert_eq!(scheduler.get_lr(0), 0.0);
555 assert_eq!(scheduler.get_lr(50), 5e-4);
556 assert_eq!(scheduler.get_lr(100), 1e-3);
557
558 assert_eq!(scheduler.get_lr(200), 1e-3);
560 assert_eq!(scheduler.get_lr(1000), 1e-3);
561 }
562
563 #[test]
564 fn test_exponential_scheduler() {
565 let scheduler = ExponentialScheduler::new(1e-3, 100, 0.9, 100);
566
567 assert_eq!(scheduler.get_lr(0), 0.0);
569 assert_eq!(scheduler.get_lr(100), 1e-3);
570
571 assert_eq!(scheduler.get_lr(200), 1e-3 * 0.9);
573 assert_eq!(scheduler.get_lr(300), 1e-3 * 0.9 * 0.9);
574 }
575
576 #[test]
577 fn test_step_scheduler() {
578 let scheduler = StepScheduler::new(1e-3, 100, 200, 0.5);
579
580 assert_eq!(scheduler.get_lr(0), 0.0);
582 assert_eq!(scheduler.get_lr(100), 1e-3);
583
584 assert_eq!(scheduler.get_lr(250), 1e-3); assert_eq!(scheduler.get_lr(300), 1e-3 * 0.5); assert_eq!(scheduler.get_lr(500), 1e-3 * 0.5 * 0.5); }
589
590 #[test]
591 fn test_onecycle_scheduler() {
592 let scheduler = OneCycleScheduler::new(1e-2, 1000, 0.3, 1e-5);
593
594 assert_eq!(scheduler.get_lr(0), 1e-5);
596
597 let peak_lr = scheduler.get_lr(150);
599 assert!(peak_lr > 5e-3);
600
601 let end_lr = scheduler.get_lr(1000);
603 assert!((end_lr - 1e-5).abs() < 1e-6);
604 }
605
606 #[test]
607 fn test_cosine_with_restarts_scheduler() {
608 let scheduler = CosineWithRestartsScheduler::new(1e-3, 1e-5, 100, 2.0);
609
610 assert!((scheduler.get_lr(0) - 1e-3).abs() < 1e-6);
612
613 let mid_lr = scheduler.get_lr(50);
615 assert!(mid_lr > 1e-5 && mid_lr < 1e-3);
616
617 let near_end_lr = scheduler.get_lr(99);
619 assert!(near_end_lr < 2e-4);
620
621 let restart_lr = scheduler.get_lr(100);
623 assert!(restart_lr > 5e-4);
624 }
625
626 #[test]
627 fn test_cyclical_scheduler() {
628 let scheduler = CyclicalScheduler::new(1e-4, 1e-3, 50, 50, CyclicalMode::Triangular);
629
630 assert!((scheduler.get_lr(0) - 1e-4).abs() < 1e-6);
632
633 assert!((scheduler.get_lr(50) - 1e-3).abs() < 1e-6);
635
636 assert!((scheduler.get_lr(100) - 1e-4).abs() < 1e-6);
638
639 assert!((scheduler.get_lr(150) - 1e-3).abs() < 1e-6);
641 }
642}
643
644#[derive(Debug, Clone)]
650pub struct AdaptiveScheduler {
651 current_lr: f32,
653 factor: f32,
655 patience: usize,
657 threshold: f32,
659 min_lr: f32,
661 mode: String,
663 epochs_since_improvement: usize,
665 best_metric: Option<f32>,
667 current_step: usize,
669}
670
671impl AdaptiveScheduler {
672 pub fn new(
691 initial_lr: f32,
692 factor: f32,
693 patience: usize,
694 threshold: f32,
695 min_lr: f32,
696 mode: &str,
697 ) -> Self {
698 assert!(
699 factor > 0.0 && factor < 1.0,
700 "Factor must be between 0 and 1"
701 );
702 assert!(patience > 0, "Patience must be positive");
703 assert!(threshold >= 0.0, "Threshold must be non-negative");
704 assert!(min_lr >= 0.0, "Min LR must be non-negative");
705 assert!(mode == "min" || mode == "max", "Mode must be min or max");
706
707 Self {
708 current_lr: initial_lr,
709 factor,
710 patience,
711 threshold,
712 min_lr,
713 mode: mode.to_string(),
714 epochs_since_improvement: 0,
715 best_metric: None,
716 current_step: 0,
717 }
718 }
719
720 pub fn step_with_metric(&mut self, metric: f32) -> (f32, bool) {
723 self.current_step += 1;
724 let mut lr_reduced = false;
725
726 let is_improvement = match self.best_metric {
727 None => {
728 self.best_metric = Some(metric);
730 true
731 },
732 Some(best) => {
733 let improvement = if self.mode == "min" {
734 (best - metric) / best.abs().max(1e-8) > self.threshold
736 } else {
737 (metric - best) / best.abs().max(1e-8) > self.threshold
739 };
740
741 if improvement {
742 self.best_metric = Some(metric);
743 }
744
745 improvement
746 },
747 };
748
749 if is_improvement {
750 self.epochs_since_improvement = 0;
751 } else {
752 self.epochs_since_improvement += 1;
753
754 if self.epochs_since_improvement >= self.patience {
755 let new_lr = (self.current_lr * self.factor).max(self.min_lr);
757 if new_lr < self.current_lr {
758 self.current_lr = new_lr;
759 lr_reduced = true;
760 self.epochs_since_improvement = 0; }
762 }
763 }
764
765 (self.current_lr, lr_reduced)
766 }
767
768 pub fn get_current_lr(&self) -> f32 {
770 self.current_lr
771 }
772
773 pub fn get_best_metric(&self) -> Option<f32> {
775 self.best_metric
776 }
777
778 pub fn get_epochs_since_improvement(&self) -> usize {
780 self.epochs_since_improvement
781 }
782
783 pub fn reset(&mut self) {
785 self.epochs_since_improvement = 0;
786 self.best_metric = None;
787 self.current_step = 0;
788 }
789
790 pub fn set_lr(&mut self, lr: f32) {
792 self.current_lr = lr;
793 }
794}
795
796impl LRScheduler for AdaptiveScheduler {
797 fn get_lr(&self, _step: usize) -> f32 {
798 self.current_lr
799 }
800
801 fn step(&mut self) {
802 }
805}
806
807pub struct CompositeScheduler {
812 schedulers: Vec<Box<dyn LRScheduler>>,
813 step_boundaries: Vec<usize>,
814 current_step: usize,
815 #[allow(dead_code)]
816 global_step_offset: usize,
817}
818
819impl CompositeScheduler {
820 pub fn new(schedulers: Vec<Box<dyn LRScheduler>>, step_boundaries: Vec<usize>) -> Self {
838 assert_eq!(
839 schedulers.len(),
840 step_boundaries.len(),
841 "Number of schedulers must match number of boundaries"
842 );
843 assert!(
844 !schedulers.is_empty(),
845 "Must provide at least one scheduler"
846 );
847
848 Self {
849 schedulers,
850 step_boundaries,
851 current_step: 0,
852 global_step_offset: 0,
853 }
854 }
855
856 fn get_active_scheduler_index(&self, step: usize) -> usize {
857 for (i, &boundary) in self.step_boundaries.iter().enumerate() {
858 if step < boundary {
859 return i;
860 }
861 }
862 self.schedulers.len() - 1
863 }
864
865 fn get_local_step(&self, global_step: usize, scheduler_index: usize) -> usize {
866 if scheduler_index == 0 {
867 global_step
868 } else {
869 global_step - self.step_boundaries[scheduler_index - 1]
870 }
871 }
872}
873
874impl LRScheduler for CompositeScheduler {
875 fn get_lr(&self, step: usize) -> f32 {
876 let scheduler_idx = self.get_active_scheduler_index(step);
877 let local_step = self.get_local_step(step, scheduler_idx);
878 self.schedulers[scheduler_idx].get_lr(local_step)
879 }
880
881 fn step(&mut self) {
882 self.current_step += 1;
883 let _scheduler_idx = self.get_active_scheduler_index(self.current_step);
884 }
886}
887
888pub struct PhaseBasedScheduler {
892 phases: Vec<Phase>,
893 current_phase: usize,
894 current_step: usize,
895 phase_start_step: usize,
896}
897
898pub struct Phase {
899 pub name: String,
900 pub scheduler: Box<dyn LRScheduler>,
901 pub duration_steps: usize,
902 pub lr_multiplier: f32,
903}
904
905impl PhaseBasedScheduler {
906 pub fn new(phases: Vec<Phase>) -> Self {
935 assert!(!phases.is_empty(), "Must provide at least one phase");
936
937 Self {
938 phases,
939 current_phase: 0,
940 current_step: 0,
941 phase_start_step: 0,
942 }
943 }
944
945 pub fn get_current_phase(&self) -> &str {
947 &self.phases[self.current_phase].name
948 }
949
950 pub fn get_current_phase_index(&self) -> usize {
952 self.current_phase
953 }
954
955 pub fn is_complete(&self) -> bool {
957 self.current_phase >= self.phases.len()
958 }
959
960 fn update_phase(&mut self, step: usize) {
961 while self.current_phase < self.phases.len() {
962 let phase_end = self.phase_start_step + self.phases[self.current_phase].duration_steps;
963
964 if step < phase_end {
965 break; }
967
968 self.current_phase += 1;
970 self.phase_start_step = phase_end;
971 }
972 }
973}
974
975impl LRScheduler for PhaseBasedScheduler {
976 fn get_lr(&self, step: usize) -> f32 {
977 if self.current_phase >= self.phases.len() {
978 return 0.0; }
980
981 let phase = &self.phases[self.current_phase];
982 let phase_step = step - self.phase_start_step;
983 let base_lr = phase.scheduler.get_lr(phase_step);
984
985 base_lr * phase.lr_multiplier
986 }
987
988 fn step(&mut self) {
989 self.current_step += 1;
990 self.update_phase(self.current_step);
991 }
992}
993
994pub struct DynamicScheduler {
999 primary_scheduler: Box<dyn LRScheduler>,
1000 fallback_scheduler: Box<dyn LRScheduler>,
1001 current_scheduler: usize, switch_condition: SwitchCondition,
1003 metrics_window: Vec<f32>,
1004 window_size: usize,
1005 current_step: usize,
1006}
1007
1008#[derive(Debug)]
1009pub enum SwitchCondition {
1010 LossPlateauSteps(usize),
1012 GradientNormThreshold(f32),
1014 StepThreshold(usize),
1016 LossIncreaseFactor(f32),
1018}
1019
1020impl DynamicScheduler {
1021 pub fn new(
1023 primary_scheduler: Box<dyn LRScheduler>,
1024 fallback_scheduler: Box<dyn LRScheduler>,
1025 switch_condition: SwitchCondition,
1026 window_size: usize,
1027 ) -> Self {
1028 Self {
1029 primary_scheduler,
1030 fallback_scheduler,
1031 current_scheduler: 0,
1032 switch_condition,
1033 metrics_window: Vec::with_capacity(window_size),
1034 window_size,
1035 current_step: 0,
1036 }
1037 }
1038
1039 pub fn update_metric(&mut self, metric: f32) {
1041 self.metrics_window.push(metric);
1042 if self.metrics_window.len() > self.window_size {
1043 self.metrics_window.remove(0);
1044 }
1045
1046 if self.current_scheduler == 0 && self.should_switch() {
1048 self.current_scheduler = 1;
1049 }
1050 }
1051
1052 fn should_switch(&self) -> bool {
1053 match &self.switch_condition {
1054 SwitchCondition::LossPlateauSteps(steps) => {
1055 if self.metrics_window.len() < *steps {
1056 return false;
1057 }
1058
1059 let recent_avg =
1060 self.metrics_window.iter().rev().take(*steps).sum::<f32>() / *steps as f32;
1061 let older_avg =
1062 self.metrics_window.iter().take(self.metrics_window.len() - steps).sum::<f32>()
1063 / (self.metrics_window.len() - steps) as f32;
1064
1065 recent_avg >= older_avg * 0.995 },
1067 SwitchCondition::StepThreshold(step) => self.current_step >= *step,
1068 SwitchCondition::LossIncreaseFactor(factor) => {
1069 if self.metrics_window.len() < 2 {
1070 return false;
1071 }
1072 let latest = self.metrics_window[self.metrics_window.len() - 1];
1073 let previous = self.metrics_window[self.metrics_window.len() - 2];
1074 latest > previous * factor
1075 },
1076 SwitchCondition::GradientNormThreshold(_) => false, }
1078 }
1079
1080 pub fn get_active_scheduler(&self) -> &str {
1082 if self.current_scheduler == 0 {
1083 "primary"
1084 } else {
1085 "fallback"
1086 }
1087 }
1088}
1089
1090impl LRScheduler for DynamicScheduler {
1091 fn get_lr(&self, step: usize) -> f32 {
1092 if self.current_scheduler == 0 {
1093 self.primary_scheduler.get_lr(step)
1094 } else {
1095 self.fallback_scheduler.get_lr(step)
1096 }
1097 }
1098
1099 fn step(&mut self) {
1100 self.current_step += 1;
1101 if self.current_scheduler == 0 {
1102 self.primary_scheduler.step();
1103 } else {
1104 self.fallback_scheduler.step();
1105 }
1106 }
1107}
1108
1109pub struct TaskSpecificScheduler {
1111 scheduler: Box<dyn LRScheduler>,
1112 task_type: TaskType,
1113 current_step: usize,
1114}
1115
1116#[derive(Debug)]
1117pub enum TaskType {
1118 LanguageModelPretraining,
1120 FineTuning,
1122 ComputerVision,
1124 ReinforcementLearning,
1126 GANTraining,
1128}
1129
1130impl TaskSpecificScheduler {
1131 pub fn new(task_type: TaskType, base_lr: f32, total_steps: usize) -> Self {
1133 let scheduler: Box<dyn LRScheduler> = match task_type {
1134 TaskType::LanguageModelPretraining => {
1135 Box::new(CosineScheduler::new(
1136 base_lr,
1137 (total_steps as f32 * 0.06) as usize, total_steps,
1139 base_lr * 0.1, ))
1141 },
1142 TaskType::FineTuning => {
1143 Box::new(LinearScheduler::new(
1144 base_lr * 0.1, (total_steps as f32 * 0.1) as usize, total_steps,
1147 ))
1148 },
1149 TaskType::ComputerVision => {
1150 Box::new(StepScheduler::new(
1151 base_lr,
1152 (total_steps as f32 * 0.05) as usize, total_steps / 3, 0.1, ))
1156 },
1157 TaskType::ReinforcementLearning => {
1158 Box::new(AdaptiveScheduler::new(
1159 base_lr,
1160 0.5, 10, 1e-4, base_lr * 1e-3, "max", ))
1166 },
1167 TaskType::GANTraining => {
1168 Box::new(ConstantWithWarmupScheduler::new(
1169 base_lr,
1170 (total_steps as f32 * 0.02) as usize, ))
1172 },
1173 };
1174
1175 Self {
1176 scheduler,
1177 task_type,
1178 current_step: 0,
1179 }
1180 }
1181
1182 pub fn get_task_type(&self) -> &TaskType {
1184 &self.task_type
1185 }
1186}
1187
1188impl LRScheduler for TaskSpecificScheduler {
1189 fn get_lr(&self, step: usize) -> f32 {
1190 self.scheduler.get_lr(step)
1191 }
1192
1193 fn step(&mut self) {
1194 self.current_step += 1;
1195 self.scheduler.step();
1196 }
1197}