1use train_station::{
34 optimizers::{Adam, Optimizer},
35 Tensor,
36};
37
38trait LearningRateScheduler {
40 fn step(&mut self, current_lr: f32, epoch: usize, loss: f32) -> f32;
41 fn name(&self) -> &str;
42}
43
44struct StepDecayScheduler {
46 milestones: Vec<usize>,
47 gamma: f32,
48}
49
50impl StepDecayScheduler {
51 fn new(milestones: Vec<usize>, gamma: f32) -> Self {
52 Self { milestones, gamma }
53 }
54}
55
56impl LearningRateScheduler for StepDecayScheduler {
57 fn step(&mut self, current_lr: f32, epoch: usize, _loss: f32) -> f32 {
58 if self.milestones.contains(&epoch) {
59 current_lr * self.gamma
60 } else {
61 current_lr
62 }
63 }
64
65 fn name(&self) -> &str {
66 "Step Decay"
67 }
68}
69
70struct ExponentialDecayScheduler {
72 gamma: f32,
73}
74
75impl ExponentialDecayScheduler {
76 fn new(gamma: f32) -> Self {
77 Self { gamma }
78 }
79}
80
81impl LearningRateScheduler for ExponentialDecayScheduler {
82 fn step(&mut self, current_lr: f32, _epoch: usize, _loss: f32) -> f32 {
83 current_lr * self.gamma
84 }
85
86 fn name(&self) -> &str {
87 "Exponential Decay"
88 }
89}
90
91struct CosineAnnealingScheduler {
93 t_max: usize,
94 eta_min: f32,
95 initial_lr: f32,
96}
97
98impl CosineAnnealingScheduler {
99 fn new(t_max: usize, eta_min: f32, initial_lr: f32) -> Self {
100 Self {
101 t_max,
102 eta_min,
103 initial_lr,
104 }
105 }
106}
107
108impl LearningRateScheduler for CosineAnnealingScheduler {
109 fn step(&mut self, _current_lr: f32, epoch: usize, _loss: f32) -> f32 {
110 let t = epoch as f32;
111 let t_max = self.t_max as f32;
112
113 self.eta_min
114 + 0.5
115 * (self.initial_lr - self.eta_min)
116 * (1.0 + (std::f32::consts::PI * t / t_max).cos())
117 }
118
119 fn name(&self) -> &str {
120 "Cosine Annealing"
121 }
122}
123
124struct AdaptiveScheduler {
126 patience: usize,
127 factor: f32,
128 min_lr: f32,
129 best_loss: f32,
130 patience_counter: usize,
131}
132
133impl AdaptiveScheduler {
134 fn new(patience: usize, factor: f32, min_lr: f32) -> Self {
135 Self {
136 patience,
137 factor,
138 min_lr,
139 best_loss: f32::INFINITY,
140 patience_counter: 0,
141 }
142 }
143}
144
145impl LearningRateScheduler for AdaptiveScheduler {
146 fn step(&mut self, current_lr: f32, _epoch: usize, loss: f32) -> f32 {
147 if loss < self.best_loss {
148 self.best_loss = loss;
149 self.patience_counter = 0;
150 current_lr
151 } else {
152 self.patience_counter += 1;
153 if self.patience_counter >= self.patience {
154 let new_lr = (current_lr * self.factor).max(self.min_lr);
155 self.patience_counter = 0;
156 new_lr
157 } else {
158 current_lr
159 }
160 }
161 }
162
163 fn name(&self) -> &str {
164 "Adaptive (Reduce on Plateau)"
165 }
166}
167
168#[derive(Debug)]
170#[allow(dead_code)]
171struct TrainingStats {
172 scheduler_name: String,
173 final_loss: f32,
174 lr_history: Vec<f32>,
175 loss_history: Vec<f32>,
176 convergence_epoch: usize,
177}
178
179fn main() -> Result<(), Box<dyn std::error::Error>> {
180 println!("=== Learning Rate Scheduling Example ===\n");
181
182 demonstrate_step_decay()?;
183 demonstrate_exponential_decay()?;
184 demonstrate_cosine_annealing()?;
185 demonstrate_adaptive_scheduling()?;
186 demonstrate_scheduler_comparison()?;
187
188 println!("\n=== Example completed successfully! ===");
189 Ok(())
190}
191
192fn demonstrate_step_decay() -> Result<(), Box<dyn std::error::Error>> {
194 println!("--- Step Decay Scheduling ---");
195
196 let mut scheduler = StepDecayScheduler::new(vec![25, 50, 75], 0.5);
197 let stats = train_with_scheduler(&mut scheduler, 100)?;
198
199 println!("Step decay results:");
200 println!(" Final loss: {:.6}", stats.final_loss);
201 println!(" Convergence epoch: {}", stats.convergence_epoch);
202 println!(" Learning rate schedule:");
203 for (i, &lr) in stats.lr_history.iter().enumerate().step_by(10) {
204 println!(" Epoch {:3}: LR = {:.6}", i, lr);
205 }
206
207 Ok(())
208}
209
210fn demonstrate_exponential_decay() -> Result<(), Box<dyn std::error::Error>> {
212 println!("\n--- Exponential Decay Scheduling ---");
213
214 let mut scheduler = ExponentialDecayScheduler::new(0.95);
215 let stats = train_with_scheduler(&mut scheduler, 100)?;
216
217 println!("Exponential decay results:");
218 println!(" Final loss: {:.6}", stats.final_loss);
219 println!(" Convergence epoch: {}", stats.convergence_epoch);
220 println!(" Learning rate schedule:");
221 for (i, &lr) in stats.lr_history.iter().enumerate().step_by(10) {
222 println!(" Epoch {:3}: LR = {:.6}", i, lr);
223 }
224
225 Ok(())
226}
227
228fn demonstrate_cosine_annealing() -> Result<(), Box<dyn std::error::Error>> {
230 println!("\n--- Cosine Annealing Scheduling ---");
231
232 let initial_lr = 0.1;
233 let mut scheduler = CosineAnnealingScheduler::new(100, 0.001, initial_lr);
234 let stats = train_with_scheduler(&mut scheduler, 100)?;
235
236 println!("Cosine annealing results:");
237 println!(" Final loss: {:.6}", stats.final_loss);
238 println!(" Convergence epoch: {}", stats.convergence_epoch);
239 println!(" Learning rate schedule:");
240 for (i, &lr) in stats.lr_history.iter().enumerate().step_by(10) {
241 println!(" Epoch {:3}: LR = {:.6}", i, lr);
242 }
243
244 Ok(())
245}
246
247fn demonstrate_adaptive_scheduling() -> Result<(), Box<dyn std::error::Error>> {
249 println!("\n--- Adaptive Scheduling ---");
250
251 let mut scheduler = AdaptiveScheduler::new(5, 0.5, 0.001);
252 let stats = train_with_scheduler(&mut scheduler, 100)?;
253
254 println!("Adaptive scheduling results:");
255 println!(" Final loss: {:.6}", stats.final_loss);
256 println!(" Convergence epoch: {}", stats.convergence_epoch);
257 println!(" Learning rate schedule:");
258 for (i, &lr) in stats.lr_history.iter().enumerate().step_by(10) {
259 println!(" Epoch {:3}: LR = {:.6}", i, lr);
260 }
261
262 Ok(())
263}
264
265fn demonstrate_scheduler_comparison() -> Result<(), Box<dyn std::error::Error>> {
267 println!("\n--- Scheduler Comparison ---");
268
269 let schedulers: Vec<Box<dyn LearningRateScheduler>> = vec![
270 Box::new(StepDecayScheduler::new(vec![30, 60], 0.5)),
271 Box::new(ExponentialDecayScheduler::new(0.98)),
272 Box::new(CosineAnnealingScheduler::new(100, 0.001, 0.05)),
273 Box::new(AdaptiveScheduler::new(8, 0.7, 0.001)),
274 ];
275
276 let mut results = Vec::new();
277
278 for mut scheduler in schedulers {
279 println!("\nTesting {} scheduler:", scheduler.name());
280
281 let stats = train_with_scheduler(scheduler.as_mut(), 100)?;
282 results.push(stats);
283
284 println!(" Final loss: {:.6}", results.last().unwrap().final_loss);
285 println!(
286 " Convergence: {} epochs",
287 results.last().unwrap().convergence_epoch
288 );
289 }
290
291 println!("\nScheduler Comparison Summary:");
293 println!(
294 " {:20} | {:10} | {:12} | {:12}",
295 "Scheduler", "Final Loss", "Convergence", "LR Range"
296 );
297 println!(" {}", "-".repeat(70));
298
299 for stats in &results {
300 let lr_range = format!(
301 "{:.0e} - {:.0e}",
302 stats
303 .lr_history
304 .iter()
305 .cloned()
306 .fold(f32::INFINITY, f32::min),
307 stats.lr_history.iter().cloned().fold(0.0, f32::max)
308 );
309 println!(
310 " {:20} | {:.6} | {:8} | {}",
311 stats.scheduler_name, stats.final_loss, stats.convergence_epoch, lr_range
312 );
313 }
314
315 Ok(())
316}
317
318fn train_with_scheduler(
320 scheduler: &mut dyn LearningRateScheduler,
321 num_epochs: usize,
322) -> Result<TrainingStats, Box<dyn std::error::Error>> {
323 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
325 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
326
327 let mut weight = Tensor::randn(vec![1, 1], Some(456)).with_requires_grad();
329 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
330
331 let mut optimizer = Adam::with_learning_rate(0.05);
333 optimizer.add_parameter(&weight);
334 optimizer.add_parameter(&bias);
335
336 let mut losses = Vec::new();
338 let mut lr_history = Vec::new();
339 let mut convergence_epoch = num_epochs;
340
341 for epoch in 0..num_epochs {
342 let y_pred = x_data.matmul(&weight) + &bias;
344 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
345
346 loss.backward(None);
348
349 let current_lr = optimizer.learning_rate();
351 let new_lr = scheduler.step(current_lr, epoch, loss.value());
352
353 if (new_lr - current_lr).abs() > 1e-8 {
354 optimizer.set_learning_rate(new_lr);
355 }
356
357 optimizer.step(&mut [&mut weight, &mut bias]);
359 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
360
361 let loss_value = loss.value();
362 losses.push(loss_value);
363 lr_history.push(new_lr);
364
365 if loss_value < 0.01 && convergence_epoch == num_epochs {
367 convergence_epoch = epoch;
368 }
369 }
370
371 Ok(TrainingStats {
372 scheduler_name: scheduler.name().to_string(),
373 final_loss: losses[losses.len() - 1],
374 lr_history,
375 loss_history: losses,
376 convergence_epoch,
377 })
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[test]
385 fn test_step_decay_scheduler() {
386 let mut scheduler = StepDecayScheduler::new(vec![5, 10], 0.5);
387 let mut lr = 0.1;
388
389 lr = scheduler.step(lr, 0, 0.0);
390 assert_eq!(lr, 0.1);
391
392 lr = scheduler.step(lr, 5, 0.0);
393 assert_eq!(lr, 0.05);
394
395 lr = scheduler.step(lr, 10, 0.0);
396 assert_eq!(lr, 0.025);
397 }
398
399 #[test]
400 fn test_exponential_decay_scheduler() {
401 let mut scheduler = ExponentialDecayScheduler::new(0.9);
402 let mut lr = 0.1;
403
404 lr = scheduler.step(lr, 0, 0.0);
405 assert!((lr - 0.09).abs() < 1e-6);
406 }
407
408 #[test]
409 fn test_cosine_annealing_scheduler() {
410 let mut scheduler = CosineAnnealingScheduler::new(10, 0.001, 0.1);
411
412 let lr_start = scheduler.step(0.0, 0, 0.0);
413 assert!((lr_start - 0.1).abs() < 1e-6);
414
415 let lr_mid = scheduler.step(0.0, 5, 0.0);
416 assert!((lr_mid - 0.0505).abs() < 1e-3); let lr_end = scheduler.step(0.0, 9, 0.0);
419 assert!((lr_end - 0.001).abs() < 1e-3);
420 }
421
422 #[test]
423 fn test_adaptive_scheduler() {
424 let mut scheduler = AdaptiveScheduler::new(2, 0.5, 0.001);
425 let mut lr = 0.1;
426
427 lr = scheduler.step(lr, 0, 0.5);
429 assert_eq!(lr, 0.1);
430
431 lr = scheduler.step(lr, 1, 0.6);
433 assert_eq!(lr, 0.1);
434
435 lr = scheduler.step(lr, 2, 0.6);
436 assert_eq!(lr, 0.05); lr = scheduler.step(lr, 3, 0.4);
440 assert_eq!(lr, 0.05);
441 }
442
443 #[test]
444 fn test_scheduler_training() {
445 let mut scheduler = StepDecayScheduler::new(vec![10], 0.5);
446 let stats = train_with_scheduler(&mut scheduler, 20).unwrap();
447
448 assert!(stats.final_loss < 1.0);
449 assert_eq!(stats.lr_history.len(), 20);
450 assert!(stats.convergence_epoch < 20);
451 }
452}