Skip to main content

trustformers_optim/
scheduler.rs

1//! # Learning Rate Schedulers
2//!
3//! This module provides various learning rate scheduling strategies for optimizers.
4//! Learning rate scheduling is crucial for achieving good convergence in deep learning.
5//!
6//! ## Available Schedulers
7//!
8//! - **LinearScheduler**: Linear warmup followed by linear decay
9//! - **CosineScheduler**: Linear warmup followed by cosine annealing
10//! - **PolynomialScheduler**: Polynomial decay with configurable power
11//! - **ConstantWithWarmupScheduler**: Constant LR after warmup
12//! - **ExponentialScheduler**: Exponential decay
13//! - **StepScheduler**: Step-wise decay at specified milestones
14//!
15//! ## Usage Example
16//!
17//! ```rust,no_run
18//! use trustformers_optim::{AdamW, CosineScheduler, LRScheduler};
19//! use trustformers_core::traits::Optimizer;
20//!
21//! let base_lr = 5e-4;
22//! let mut optimizer = AdamW::new(base_lr, (0.9, 0.999), 1e-8, 0.01);
23//!
24//! let mut scheduler = CosineScheduler::new(
25//!     base_lr,
26//!     1000,   // Linear warmup for 1000 steps
27//!     10000,  // Total training steps
28//!     1e-5,   // Minimum learning rate
29//! );
30//!
31//! // Training loop
32//! for step in 0..10000 {
33//!     // Get current learning rate
34//!     let lr = scheduler.get_lr(step);
35//!     optimizer.set_lr(lr);
36//!
37//!     // Training step...
38//!
39//!     scheduler.step();
40//! }
41//! ```
42//!
43//! ## Choosing a Scheduler
44//!
45//! ### For Transformer Pre-training
46//! - **CosineScheduler**: Most common, smooth decay
47//! - **LinearScheduler**: Simple and effective
48//!
49//! ### For Fine-tuning
50//! - **ConstantWithWarmupScheduler**: Stable for small datasets
51//! - **LinearScheduler**: With small decay rate
52//!
53//! ### For Computer Vision
54//! - **StepScheduler**: Traditional for CNNs
55//! - **CosineScheduler**: Modern alternative
56//!
57//! ## Warmup Importance
58//!
59//! Warmup is crucial for:
60//! - Stabilizing training with large learning rates
61//! - Preventing early divergence
62//! - Allowing adaptive optimizers to estimate statistics
63//!
64//! Typical warmup steps:
65//! - 2-10% of total training steps
66//! - 500-2000 steps for most tasks
67
68/// Trait for learning rate schedulers.
69pub trait LRScheduler: Send + Sync {
70    /// Get the learning rate for a given step.
71    fn get_lr(&self, step: usize) -> f32;
72    /// Advance the scheduler by one step.
73    fn step(&mut self);
74}
75
76/// Linear learning rate scheduler with warmup.
77///
78/// Implements linear warmup from 0 to base_lr, followed by linear decay to 0.
79/// This is commonly used for transformer pre-training.
80#[derive(Debug)]
81pub struct LinearScheduler {
82    base_lr: f32,
83    warmup_steps: usize,
84    total_steps: usize,
85    current_step: usize,
86}
87
88impl LinearScheduler {
89    pub fn new(base_lr: f32, warmup_steps: usize, total_steps: usize) -> Self {
90        Self {
91            base_lr,
92            warmup_steps,
93            total_steps,
94            current_step: 0,
95        }
96    }
97}
98
99impl LRScheduler for LinearScheduler {
100    fn get_lr(&self, step: usize) -> f32 {
101        if step < self.warmup_steps {
102            self.base_lr * (step as f32) / (self.warmup_steps as f32)
103        } else {
104            let progress =
105                (step - self.warmup_steps) as f32 / (self.total_steps - self.warmup_steps) as f32;
106            self.base_lr * (1.0 - progress).max(0.0)
107        }
108    }
109
110    fn step(&mut self) {
111        self.current_step += 1;
112    }
113}
114
115/// Cosine annealing learning rate scheduler with warmup.
116///
117/// Implements linear warmup followed by cosine decay to min_lr.
118/// This provides a smoother decay than linear scheduling and often
119/// leads to better final performance.
120#[derive(Debug)]
121pub struct CosineScheduler {
122    base_lr: f32,
123    warmup_steps: usize,
124    total_steps: usize,
125    current_step: usize,
126    min_lr: f32,
127}
128
129impl CosineScheduler {
130    pub fn new(base_lr: f32, warmup_steps: usize, total_steps: usize, min_lr: f32) -> Self {
131        Self {
132            base_lr,
133            warmup_steps,
134            total_steps,
135            current_step: 0,
136            min_lr,
137        }
138    }
139}
140
141impl LRScheduler for CosineScheduler {
142    fn get_lr(&self, step: usize) -> f32 {
143        use std::f32::consts::PI;
144
145        if step < self.warmup_steps {
146            self.base_lr * (step as f32) / (self.warmup_steps as f32)
147        } else {
148            let progress =
149                (step - self.warmup_steps) as f32 / (self.total_steps - self.warmup_steps) as f32;
150            let cosine_decay = 0.5 * (1.0 + (PI * progress).cos());
151            self.min_lr + (self.base_lr - self.min_lr) * cosine_decay
152        }
153    }
154
155    fn step(&mut self) {
156        self.current_step += 1;
157    }
158}
159
160/// Polynomial decay scheduler with configurable power.
161///
162/// Decays learning rate according to: lr = (base_lr - min_lr) * (1 - t)^power + min_lr
163/// where t is the progress ratio. Common powers:
164/// - power = 1.0: Linear decay
165/// - power = 0.5: Square root decay
166/// - power = 2.0: Quadratic decay
167#[derive(Debug)]
168pub struct PolynomialScheduler {
169    base_lr: f32,
170    warmup_steps: usize,
171    total_steps: usize,
172    current_step: usize,
173    min_lr: f32,
174    power: f32,
175}
176
177impl PolynomialScheduler {
178    pub fn new(
179        base_lr: f32,
180        warmup_steps: usize,
181        total_steps: usize,
182        min_lr: f32,
183        power: f32,
184    ) -> Self {
185        Self {
186            base_lr,
187            warmup_steps,
188            total_steps,
189            current_step: 0,
190            min_lr,
191            power,
192        }
193    }
194}
195
196impl LRScheduler for PolynomialScheduler {
197    fn get_lr(&self, step: usize) -> f32 {
198        if step < self.warmup_steps {
199            self.base_lr * (step as f32) / (self.warmup_steps as f32)
200        } else {
201            let progress =
202                (step - self.warmup_steps) as f32 / (self.total_steps - self.warmup_steps) as f32;
203            let decay_factor = (1.0 - progress.min(1.0)).powf(self.power);
204            self.min_lr + (self.base_lr - self.min_lr) * decay_factor
205        }
206    }
207
208    fn step(&mut self) {
209        self.current_step += 1;
210    }
211}
212
213/// Constant learning rate with warmup
214#[derive(Debug)]
215pub struct ConstantWithWarmupScheduler {
216    base_lr: f32,
217    warmup_steps: usize,
218    current_step: usize,
219}
220
221impl ConstantWithWarmupScheduler {
222    pub fn new(base_lr: f32, warmup_steps: usize) -> Self {
223        Self {
224            base_lr,
225            warmup_steps,
226            current_step: 0,
227        }
228    }
229}
230
231impl LRScheduler for ConstantWithWarmupScheduler {
232    fn get_lr(&self, step: usize) -> f32 {
233        if step < self.warmup_steps {
234            self.base_lr * (step as f32) / (self.warmup_steps as f32)
235        } else {
236            self.base_lr
237        }
238    }
239
240    fn step(&mut self) {
241        self.current_step += 1;
242    }
243}
244
245/// Exponential decay scheduler
246#[derive(Debug)]
247pub struct ExponentialScheduler {
248    base_lr: f32,
249    warmup_steps: usize,
250    current_step: usize,
251    decay_rate: f32,
252    decay_steps: usize,
253}
254
255impl ExponentialScheduler {
256    pub fn new(base_lr: f32, warmup_steps: usize, decay_rate: f32, decay_steps: usize) -> Self {
257        Self {
258            base_lr,
259            warmup_steps,
260            current_step: 0,
261            decay_rate,
262            decay_steps,
263        }
264    }
265}
266
267impl LRScheduler for ExponentialScheduler {
268    fn get_lr(&self, step: usize) -> f32 {
269        if step < self.warmup_steps {
270            self.base_lr * (step as f32) / (self.warmup_steps as f32)
271        } else {
272            let decay_step = (step - self.warmup_steps) / self.decay_steps;
273            self.base_lr * self.decay_rate.powf(decay_step as f32)
274        }
275    }
276
277    fn step(&mut self) {
278        self.current_step += 1;
279    }
280}
281
282/// Step decay scheduler (reduce LR at specific steps)
283#[derive(Debug)]
284pub struct StepScheduler {
285    base_lr: f32,
286    warmup_steps: usize,
287    current_step: usize,
288    step_size: usize,
289    gamma: f32,
290}
291
292impl StepScheduler {
293    pub fn new(base_lr: f32, warmup_steps: usize, step_size: usize, gamma: f32) -> Self {
294        Self {
295            base_lr,
296            warmup_steps,
297            current_step: 0,
298            step_size,
299            gamma,
300        }
301    }
302}
303
304impl LRScheduler for StepScheduler {
305    fn get_lr(&self, step: usize) -> f32 {
306        if step < self.warmup_steps {
307            self.base_lr * (step as f32) / (self.warmup_steps as f32)
308        } else {
309            let decay_step = (step - self.warmup_steps) / self.step_size;
310            self.base_lr * self.gamma.powf(decay_step as f32)
311        }
312    }
313
314    fn step(&mut self) {
315        self.current_step += 1;
316    }
317}
318
319/// OneCycle learning rate scheduler.
320///
321/// Implements the OneCycle policy: ramp up LR to max_lr over pct_start of training,
322/// then decay to final_lr for the remainder. This scheduler often enables training
323/// with much higher learning rates.
324#[derive(Debug)]
325pub struct OneCycleScheduler {
326    max_lr: f32,
327    final_lr: f32,
328    total_steps: usize,
329    pct_start: f32,
330    current_step: usize,
331}
332
333impl OneCycleScheduler {
334    pub fn new(max_lr: f32, total_steps: usize, pct_start: f32, final_lr: f32) -> Self {
335        Self {
336            max_lr,
337            final_lr,
338            total_steps,
339            pct_start: pct_start.clamp(0.0, 1.0),
340            current_step: 0,
341        }
342    }
343}
344
345impl LRScheduler for OneCycleScheduler {
346    fn get_lr(&self, step: usize) -> f32 {
347        use std::f32::consts::PI;
348
349        let step = step.min(self.total_steps);
350        let pct = step as f32 / self.total_steps as f32;
351
352        if pct <= self.pct_start {
353            // Ramp up phase
354            let phase_pct = pct / self.pct_start;
355            let cosine_term = 0.5 * (1.0 - (PI * phase_pct).cos());
356            self.final_lr + (self.max_lr - self.final_lr) * cosine_term
357        } else {
358            // Decay phase
359            let remaining_pct = (pct - self.pct_start) / (1.0 - self.pct_start);
360            let cosine_term = 0.5 * (1.0 + (PI * remaining_pct).cos());
361            self.final_lr + (self.max_lr - self.final_lr) * cosine_term
362        }
363    }
364
365    fn step(&mut self) {
366        self.current_step += 1;
367    }
368}
369
370/// Cosine annealing with warm restarts (SGDR).
371///
372/// Periodically restarts the learning rate schedule. This can help escape
373/// local minima and often improves final performance.
374#[derive(Debug)]
375pub struct CosineWithRestartsScheduler {
376    base_lr: f32,
377    min_lr: f32,
378    t_0: usize,
379    t_mult: f32,
380    current_step: usize,
381    next_restart: usize,
382    current_t: usize,
383}
384
385impl CosineWithRestartsScheduler {
386    pub fn new(base_lr: f32, min_lr: f32, t_0: usize, t_mult: f32) -> Self {
387        Self {
388            base_lr,
389            min_lr,
390            t_0,
391            t_mult,
392            current_step: 0,
393            next_restart: t_0,
394            current_t: t_0,
395        }
396    }
397}
398
399impl LRScheduler for CosineWithRestartsScheduler {
400    fn get_lr(&self, step: usize) -> f32 {
401        use std::f32::consts::PI;
402
403        let mut step_in_cycle = step;
404        let mut cycle_length = self.t_0;
405
406        // Find which cycle we're in
407        while step_in_cycle >= cycle_length {
408            step_in_cycle -= cycle_length;
409            cycle_length = (cycle_length as f32 * self.t_mult) as usize;
410        }
411
412        let progress = step_in_cycle as f32 / cycle_length as f32;
413        let cosine_decay = 0.5 * (1.0 + (PI * progress).cos());
414
415        self.min_lr + (self.base_lr - self.min_lr) * cosine_decay
416    }
417
418    fn step(&mut self) {
419        self.current_step += 1;
420
421        if self.current_step >= self.next_restart {
422            self.current_t = (self.current_t as f32 * self.t_mult) as usize;
423            self.next_restart += self.current_t;
424        }
425    }
426}
427
428/// Cyclical learning rate scheduler.
429///
430/// Cycles the learning rate between base_lr and max_lr over step_size_up + step_size_down steps.
431/// This can help find better learning rates and escape local minima.
432#[derive(Debug)]
433pub struct CyclicalScheduler {
434    base_lr: f32,
435    max_lr: f32,
436    step_size_up: usize,
437    step_size_down: usize,
438    current_step: usize,
439    mode: CyclicalMode,
440}
441
442#[derive(Debug, Clone)]
443pub enum CyclicalMode {
444    Triangular,
445    Triangular2,
446    ExpRange(f32), // gamma parameter
447}
448
449impl CyclicalScheduler {
450    pub fn new(
451        base_lr: f32,
452        max_lr: f32,
453        step_size_up: usize,
454        step_size_down: usize,
455        mode: CyclicalMode,
456    ) -> Self {
457        Self {
458            base_lr,
459            max_lr,
460            step_size_up,
461            step_size_down,
462            current_step: 0,
463            mode,
464        }
465    }
466}
467
468impl LRScheduler for CyclicalScheduler {
469    fn get_lr(&self, step: usize) -> f32 {
470        let cycle_length = self.step_size_up + self.step_size_down;
471        let cycle = (step / cycle_length) + 1;
472        let x = (step % cycle_length) as f32;
473
474        let (amplitude, _phase) = if x <= self.step_size_up as f32 {
475            // Ascending phase
476            (x / self.step_size_up as f32, 1.0)
477        } else {
478            // Descending phase
479            (
480                (self.step_size_down as f32 - (x - self.step_size_up as f32))
481                    / self.step_size_down as f32,
482                1.0,
483            )
484        };
485
486        let scale_factor = match &self.mode {
487            CyclicalMode::Triangular => 1.0,
488            CyclicalMode::Triangular2 => 1.0 / (2.0_f32.powi((cycle - 1) as i32)),
489            CyclicalMode::ExpRange(gamma) => gamma.powi(step as i32),
490        };
491
492        self.base_lr + (self.max_lr - self.base_lr) * amplitude * scale_factor
493    }
494
495    fn step(&mut self) {
496        self.current_step += 1;
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503
504    #[test]
505    fn test_linear_scheduler() {
506        let scheduler = LinearScheduler::new(1e-3, 100, 1000);
507
508        // Test warmup
509        assert_eq!(scheduler.get_lr(0), 0.0);
510        assert_eq!(scheduler.get_lr(50), 5e-4);
511        assert_eq!(scheduler.get_lr(100), 1e-3);
512
513        // Test decay
514        assert_eq!(scheduler.get_lr(550), 5e-4);
515        assert_eq!(scheduler.get_lr(1000), 0.0);
516    }
517
518    #[test]
519    fn test_cosine_scheduler() {
520        let scheduler = CosineScheduler::new(1e-3, 100, 1000, 1e-5);
521
522        // Test warmup
523        assert_eq!(scheduler.get_lr(0), 0.0);
524        assert_eq!(scheduler.get_lr(50), 5e-4);
525        assert_eq!(scheduler.get_lr(100), 1e-3);
526
527        // Test cosine decay - should be smooth
528        let mid_lr = scheduler.get_lr(550);
529        assert!(mid_lr > 1e-5 && mid_lr < 1e-3);
530
531        // Should approach min_lr at the end
532        let end_lr = scheduler.get_lr(1000);
533        assert!((end_lr - 1e-5).abs() < 1e-6);
534    }
535
536    #[test]
537    fn test_polynomial_scheduler() {
538        let scheduler = PolynomialScheduler::new(1e-3, 100, 1000, 1e-5, 2.0);
539
540        // Test warmup
541        assert_eq!(scheduler.get_lr(0), 0.0);
542        assert_eq!(scheduler.get_lr(100), 1e-3);
543
544        // Test polynomial decay
545        let mid_lr = scheduler.get_lr(550);
546        assert!(mid_lr > 1e-5 && mid_lr < 1e-3);
547    }
548
549    #[test]
550    fn test_constant_with_warmup_scheduler() {
551        let scheduler = ConstantWithWarmupScheduler::new(1e-3, 100);
552
553        // Test warmup
554        assert_eq!(scheduler.get_lr(0), 0.0);
555        assert_eq!(scheduler.get_lr(50), 5e-4);
556        assert_eq!(scheduler.get_lr(100), 1e-3);
557
558        // Test constant after warmup
559        assert_eq!(scheduler.get_lr(200), 1e-3);
560        assert_eq!(scheduler.get_lr(1000), 1e-3);
561    }
562
563    #[test]
564    fn test_exponential_scheduler() {
565        let scheduler = ExponentialScheduler::new(1e-3, 100, 0.9, 100);
566
567        // Test warmup
568        assert_eq!(scheduler.get_lr(0), 0.0);
569        assert_eq!(scheduler.get_lr(100), 1e-3);
570
571        // Test exponential decay
572        assert_eq!(scheduler.get_lr(200), 1e-3 * 0.9);
573        assert_eq!(scheduler.get_lr(300), 1e-3 * 0.9 * 0.9);
574    }
575
576    #[test]
577    fn test_step_scheduler() {
578        let scheduler = StepScheduler::new(1e-3, 100, 200, 0.5);
579
580        // Test warmup
581        assert_eq!(scheduler.get_lr(0), 0.0);
582        assert_eq!(scheduler.get_lr(100), 1e-3);
583
584        // Test step decay
585        assert_eq!(scheduler.get_lr(250), 1e-3); // Still first step
586        assert_eq!(scheduler.get_lr(300), 1e-3 * 0.5); // Second step
587        assert_eq!(scheduler.get_lr(500), 1e-3 * 0.5 * 0.5); // Third step
588    }
589
590    #[test]
591    fn test_onecycle_scheduler() {
592        let scheduler = OneCycleScheduler::new(1e-2, 1000, 0.3, 1e-5);
593
594        // Test start
595        assert_eq!(scheduler.get_lr(0), 1e-5);
596
597        // Test peak (around 30% of training)
598        let peak_lr = scheduler.get_lr(150);
599        assert!(peak_lr > 5e-3);
600
601        // Test end
602        let end_lr = scheduler.get_lr(1000);
603        assert!((end_lr - 1e-5).abs() < 1e-6);
604    }
605
606    #[test]
607    fn test_cosine_with_restarts_scheduler() {
608        let scheduler = CosineWithRestartsScheduler::new(1e-3, 1e-5, 100, 2.0);
609
610        // Test initial learning rate
611        assert!((scheduler.get_lr(0) - 1e-3).abs() < 1e-6);
612
613        // Test mid-cycle (should be between min and max)
614        let mid_lr = scheduler.get_lr(50);
615        assert!(mid_lr > 1e-5 && mid_lr < 1e-3);
616
617        // Test near end of first cycle (should be close to minimum)
618        let near_end_lr = scheduler.get_lr(99);
619        assert!(near_end_lr < 2e-4);
620
621        // Test restart (should be back to max)
622        let restart_lr = scheduler.get_lr(100);
623        assert!(restart_lr > 5e-4);
624    }
625
626    #[test]
627    fn test_cyclical_scheduler() {
628        let scheduler = CyclicalScheduler::new(1e-4, 1e-3, 50, 50, CyclicalMode::Triangular);
629
630        // Test base learning rate
631        assert!((scheduler.get_lr(0) - 1e-4).abs() < 1e-6);
632
633        // Test peak learning rate
634        assert!((scheduler.get_lr(50) - 1e-3).abs() < 1e-6);
635
636        // Test return to base
637        assert!((scheduler.get_lr(100) - 1e-4).abs() < 1e-6);
638
639        // Test second cycle
640        assert!((scheduler.get_lr(150) - 1e-3).abs() < 1e-6);
641    }
642}
643
644/// Adaptive learning rate scheduler that reduces LR when a metric has stopped improving.
645///
646/// This scheduler monitors a metric (typically validation loss) and reduces the learning rate
647/// when the metric plateaus for a certain number of epochs. Similar to ReduceLROnPlateau
648/// in PyTorch, this provides adaptive learning rate scheduling based on actual training progress.
649#[derive(Debug, Clone)]
650pub struct AdaptiveScheduler {
651    /// Current learning rate
652    current_lr: f32,
653    /// Factor by which to reduce learning rate (new_lr = lr * factor)
654    factor: f32,
655    /// Number of epochs with no improvement after which LR will be reduced
656    patience: usize,
657    /// Threshold for measuring the new optimum (relative improvement)
658    threshold: f32,
659    /// Minimum learning rate (will not go below this)
660    min_lr: f32,
661    /// Mode: "min" for minimizing (loss), "max" for maximizing (accuracy)
662    mode: String,
663    /// Number of epochs since last improvement
664    epochs_since_improvement: usize,
665    /// Best metric value seen so far
666    best_metric: Option<f32>,
667    /// Step counter
668    current_step: usize,
669}
670
671impl AdaptiveScheduler {
672    /// Creates a new adaptive scheduler.
673    ///
674    /// # Arguments
675    ///
676    /// * `initial_lr` - Initial learning rate
677    /// * `factor` - Factor by which to reduce LR (typical: 0.1 to 0.5)
678    /// * `patience` - Number of epochs to wait before reducing LR (typical: 5-10)
679    /// * `threshold` - Minimum improvement threshold (typical: 1e-4)
680    /// * `min_lr` - Minimum learning rate (typical: 1e-8)
681    /// * `mode` - "min" for loss, "max" for accuracy
682    ///
683    /// # Example
684    ///
685    /// ```
686    /// use trustformers_optim::AdaptiveScheduler;
687    ///
688    /// let scheduler = AdaptiveScheduler::new(1e-3, 0.1, 5, 1e-4, 1e-8, "min");
689    /// ```
690    pub fn new(
691        initial_lr: f32,
692        factor: f32,
693        patience: usize,
694        threshold: f32,
695        min_lr: f32,
696        mode: &str,
697    ) -> Self {
698        assert!(
699            factor > 0.0 && factor < 1.0,
700            "Factor must be between 0 and 1"
701        );
702        assert!(patience > 0, "Patience must be positive");
703        assert!(threshold >= 0.0, "Threshold must be non-negative");
704        assert!(min_lr >= 0.0, "Min LR must be non-negative");
705        assert!(mode == "min" || mode == "max", "Mode must be min or max");
706
707        Self {
708            current_lr: initial_lr,
709            factor,
710            patience,
711            threshold,
712            min_lr,
713            mode: mode.to_string(),
714            epochs_since_improvement: 0,
715            best_metric: None,
716            current_step: 0,
717        }
718    }
719
720    /// Update the scheduler with a new metric value.
721    /// Returns the new learning rate and whether it was reduced.
722    pub fn step_with_metric(&mut self, metric: f32) -> (f32, bool) {
723        self.current_step += 1;
724        let mut lr_reduced = false;
725
726        let is_improvement = match self.best_metric {
727            None => {
728                // First metric, set as best
729                self.best_metric = Some(metric);
730                true
731            },
732            Some(best) => {
733                let improvement = if self.mode == "min" {
734                    // For minimizing (loss), improvement is when metric decreases
735                    (best - metric) / best.abs().max(1e-8) > self.threshold
736                } else {
737                    // For maximizing (accuracy), improvement is when metric increases
738                    (metric - best) / best.abs().max(1e-8) > self.threshold
739                };
740
741                if improvement {
742                    self.best_metric = Some(metric);
743                }
744
745                improvement
746            },
747        };
748
749        if is_improvement {
750            self.epochs_since_improvement = 0;
751        } else {
752            self.epochs_since_improvement += 1;
753
754            if self.epochs_since_improvement >= self.patience {
755                // Reduce learning rate
756                let new_lr = (self.current_lr * self.factor).max(self.min_lr);
757                if new_lr < self.current_lr {
758                    self.current_lr = new_lr;
759                    lr_reduced = true;
760                    self.epochs_since_improvement = 0; // Reset patience counter
761                }
762            }
763        }
764
765        (self.current_lr, lr_reduced)
766    }
767
768    /// Get current learning rate without updating.
769    pub fn get_current_lr(&self) -> f32 {
770        self.current_lr
771    }
772
773    /// Get the best metric seen so far.
774    pub fn get_best_metric(&self) -> Option<f32> {
775        self.best_metric
776    }
777
778    /// Get epochs since last improvement.
779    pub fn get_epochs_since_improvement(&self) -> usize {
780        self.epochs_since_improvement
781    }
782
783    /// Reset the scheduler state.
784    pub fn reset(&mut self) {
785        self.epochs_since_improvement = 0;
786        self.best_metric = None;
787        self.current_step = 0;
788    }
789
790    /// Set the learning rate manually.
791    pub fn set_lr(&mut self, lr: f32) {
792        self.current_lr = lr;
793    }
794}
795
796impl LRScheduler for AdaptiveScheduler {
797    fn get_lr(&self, _step: usize) -> f32 {
798        self.current_lr
799    }
800
801    fn step(&mut self) {
802        // For adaptive scheduler, stepping is done via step_with_metric
803        // This method is kept for compatibility with the LRScheduler trait
804    }
805}
806
807/// A composite scheduler that chains multiple schedulers together.
808///
809/// This allows combining different scheduling strategies, e.g., warmup + cosine + linear decay.
810/// Each scheduler is active for a specified number of steps.
811pub struct CompositeScheduler {
812    schedulers: Vec<Box<dyn LRScheduler>>,
813    step_boundaries: Vec<usize>,
814    current_step: usize,
815    #[allow(dead_code)]
816    global_step_offset: usize,
817}
818
819impl CompositeScheduler {
820    /// Creates a new composite scheduler.
821    ///
822    /// # Arguments
823    /// * `schedulers` - Vector of schedulers to chain
824    /// * `step_boundaries` - Steps at which to switch to the next scheduler
825    ///
826    /// # Example
827    /// ```rust,no_run
828    /// use trustformers_optim::{LinearScheduler, CosineScheduler, CompositeScheduler, LRScheduler};
829    ///
830    /// let warmup = Box::new(LinearScheduler::new(1e-4, 1000, 1000));
831    /// let main = Box::new(CosineScheduler::new(1e-4, 0, 9000, 1e-6));
832    /// let composite = CompositeScheduler::new(
833    ///     vec![warmup, main],
834    ///     vec![1000, 10000]
835    /// );
836    /// ```
837    pub fn new(schedulers: Vec<Box<dyn LRScheduler>>, step_boundaries: Vec<usize>) -> Self {
838        assert_eq!(
839            schedulers.len(),
840            step_boundaries.len(),
841            "Number of schedulers must match number of boundaries"
842        );
843        assert!(
844            !schedulers.is_empty(),
845            "Must provide at least one scheduler"
846        );
847
848        Self {
849            schedulers,
850            step_boundaries,
851            current_step: 0,
852            global_step_offset: 0,
853        }
854    }
855
856    fn get_active_scheduler_index(&self, step: usize) -> usize {
857        for (i, &boundary) in self.step_boundaries.iter().enumerate() {
858            if step < boundary {
859                return i;
860            }
861        }
862        self.schedulers.len() - 1
863    }
864
865    fn get_local_step(&self, global_step: usize, scheduler_index: usize) -> usize {
866        if scheduler_index == 0 {
867            global_step
868        } else {
869            global_step - self.step_boundaries[scheduler_index - 1]
870        }
871    }
872}
873
874impl LRScheduler for CompositeScheduler {
875    fn get_lr(&self, step: usize) -> f32 {
876        let scheduler_idx = self.get_active_scheduler_index(step);
877        let local_step = self.get_local_step(step, scheduler_idx);
878        self.schedulers[scheduler_idx].get_lr(local_step)
879    }
880
881    fn step(&mut self) {
882        self.current_step += 1;
883        let _scheduler_idx = self.get_active_scheduler_index(self.current_step);
884        // Note: Individual schedulers manage their own state
885    }
886}
887
888/// A phase-based scheduler that applies different scheduling strategies during training phases.
889///
890/// This is useful for complex training regimes like pre-training -> fine-tuning -> evaluation.
891pub struct PhaseBasedScheduler {
892    phases: Vec<Phase>,
893    current_phase: usize,
894    current_step: usize,
895    phase_start_step: usize,
896}
897
898pub struct Phase {
899    pub name: String,
900    pub scheduler: Box<dyn LRScheduler>,
901    pub duration_steps: usize,
902    pub lr_multiplier: f32,
903}
904
905impl PhaseBasedScheduler {
906    /// Creates a new phase-based scheduler.
907    ///
908    /// # Example
909    /// ```rust,no_run
910    /// use trustformers_optim::{Phase, LinearScheduler, CosineScheduler, ConstantWithWarmupScheduler, PhaseBasedScheduler};
911    ///
912    /// let phases = vec![
913    ///     Phase {
914    ///         name: "warmup".to_string(),
915    ///         scheduler: Box::new(LinearScheduler::new(1e-4, 1000, 1000)),
916    ///         duration_steps: 1000,
917    ///         lr_multiplier: 1.0,
918    ///     },
919    ///     Phase {
920    ///         name: "main_training".to_string(),
921    ///         scheduler: Box::new(CosineScheduler::new(1e-4, 0, 9000, 1e-6)),
922    ///         duration_steps: 9000,
923    ///         lr_multiplier: 1.0,
924    ///     },
925    ///     Phase {
926    ///         name: "fine_tuning".to_string(),
927    ///         scheduler: Box::new(ConstantWithWarmupScheduler::new(1e-5, 0)),
928    ///         duration_steps: 1000,
929    ///         lr_multiplier: 0.1,
930    ///     },
931    /// ];
932    /// let scheduler = PhaseBasedScheduler::new(phases);
933    /// ```
934    pub fn new(phases: Vec<Phase>) -> Self {
935        assert!(!phases.is_empty(), "Must provide at least one phase");
936
937        Self {
938            phases,
939            current_phase: 0,
940            current_step: 0,
941            phase_start_step: 0,
942        }
943    }
944
945    /// Get the current phase name.
946    pub fn get_current_phase(&self) -> &str {
947        &self.phases[self.current_phase].name
948    }
949
950    /// Get the current phase index.
951    pub fn get_current_phase_index(&self) -> usize {
952        self.current_phase
953    }
954
955    /// Check if training is complete (all phases finished).
956    pub fn is_complete(&self) -> bool {
957        self.current_phase >= self.phases.len()
958    }
959
960    fn update_phase(&mut self, step: usize) {
961        while self.current_phase < self.phases.len() {
962            let phase_end = self.phase_start_step + self.phases[self.current_phase].duration_steps;
963
964            if step < phase_end {
965                break; // Still in current phase
966            }
967
968            // Move to next phase
969            self.current_phase += 1;
970            self.phase_start_step = phase_end;
971        }
972    }
973}
974
975impl LRScheduler for PhaseBasedScheduler {
976    fn get_lr(&self, step: usize) -> f32 {
977        if self.current_phase >= self.phases.len() {
978            return 0.0; // Training complete
979        }
980
981        let phase = &self.phases[self.current_phase];
982        let phase_step = step - self.phase_start_step;
983        let base_lr = phase.scheduler.get_lr(phase_step);
984
985        base_lr * phase.lr_multiplier
986    }
987
988    fn step(&mut self) {
989        self.current_step += 1;
990        self.update_phase(self.current_step);
991    }
992}
993
994/// A dynamic scheduler that adjusts its behavior based on training metrics.
995///
996/// This scheduler can dynamically switch between different scheduling strategies
997/// based on training progress, loss trends, or other metrics.
998pub struct DynamicScheduler {
999    primary_scheduler: Box<dyn LRScheduler>,
1000    fallback_scheduler: Box<dyn LRScheduler>,
1001    current_scheduler: usize, // 0 = primary, 1 = fallback
1002    switch_condition: SwitchCondition,
1003    metrics_window: Vec<f32>,
1004    window_size: usize,
1005    current_step: usize,
1006}
1007
1008#[derive(Debug)]
1009pub enum SwitchCondition {
1010    /// Switch when loss stops improving for N steps
1011    LossPlateauSteps(usize),
1012    /// Switch when gradient norm exceeds threshold
1013    GradientNormThreshold(f32),
1014    /// Switch at specific step
1015    StepThreshold(usize),
1016    /// Switch when loss increases by factor
1017    LossIncreaseFactor(f32),
1018}
1019
1020impl DynamicScheduler {
1021    /// Creates a new dynamic scheduler.
1022    pub fn new(
1023        primary_scheduler: Box<dyn LRScheduler>,
1024        fallback_scheduler: Box<dyn LRScheduler>,
1025        switch_condition: SwitchCondition,
1026        window_size: usize,
1027    ) -> Self {
1028        Self {
1029            primary_scheduler,
1030            fallback_scheduler,
1031            current_scheduler: 0,
1032            switch_condition,
1033            metrics_window: Vec::with_capacity(window_size),
1034            window_size,
1035            current_step: 0,
1036        }
1037    }
1038
1039    /// Update with a new metric (e.g., loss value).
1040    pub fn update_metric(&mut self, metric: f32) {
1041        self.metrics_window.push(metric);
1042        if self.metrics_window.len() > self.window_size {
1043            self.metrics_window.remove(0);
1044        }
1045
1046        // Check switch condition
1047        if self.current_scheduler == 0 && self.should_switch() {
1048            self.current_scheduler = 1;
1049        }
1050    }
1051
1052    fn should_switch(&self) -> bool {
1053        match &self.switch_condition {
1054            SwitchCondition::LossPlateauSteps(steps) => {
1055                if self.metrics_window.len() < *steps {
1056                    return false;
1057                }
1058
1059                let recent_avg =
1060                    self.metrics_window.iter().rev().take(*steps).sum::<f32>() / *steps as f32;
1061                let older_avg =
1062                    self.metrics_window.iter().take(self.metrics_window.len() - steps).sum::<f32>()
1063                        / (self.metrics_window.len() - steps) as f32;
1064
1065                recent_avg >= older_avg * 0.995 // Less than 0.5% improvement
1066            },
1067            SwitchCondition::StepThreshold(step) => self.current_step >= *step,
1068            SwitchCondition::LossIncreaseFactor(factor) => {
1069                if self.metrics_window.len() < 2 {
1070                    return false;
1071                }
1072                let latest = self.metrics_window[self.metrics_window.len() - 1];
1073                let previous = self.metrics_window[self.metrics_window.len() - 2];
1074                latest > previous * factor
1075            },
1076            SwitchCondition::GradientNormThreshold(_) => false, // Requires external gradient norm input
1077        }
1078    }
1079
1080    /// Get which scheduler is currently active.
1081    pub fn get_active_scheduler(&self) -> &str {
1082        if self.current_scheduler == 0 {
1083            "primary"
1084        } else {
1085            "fallback"
1086        }
1087    }
1088}
1089
1090impl LRScheduler for DynamicScheduler {
1091    fn get_lr(&self, step: usize) -> f32 {
1092        if self.current_scheduler == 0 {
1093            self.primary_scheduler.get_lr(step)
1094        } else {
1095            self.fallback_scheduler.get_lr(step)
1096        }
1097    }
1098
1099    fn step(&mut self) {
1100        self.current_step += 1;
1101        if self.current_scheduler == 0 {
1102            self.primary_scheduler.step();
1103        } else {
1104            self.fallback_scheduler.step();
1105        }
1106    }
1107}
1108
1109/// A task-specific scheduler optimized for different ML tasks.
1110pub struct TaskSpecificScheduler {
1111    scheduler: Box<dyn LRScheduler>,
1112    task_type: TaskType,
1113    current_step: usize,
1114}
1115
1116#[derive(Debug)]
1117pub enum TaskType {
1118    /// Language model pre-training (warmup + cosine decay)
1119    LanguageModelPretraining,
1120    /// Fine-tuning (low LR, minimal decay)
1121    FineTuning,
1122    /// Computer vision (step decay)
1123    ComputerVision,
1124    /// Reinforcement learning (adaptive)
1125    ReinforcementLearning,
1126    /// GAN training (alternating or constant)
1127    GANTraining,
1128}
1129
1130impl TaskSpecificScheduler {
1131    /// Creates a task-specific scheduler with optimal defaults.
1132    pub fn new(task_type: TaskType, base_lr: f32, total_steps: usize) -> Self {
1133        let scheduler: Box<dyn LRScheduler> = match task_type {
1134            TaskType::LanguageModelPretraining => {
1135                Box::new(CosineScheduler::new(
1136                    base_lr,
1137                    (total_steps as f32 * 0.06) as usize, // 6% warmup
1138                    total_steps,
1139                    base_lr * 0.1, // Decay to 10% of base LR
1140                ))
1141            },
1142            TaskType::FineTuning => {
1143                Box::new(LinearScheduler::new(
1144                    base_lr * 0.1,                       // Lower LR for fine-tuning
1145                    (total_steps as f32 * 0.1) as usize, // 10% warmup
1146                    total_steps,
1147                ))
1148            },
1149            TaskType::ComputerVision => {
1150                Box::new(StepScheduler::new(
1151                    base_lr,
1152                    (total_steps as f32 * 0.05) as usize, // 5% warmup
1153                    total_steps / 3,                      // Step every 1/3 of training
1154                    0.1,                                  // Decay by factor of 10
1155                ))
1156            },
1157            TaskType::ReinforcementLearning => {
1158                Box::new(AdaptiveScheduler::new(
1159                    base_lr,
1160                    0.5,            // Moderate reduction factor
1161                    10,             // Patience
1162                    1e-4,           // Threshold
1163                    base_lr * 1e-3, // Min LR
1164                    "max",          // Maximize reward
1165                ))
1166            },
1167            TaskType::GANTraining => {
1168                Box::new(ConstantWithWarmupScheduler::new(
1169                    base_lr,
1170                    (total_steps as f32 * 0.02) as usize, // 2% warmup
1171                ))
1172            },
1173        };
1174
1175        Self {
1176            scheduler,
1177            task_type,
1178            current_step: 0,
1179        }
1180    }
1181
1182    /// Get the task type.
1183    pub fn get_task_type(&self) -> &TaskType {
1184        &self.task_type
1185    }
1186}
1187
1188impl LRScheduler for TaskSpecificScheduler {
1189    fn get_lr(&self, step: usize) -> f32 {
1190        self.scheduler.get_lr(step)
1191    }
1192
1193    fn step(&mut self) {
1194        self.current_step += 1;
1195        self.scheduler.step();
1196    }
1197}