Skip to main content

ruvector_gnn/
scheduler.rs

1//! Learning rate scheduling for Graph Neural Networks
2//!
3//! Provides various learning rate scheduling strategies to prevent catastrophic
4//! forgetting and optimize training dynamics in continual learning scenarios.
5
6use std::f32::consts::PI;
7
8/// Learning rate scheduling strategies
9#[derive(Debug, Clone)]
10pub enum SchedulerType {
11    /// Constant learning rate throughout training
12    Constant,
13
14    /// Step decay: multiply learning rate by gamma every step_size epochs
15    /// Formula: lr = base_lr * gamma^(epoch / step_size)
16    StepDecay { step_size: usize, gamma: f32 },
17
18    /// Exponential decay: multiply learning rate by gamma each epoch
19    /// Formula: lr = base_lr * gamma^epoch
20    Exponential { gamma: f32 },
21
22    /// Cosine annealing with warm restarts
23    /// Formula: lr = eta_min + 0.5 * (base_lr - eta_min) * (1 + cos(pi * (epoch % t_max) / t_max))
24    CosineAnnealing { t_max: usize, eta_min: f32 },
25
26    /// Warmup phase followed by linear decay
27    /// Linearly increases lr from 0 to base_lr over warmup_steps,
28    /// then linearly decreases to 0 over remaining steps
29    WarmupLinear {
30        warmup_steps: usize,
31        total_steps: usize,
32    },
33
34    /// Reduce learning rate when a metric plateaus
35    /// Useful for online learning scenarios
36    ReduceOnPlateau {
37        factor: f32,
38        patience: usize,
39        min_lr: f32,
40    },
41}
42
43/// Learning rate scheduler for GNN training
44///
45/// Implements various scheduling strategies to control learning rate
46/// during training, helping prevent catastrophic forgetting and
47/// improve convergence.
48#[derive(Debug, Clone)]
49pub struct LearningRateScheduler {
50    scheduler_type: SchedulerType,
51    base_lr: f32,
52    current_lr: f32,
53    step_count: usize,
54    best_metric: f32,
55    patience_counter: usize,
56}
57
58impl LearningRateScheduler {
59    /// Creates a new learning rate scheduler
60    ///
61    /// # Arguments
62    /// * `scheduler_type` - The scheduling strategy to use
63    /// * `base_lr` - The initial/base learning rate
64    ///
65    /// # Example
66    /// ```
67    /// use ruvector_gnn::scheduler::{LearningRateScheduler, SchedulerType};
68    ///
69    /// let scheduler = LearningRateScheduler::new(
70    ///     SchedulerType::StepDecay { step_size: 10, gamma: 0.9 },
71    ///     0.001
72    /// );
73    /// ```
74    pub fn new(scheduler_type: SchedulerType, base_lr: f32) -> Self {
75        Self {
76            scheduler_type,
77            base_lr,
78            current_lr: base_lr,
79            step_count: 0,
80            best_metric: f32::INFINITY,
81            patience_counter: 0,
82        }
83    }
84
85    /// Advances the scheduler by one step and returns the new learning rate
86    ///
87    /// For most schedulers, this should be called once per epoch.
88    /// For ReduceOnPlateau, use `step_with_metric` instead.
89    ///
90    /// # Returns
91    /// The updated learning rate
92    pub fn step(&mut self) -> f32 {
93        self.step_count += 1;
94        self.current_lr = self.calculate_lr();
95        self.current_lr
96    }
97
98    /// Advances the scheduler with a metric value (for ReduceOnPlateau)
99    ///
100    /// # Arguments
101    /// * `metric` - The metric value to monitor (e.g., validation loss)
102    ///
103    /// # Returns
104    /// The updated learning rate
105    pub fn step_with_metric(&mut self, metric: f32) -> f32 {
106        self.step_count += 1;
107
108        match &self.scheduler_type {
109            SchedulerType::ReduceOnPlateau {
110                factor,
111                patience,
112                min_lr,
113            } => {
114                // Check if metric improved
115                if metric < self.best_metric - 1e-8 {
116                    self.best_metric = metric;
117                    self.patience_counter = 0;
118                } else {
119                    self.patience_counter += 1;
120
121                    // Reduce learning rate if patience exceeded
122                    if self.patience_counter >= *patience {
123                        self.current_lr = (self.current_lr * factor).max(*min_lr);
124                        self.patience_counter = 0;
125                    }
126                }
127            }
128            _ => {
129                // For non-plateau schedulers, just use step()
130                self.current_lr = self.calculate_lr();
131            }
132        }
133
134        self.current_lr
135    }
136
137    /// Gets the current learning rate without advancing the scheduler
138    pub fn get_lr(&self) -> f32 {
139        self.current_lr
140    }
141
142    /// Resets the scheduler to its initial state
143    pub fn reset(&mut self) {
144        self.current_lr = self.base_lr;
145        self.step_count = 0;
146        self.best_metric = f32::INFINITY;
147        self.patience_counter = 0;
148    }
149
150    /// Calculates the learning rate based on the current step and scheduler type
151    fn calculate_lr(&self) -> f32 {
152        match &self.scheduler_type {
153            SchedulerType::Constant => self.base_lr,
154
155            SchedulerType::StepDecay { step_size, gamma } => {
156                let decay_factor = (*gamma).powi((self.step_count / step_size) as i32);
157                self.base_lr * decay_factor
158            }
159
160            SchedulerType::Exponential { gamma } => {
161                let decay_factor = (*gamma).powi(self.step_count as i32);
162                self.base_lr * decay_factor
163            }
164
165            SchedulerType::CosineAnnealing { t_max, eta_min } => {
166                let cycle_step = self.step_count % t_max;
167                let cos_term = (PI * cycle_step as f32 / *t_max as f32).cos();
168                eta_min + 0.5 * (self.base_lr - eta_min) * (1.0 + cos_term)
169            }
170
171            SchedulerType::WarmupLinear {
172                warmup_steps,
173                total_steps,
174            } => {
175                if self.step_count < *warmup_steps {
176                    // Warmup phase: linear increase
177                    self.base_lr * (self.step_count as f32 / *warmup_steps as f32)
178                } else if self.step_count < *total_steps {
179                    // Decay phase: linear decrease
180                    let remaining_steps = *total_steps - self.step_count;
181                    let total_decay_steps = *total_steps - *warmup_steps;
182                    self.base_lr * (remaining_steps as f32 / total_decay_steps as f32)
183                } else {
184                    // After total_steps, keep at 0
185                    0.0
186                }
187            }
188
189            SchedulerType::ReduceOnPlateau { .. } => {
190                // For plateau scheduler, lr is updated in step_with_metric
191                self.current_lr
192            }
193        }
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    const EPSILON: f32 = 1e-6;
202
203    fn assert_close(a: f32, b: f32, msg: &str) {
204        assert!((a - b).abs() < EPSILON, "{}: {} != {}", msg, a, b);
205    }
206
207    #[test]
208    fn test_constant_scheduler() {
209        let mut scheduler = LearningRateScheduler::new(SchedulerType::Constant, 0.01);
210
211        assert_close(scheduler.get_lr(), 0.01, "Initial LR");
212
213        for i in 1..=10 {
214            let lr = scheduler.step();
215            assert_close(lr, 0.01, &format!("Step {} LR", i));
216        }
217    }
218
219    #[test]
220    fn test_step_decay() {
221        let mut scheduler = LearningRateScheduler::new(
222            SchedulerType::StepDecay {
223                step_size: 5,
224                gamma: 0.5,
225            },
226            0.1,
227        );
228
229        assert_close(scheduler.get_lr(), 0.1, "Initial LR");
230
231        // Steps 1-4: no decay
232        for i in 1..=4 {
233            let lr = scheduler.step();
234            assert_close(lr, 0.1, &format!("Step {} LR", i));
235        }
236
237        // Step 5: first decay (0.1 * 0.5)
238        let lr = scheduler.step();
239        assert_close(lr, 0.05, "Step 5 LR (first decay)");
240
241        // Steps 6-9: maintain decayed rate
242        for i in 6..=9 {
243            let lr = scheduler.step();
244            assert_close(lr, 0.05, &format!("Step {} LR", i));
245        }
246
247        // Step 10: second decay (0.1 * 0.5^2)
248        let lr = scheduler.step();
249        assert_close(lr, 0.025, "Step 10 LR (second decay)");
250    }
251
252    #[test]
253    fn test_exponential_decay() {
254        let mut scheduler =
255            LearningRateScheduler::new(SchedulerType::Exponential { gamma: 0.9 }, 0.1);
256
257        assert_close(scheduler.get_lr(), 0.1, "Initial LR");
258
259        let expected_lrs = vec![
260            0.1 * 0.9,   // Step 1
261            0.1 * 0.81,  // Step 2 (0.9^2)
262            0.1 * 0.729, // Step 3 (0.9^3)
263        ];
264
265        for (i, expected) in expected_lrs.iter().enumerate() {
266            let lr = scheduler.step();
267            assert_close(lr, *expected, &format!("Step {} LR", i + 1));
268        }
269    }
270
271    #[test]
272    fn test_cosine_annealing() {
273        let mut scheduler = LearningRateScheduler::new(
274            SchedulerType::CosineAnnealing {
275                t_max: 10,
276                eta_min: 0.0,
277            },
278            1.0,
279        );
280
281        assert_close(scheduler.get_lr(), 1.0, "Initial LR");
282
283        // Cosine annealing formula: lr = eta_min + 0.5 * (base_lr - eta_min) * (1 + cos(pi * cycle_step / t_max))
284        // cycle_step = step_count % t_max
285        // At step 5: cycle_step = 5, cos(pi * 5/10) = cos(pi/2) = 0, lr = 0 + 0.5 * 1 * (1 + 0) = 0.5
286        // At step 10: cycle_step = 0 (wrapped), cos(0) = 1, lr = 0 + 0.5 * 1 * (1 + 1) = 1.0 (restart)
287
288        for _ in 1..=5 {
289            scheduler.step();
290        }
291        assert_close(scheduler.get_lr(), 0.5, "Mid-cycle LR (step 5)");
292
293        // At step 9: cycle_step = 9, cos(pi * 9/10) ≈ -0.951, lr ≈ 0.025
294        for _ in 6..=9 {
295            scheduler.step();
296        }
297        let lr_step9 = scheduler.get_lr();
298        assert!(
299            lr_step9 < 0.1,
300            "Near end of cycle LR (step 9) should be small: {}",
301            lr_step9
302        );
303
304        // At step 10: warm restart (cycle_step = 0), LR goes back to base
305        scheduler.step();
306        assert_close(
307            scheduler.get_lr(),
308            1.0,
309            "Restart at step 10 (cycle_step = 0)",
310        );
311
312        // Continue new cycle
313        scheduler.step();
314        assert!(
315            scheduler.get_lr() < 1.0,
316            "Step 11 should be less than base LR"
317        );
318    }
319
320    #[test]
321    fn test_warmup_linear() {
322        let mut scheduler = LearningRateScheduler::new(
323            SchedulerType::WarmupLinear {
324                warmup_steps: 5,
325                total_steps: 10,
326            },
327            1.0,
328        );
329
330        assert_close(scheduler.get_lr(), 1.0, "Initial LR");
331
332        // Warmup phase: linear increase
333        scheduler.step();
334        assert_close(scheduler.get_lr(), 0.2, "Step 1 (warmup)");
335
336        scheduler.step();
337        assert_close(scheduler.get_lr(), 0.4, "Step 2 (warmup)");
338
339        scheduler.step();
340        assert_close(scheduler.get_lr(), 0.6, "Step 3 (warmup)");
341
342        scheduler.step();
343        assert_close(scheduler.get_lr(), 0.8, "Step 4 (warmup)");
344
345        scheduler.step();
346        assert_close(scheduler.get_lr(), 1.0, "Step 5 (warmup end)");
347
348        // Decay phase: linear decrease
349        scheduler.step();
350        assert_close(scheduler.get_lr(), 0.8, "Step 6 (decay)");
351
352        scheduler.step();
353        assert_close(scheduler.get_lr(), 0.6, "Step 7 (decay)");
354
355        scheduler.step();
356        assert_close(scheduler.get_lr(), 0.4, "Step 8 (decay)");
357
358        scheduler.step();
359        assert_close(scheduler.get_lr(), 0.2, "Step 9 (decay)");
360
361        scheduler.step();
362        assert_close(scheduler.get_lr(), 0.0, "Step 10 (decay end)");
363
364        // After total_steps
365        scheduler.step();
366        assert_close(scheduler.get_lr(), 0.0, "Step 11 (after total)");
367    }
368
369    #[test]
370    fn test_reduce_on_plateau() {
371        let mut scheduler = LearningRateScheduler::new(
372            SchedulerType::ReduceOnPlateau {
373                factor: 0.5,
374                patience: 3,
375                min_lr: 0.0001,
376            },
377            0.01,
378        );
379
380        assert_close(scheduler.get_lr(), 0.01, "Initial LR");
381
382        // Improving metrics: no reduction (sets best_metric, resets patience)
383        scheduler.step_with_metric(1.0);
384        assert_close(
385            scheduler.get_lr(),
386            0.01,
387            "Step 1 (first metric, sets baseline)",
388        );
389
390        scheduler.step_with_metric(0.9);
391        assert_close(scheduler.get_lr(), 0.01, "Step 2 (improving)");
392
393        // Plateau: metric not improving (patience counter: 1, 2, 3)
394        scheduler.step_with_metric(0.91);
395        assert_close(scheduler.get_lr(), 0.01, "Step 3 (plateau 1)");
396
397        scheduler.step_with_metric(0.92);
398        assert_close(scheduler.get_lr(), 0.01, "Step 4 (plateau 2)");
399
400        // patience=3 means after 3 non-improvements, reduce LR
401        // Step 5 is the 3rd non-improvement, so LR gets reduced
402        scheduler.step_with_metric(0.93);
403        assert_close(
404            scheduler.get_lr(),
405            0.005,
406            "Step 5 (patience exceeded, reduced)",
407        );
408
409        // Counter is reset after reduction, so we need 3 more non-improvements
410        scheduler.step_with_metric(0.94); // plateau 1 after reset
411        assert_close(scheduler.get_lr(), 0.005, "Step 6 (plateau 1 after reset)");
412
413        scheduler.step_with_metric(0.95); // plateau 2
414        assert_close(scheduler.get_lr(), 0.005, "Step 7 (plateau 2)");
415
416        scheduler.step_with_metric(0.96); // plateau 3 - triggers reduction
417        assert_close(scheduler.get_lr(), 0.0025, "Step 8 (reduced again)");
418
419        // Test min_lr floor
420        for _ in 0..20 {
421            scheduler.step_with_metric(1.0);
422        }
423        assert!(
424            scheduler.get_lr() >= 0.0001,
425            "LR should not go below min_lr"
426        );
427    }
428
429    #[test]
430    fn test_scheduler_reset() {
431        let mut scheduler =
432            LearningRateScheduler::new(SchedulerType::Exponential { gamma: 0.9 }, 0.1);
433
434        // Run for several steps
435        for _ in 0..5 {
436            scheduler.step();
437        }
438        assert!(scheduler.get_lr() < 0.1, "LR should have decayed");
439
440        // Reset and verify
441        scheduler.reset();
442        assert_close(scheduler.get_lr(), 0.1, "Reset LR");
443        assert_eq!(scheduler.step_count, 0, "Reset step count");
444    }
445
446    #[test]
447    fn test_scheduler_cloning() {
448        let scheduler1 = LearningRateScheduler::new(
449            SchedulerType::StepDecay {
450                step_size: 10,
451                gamma: 0.5,
452            },
453            0.01,
454        );
455
456        let mut scheduler2 = scheduler1.clone();
457
458        // Advance clone
459        scheduler2.step();
460
461        // Original should be unchanged
462        assert_close(scheduler1.get_lr(), 0.01, "Original LR");
463        assert_close(scheduler2.get_lr(), 0.01, "Clone LR after step");
464    }
465
466    #[test]
467    fn test_multiple_scheduler_types() {
468        let schedulers = vec![
469            (SchedulerType::Constant, 0.01),
470            (
471                SchedulerType::StepDecay {
472                    step_size: 5,
473                    gamma: 0.9,
474                },
475                0.01,
476            ),
477            (SchedulerType::Exponential { gamma: 0.95 }, 0.01),
478            (
479                SchedulerType::CosineAnnealing {
480                    t_max: 10,
481                    eta_min: 0.001,
482                },
483                0.01,
484            ),
485            (
486                SchedulerType::WarmupLinear {
487                    warmup_steps: 5,
488                    total_steps: 20,
489                },
490                0.01,
491            ),
492            (
493                SchedulerType::ReduceOnPlateau {
494                    factor: 0.5,
495                    patience: 5,
496                    min_lr: 0.0001,
497                },
498                0.01,
499            ),
500        ];
501
502        for (sched_type, base_lr) in schedulers {
503            let mut scheduler = LearningRateScheduler::new(sched_type, base_lr);
504
505            // All schedulers should start at base_lr
506            assert_close(scheduler.get_lr(), base_lr, "Initial LR for scheduler type");
507
508            // All schedulers should be able to step
509            let _ = scheduler.step();
510            assert!(scheduler.get_lr() >= 0.0, "LR should be non-negative");
511        }
512    }
513
514    #[test]
515    fn test_edge_cases() {
516        // Zero learning rate
517        let mut scheduler = LearningRateScheduler::new(SchedulerType::Constant, 0.0);
518        assert_close(scheduler.get_lr(), 0.0, "Zero LR");
519        scheduler.step();
520        assert_close(scheduler.get_lr(), 0.0, "Zero LR after step");
521
522        // Very small gamma
523        let mut scheduler =
524            LearningRateScheduler::new(SchedulerType::Exponential { gamma: 0.1 }, 1.0);
525        for _ in 0..10 {
526            scheduler.step();
527        }
528        assert!(scheduler.get_lr() > 0.0, "LR should remain positive");
529        assert!(scheduler.get_lr() < 1e-8, "LR should be very small");
530    }
531}