learning_rate_scheduling/
learning_rate_scheduling.rs

1//! Learning Rate Scheduling Example
2//!
3//! This example demonstrates various learning rate scheduling techniques
4//! for improving neural network training convergence and performance:
5//! - Step decay scheduling with milestones
6//! - Exponential decay scheduling
7//! - Cosine annealing scheduling
8//! - Linear warmup with decay
9//! - Adaptive scheduling based on validation loss
10//! - Performance comparison across scheduling strategies
11//!
12//! # Learning Objectives
13//!
14//! - Understand different learning rate scheduling strategies
15//! - Learn how to implement custom learning rate schedules
16//! - Explore the impact of scheduling on convergence
17//! - Compare different scheduling techniques
18//! - Implement adaptive scheduling based on training metrics
19//!
20//! # Prerequisites
21//!
22//! - Basic Rust knowledge
23//! - Understanding of Adam optimizer
24//! - Familiarity with neural network training
25//! - Knowledge of optimization concepts
26//!
27//! # Usage
28//!
29//! ```bash
30//! cargo run --example learning_rate_scheduling
31//! ```
32
33use train_station::{
34    optimizers::{Adam, Optimizer},
35    Tensor,
36};
37
38/// Learning rate scheduler trait
39trait LearningRateScheduler {
40    fn step(&mut self, current_lr: f32, epoch: usize, loss: f32) -> f32;
41    fn name(&self) -> &str;
42}
43
44/// Step decay scheduler
45struct StepDecayScheduler {
46    milestones: Vec<usize>,
47    gamma: f32,
48}
49
50impl StepDecayScheduler {
51    fn new(milestones: Vec<usize>, gamma: f32) -> Self {
52        Self { milestones, gamma }
53    }
54}
55
56impl LearningRateScheduler for StepDecayScheduler {
57    fn step(&mut self, current_lr: f32, epoch: usize, _loss: f32) -> f32 {
58        if self.milestones.contains(&epoch) {
59            current_lr * self.gamma
60        } else {
61            current_lr
62        }
63    }
64
65    fn name(&self) -> &str {
66        "Step Decay"
67    }
68}
69
70/// Exponential decay scheduler
71struct ExponentialDecayScheduler {
72    gamma: f32,
73}
74
75impl ExponentialDecayScheduler {
76    fn new(gamma: f32) -> Self {
77        Self { gamma }
78    }
79}
80
81impl LearningRateScheduler for ExponentialDecayScheduler {
82    fn step(&mut self, current_lr: f32, _epoch: usize, _loss: f32) -> f32 {
83        current_lr * self.gamma
84    }
85
86    fn name(&self) -> &str {
87        "Exponential Decay"
88    }
89}
90
91/// Cosine annealing scheduler
92struct CosineAnnealingScheduler {
93    t_max: usize,
94    eta_min: f32,
95    initial_lr: f32,
96}
97
98impl CosineAnnealingScheduler {
99    fn new(t_max: usize, eta_min: f32, initial_lr: f32) -> Self {
100        Self {
101            t_max,
102            eta_min,
103            initial_lr,
104        }
105    }
106}
107
108impl LearningRateScheduler for CosineAnnealingScheduler {
109    fn step(&mut self, _current_lr: f32, epoch: usize, _loss: f32) -> f32 {
110        let t = epoch as f32;
111        let t_max = self.t_max as f32;
112
113        self.eta_min
114            + 0.5
115                * (self.initial_lr - self.eta_min)
116                * (1.0 + (std::f32::consts::PI * t / t_max).cos())
117    }
118
119    fn name(&self) -> &str {
120        "Cosine Annealing"
121    }
122}
123
124/// Adaptive scheduler based on validation loss
125struct AdaptiveScheduler {
126    patience: usize,
127    factor: f32,
128    min_lr: f32,
129    best_loss: f32,
130    patience_counter: usize,
131}
132
133impl AdaptiveScheduler {
134    fn new(patience: usize, factor: f32, min_lr: f32) -> Self {
135        Self {
136            patience,
137            factor,
138            min_lr,
139            best_loss: f32::INFINITY,
140            patience_counter: 0,
141        }
142    }
143}
144
145impl LearningRateScheduler for AdaptiveScheduler {
146    fn step(&mut self, current_lr: f32, _epoch: usize, loss: f32) -> f32 {
147        if loss < self.best_loss {
148            self.best_loss = loss;
149            self.patience_counter = 0;
150            current_lr
151        } else {
152            self.patience_counter += 1;
153            if self.patience_counter >= self.patience {
154                let new_lr = (current_lr * self.factor).max(self.min_lr);
155                self.patience_counter = 0;
156                new_lr
157            } else {
158                current_lr
159            }
160        }
161    }
162
163    fn name(&self) -> &str {
164        "Adaptive (Reduce on Plateau)"
165    }
166}
167
168/// Training statistics
169#[derive(Debug)]
170#[allow(dead_code)]
171struct TrainingStats {
172    scheduler_name: String,
173    final_loss: f32,
174    lr_history: Vec<f32>,
175    loss_history: Vec<f32>,
176    convergence_epoch: usize,
177}
178
179fn main() -> Result<(), Box<dyn std::error::Error>> {
180    println!("=== Learning Rate Scheduling Example ===\n");
181
182    demonstrate_step_decay()?;
183    demonstrate_exponential_decay()?;
184    demonstrate_cosine_annealing()?;
185    demonstrate_adaptive_scheduling()?;
186    demonstrate_scheduler_comparison()?;
187
188    println!("\n=== Example completed successfully! ===");
189    Ok(())
190}
191
192/// Demonstrate step decay scheduling
193fn demonstrate_step_decay() -> Result<(), Box<dyn std::error::Error>> {
194    println!("--- Step Decay Scheduling ---");
195
196    let mut scheduler = StepDecayScheduler::new(vec![25, 50, 75], 0.5);
197    let stats = train_with_scheduler(&mut scheduler, 100)?;
198
199    println!("Step decay results:");
200    println!("  Final loss: {:.6}", stats.final_loss);
201    println!("  Convergence epoch: {}", stats.convergence_epoch);
202    println!("  Learning rate schedule:");
203    for (i, &lr) in stats.lr_history.iter().enumerate().step_by(10) {
204        println!("    Epoch {:3}: LR = {:.6}", i, lr);
205    }
206
207    Ok(())
208}
209
210/// Demonstrate exponential decay scheduling
211fn demonstrate_exponential_decay() -> Result<(), Box<dyn std::error::Error>> {
212    println!("\n--- Exponential Decay Scheduling ---");
213
214    let mut scheduler = ExponentialDecayScheduler::new(0.95);
215    let stats = train_with_scheduler(&mut scheduler, 100)?;
216
217    println!("Exponential decay results:");
218    println!("  Final loss: {:.6}", stats.final_loss);
219    println!("  Convergence epoch: {}", stats.convergence_epoch);
220    println!("  Learning rate schedule:");
221    for (i, &lr) in stats.lr_history.iter().enumerate().step_by(10) {
222        println!("    Epoch {:3}: LR = {:.6}", i, lr);
223    }
224
225    Ok(())
226}
227
228/// Demonstrate cosine annealing scheduling
229fn demonstrate_cosine_annealing() -> Result<(), Box<dyn std::error::Error>> {
230    println!("\n--- Cosine Annealing Scheduling ---");
231
232    let initial_lr = 0.1;
233    let mut scheduler = CosineAnnealingScheduler::new(100, 0.001, initial_lr);
234    let stats = train_with_scheduler(&mut scheduler, 100)?;
235
236    println!("Cosine annealing results:");
237    println!("  Final loss: {:.6}", stats.final_loss);
238    println!("  Convergence epoch: {}", stats.convergence_epoch);
239    println!("  Learning rate schedule:");
240    for (i, &lr) in stats.lr_history.iter().enumerate().step_by(10) {
241        println!("    Epoch {:3}: LR = {:.6}", i, lr);
242    }
243
244    Ok(())
245}
246
247/// Demonstrate adaptive scheduling
248fn demonstrate_adaptive_scheduling() -> Result<(), Box<dyn std::error::Error>> {
249    println!("\n--- Adaptive Scheduling ---");
250
251    let mut scheduler = AdaptiveScheduler::new(5, 0.5, 0.001);
252    let stats = train_with_scheduler(&mut scheduler, 100)?;
253
254    println!("Adaptive scheduling results:");
255    println!("  Final loss: {:.6}", stats.final_loss);
256    println!("  Convergence epoch: {}", stats.convergence_epoch);
257    println!("  Learning rate schedule:");
258    for (i, &lr) in stats.lr_history.iter().enumerate().step_by(10) {
259        println!("    Epoch {:3}: LR = {:.6}", i, lr);
260    }
261
262    Ok(())
263}
264
265/// Demonstrate scheduler comparison
266fn demonstrate_scheduler_comparison() -> Result<(), Box<dyn std::error::Error>> {
267    println!("\n--- Scheduler Comparison ---");
268
269    let schedulers: Vec<Box<dyn LearningRateScheduler>> = vec![
270        Box::new(StepDecayScheduler::new(vec![30, 60], 0.5)),
271        Box::new(ExponentialDecayScheduler::new(0.98)),
272        Box::new(CosineAnnealingScheduler::new(100, 0.001, 0.05)),
273        Box::new(AdaptiveScheduler::new(8, 0.7, 0.001)),
274    ];
275
276    let mut results = Vec::new();
277
278    for mut scheduler in schedulers {
279        println!("\nTesting {} scheduler:", scheduler.name());
280
281        let stats = train_with_scheduler(scheduler.as_mut(), 100)?;
282        results.push(stats);
283
284        println!("  Final loss: {:.6}", results.last().unwrap().final_loss);
285        println!(
286            "  Convergence: {} epochs",
287            results.last().unwrap().convergence_epoch
288        );
289    }
290
291    // Comparison summary
292    println!("\nScheduler Comparison Summary:");
293    println!(
294        "  {:20} | {:10} | {:12} | {:12}",
295        "Scheduler", "Final Loss", "Convergence", "LR Range"
296    );
297    println!("  {}", "-".repeat(70));
298
299    for stats in &results {
300        let lr_range = format!(
301            "{:.0e} - {:.0e}",
302            stats
303                .lr_history
304                .iter()
305                .cloned()
306                .fold(f32::INFINITY, f32::min),
307            stats.lr_history.iter().cloned().fold(0.0, f32::max)
308        );
309        println!(
310            "  {:20} | {:.6} | {:8} | {}",
311            stats.scheduler_name, stats.final_loss, stats.convergence_epoch, lr_range
312        );
313    }
314
315    Ok(())
316}
317
318/// Helper function to train with a learning rate scheduler
319fn train_with_scheduler(
320    scheduler: &mut dyn LearningRateScheduler,
321    num_epochs: usize,
322) -> Result<TrainingStats, Box<dyn std::error::Error>> {
323    // Create training data: y = 2*x + 1
324    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
325    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
326
327    // Create model parameters
328    let mut weight = Tensor::randn(vec![1, 1], Some(456)).with_requires_grad();
329    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
330
331    // Create optimizer with initial learning rate
332    let mut optimizer = Adam::with_learning_rate(0.05);
333    optimizer.add_parameter(&weight);
334    optimizer.add_parameter(&bias);
335
336    // Training loop
337    let mut losses = Vec::new();
338    let mut lr_history = Vec::new();
339    let mut convergence_epoch = num_epochs;
340
341    for epoch in 0..num_epochs {
342        // Forward pass
343        let y_pred = x_data.matmul(&weight) + &bias;
344        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
345
346        // Backward pass
347        loss.backward(None);
348
349        // Update learning rate using scheduler
350        let current_lr = optimizer.learning_rate();
351        let new_lr = scheduler.step(current_lr, epoch, loss.value());
352
353        if (new_lr - current_lr).abs() > 1e-8 {
354            optimizer.set_learning_rate(new_lr);
355        }
356
357        // Optimizer step
358        optimizer.step(&mut [&mut weight, &mut bias]);
359        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
360
361        let loss_value = loss.value();
362        losses.push(loss_value);
363        lr_history.push(new_lr);
364
365        // Check for convergence
366        if loss_value < 0.01 && convergence_epoch == num_epochs {
367            convergence_epoch = epoch;
368        }
369    }
370
371    Ok(TrainingStats {
372        scheduler_name: scheduler.name().to_string(),
373        final_loss: losses[losses.len() - 1],
374        lr_history,
375        loss_history: losses,
376        convergence_epoch,
377    })
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn test_step_decay_scheduler() {
386        let mut scheduler = StepDecayScheduler::new(vec![5, 10], 0.5);
387        let mut lr = 0.1;
388
389        lr = scheduler.step(lr, 0, 0.0);
390        assert_eq!(lr, 0.1);
391
392        lr = scheduler.step(lr, 5, 0.0);
393        assert_eq!(lr, 0.05);
394
395        lr = scheduler.step(lr, 10, 0.0);
396        assert_eq!(lr, 0.025);
397    }
398
399    #[test]
400    fn test_exponential_decay_scheduler() {
401        let mut scheduler = ExponentialDecayScheduler::new(0.9);
402        let mut lr = 0.1;
403
404        lr = scheduler.step(lr, 0, 0.0);
405        assert!((lr - 0.09).abs() < 1e-6);
406    }
407
408    #[test]
409    fn test_cosine_annealing_scheduler() {
410        let mut scheduler = CosineAnnealingScheduler::new(10, 0.001, 0.1);
411
412        let lr_start = scheduler.step(0.0, 0, 0.0);
413        assert!((lr_start - 0.1).abs() < 1e-6);
414
415        let lr_mid = scheduler.step(0.0, 5, 0.0);
416        assert!((lr_mid - 0.0505).abs() < 1e-3); // Approximately halfway
417
418        let lr_end = scheduler.step(0.0, 9, 0.0);
419        assert!((lr_end - 0.001).abs() < 1e-3);
420    }
421
422    #[test]
423    fn test_adaptive_scheduler() {
424        let mut scheduler = AdaptiveScheduler::new(2, 0.5, 0.001);
425        let mut lr = 0.1;
426
427        // Improving loss - should not change LR
428        lr = scheduler.step(lr, 0, 0.5);
429        assert_eq!(lr, 0.1);
430
431        // Worse loss - should not change LR yet (patience = 2)
432        lr = scheduler.step(lr, 1, 0.6);
433        assert_eq!(lr, 0.1);
434
435        lr = scheduler.step(lr, 2, 0.6);
436        assert_eq!(lr, 0.05); // Should reduce after patience
437
438        // Improving again - should not change LR
439        lr = scheduler.step(lr, 3, 0.4);
440        assert_eq!(lr, 0.05);
441    }
442
443    #[test]
444    fn test_scheduler_training() {
445        let mut scheduler = StepDecayScheduler::new(vec![10], 0.5);
446        let stats = train_with_scheduler(&mut scheduler, 20).unwrap();
447
448        assert!(stats.final_loss < 1.0);
449        assert_eq!(stats.lr_history.len(), 20);
450        assert!(stats.convergence_epoch < 20);
451    }
452}