tensorlogic_train/callbacks/lr_finder.rs
1//! Learning rate finder callback using the LR range test.
2
3use crate::callbacks::core::Callback;
4use crate::{TrainResult, TrainingState};
5
6/// Learning rate finder callback using the LR range test.
7///
8/// This callback implements the learning rate range test proposed by Leslie N. Smith.
9/// It gradually increases the learning rate from a minimum to a maximum value over
10/// a specified number of iterations/epochs and tracks the loss at each step.
11///
12/// The optimal learning rate is typically found just before the loss starts to increase.
13///
14/// # Example
15/// ```rust,ignore
16/// use tensorlogic_train::{LearningRateFinder, CallbackList};
17///
18/// let mut callbacks = CallbackList::new();
19/// callbacks.add(Box::new(LearningRateFinder::new(
20/// 1e-7, // start_lr
21/// 10.0, // end_lr
22/// 100, // num_steps
23/// )));
24/// ```
25pub struct LearningRateFinder {
26 /// Starting learning rate.
27 start_lr: f64,
28 /// Ending learning rate.
29 end_lr: f64,
30 /// Number of steps to test.
31 num_steps: usize,
32 /// Current step.
33 current_step: usize,
34 /// History of (lr, loss) pairs.
35 pub history: Vec<(f64, f64)>,
36 /// Whether to use exponential or linear scaling.
37 exponential: bool,
38 /// Smoothing factor for loss (0.0 = no smoothing, 0.9 = heavy smoothing).
39 smoothing: f64,
40 /// Smoothed loss.
41 smoothed_loss: Option<f64>,
42}
43
44impl LearningRateFinder {
45 /// Create a new learning rate finder.
46 ///
47 /// # Arguments
48 /// * `start_lr` - Starting learning rate (e.g., 1e-7)
49 /// * `end_lr` - Ending learning rate (e.g., 10.0)
50 /// * `num_steps` - Number of steps to test
51 pub fn new(start_lr: f64, end_lr: f64, num_steps: usize) -> Self {
52 Self {
53 start_lr,
54 end_lr,
55 num_steps,
56 current_step: 0,
57 history: Vec::with_capacity(num_steps),
58 exponential: true, // Exponential scaling is recommended
59 smoothing: 0.0, // No smoothing by default
60 smoothed_loss: None,
61 }
62 }
63
64 /// Enable exponential scaling (recommended, default).
65 pub fn with_exponential_scaling(mut self) -> Self {
66 self.exponential = true;
67 self
68 }
69
70 /// Enable linear scaling.
71 pub fn with_linear_scaling(mut self) -> Self {
72 self.exponential = false;
73 self
74 }
75
76 /// Set loss smoothing factor (0.0-1.0).
77 ///
78 /// Recommended: 0.9 for noisy losses, 0.0 for smooth losses.
79 pub fn with_smoothing(mut self, smoothing: f64) -> Self {
80 self.smoothing = smoothing.clamp(0.0, 1.0);
81 self
82 }
83
84 /// Compute the current learning rate based on step.
85 fn compute_lr(&self) -> f64 {
86 if self.num_steps <= 1 {
87 return self.start_lr;
88 }
89
90 let step_ratio = self.current_step as f64 / (self.num_steps - 1) as f64;
91
92 if self.exponential {
93 // Exponential scaling: lr = start_lr * (end_lr/start_lr)^step_ratio
94 self.start_lr * (self.end_lr / self.start_lr).powf(step_ratio)
95 } else {
96 // Linear scaling: lr = start_lr + (end_lr - start_lr) * step_ratio
97 self.start_lr + (self.end_lr - self.start_lr) * step_ratio
98 }
99 }
100
101 /// Get the smoothed loss.
102 fn smooth_loss(&mut self, loss: f64) -> f64 {
103 if self.smoothing == 0.0 {
104 return loss;
105 }
106
107 match self.smoothed_loss {
108 None => {
109 self.smoothed_loss = Some(loss);
110 loss
111 }
112 Some(prev) => {
113 let smoothed = self.smoothing * prev + (1.0 - self.smoothing) * loss;
114 self.smoothed_loss = Some(smoothed);
115 smoothed
116 }
117 }
118 }
119
120 /// Find the suggested optimal learning rate.
121 ///
122 /// Returns the learning rate with the steepest negative gradient (fastest decrease in loss).
123 pub fn suggest_lr(&self) -> Option<f64> {
124 if self.history.len() < 3 {
125 return None;
126 }
127
128 let mut best_lr = None;
129 let mut best_gradient = f64::INFINITY;
130
131 // Compute gradients and find steepest descent
132 for i in 1..self.history.len() {
133 let (lr1, loss1) = self.history[i - 1];
134 let (lr2, loss2) = self.history[i];
135
136 let gradient = (loss2 - loss1) / (lr2 - lr1);
137
138 if gradient < best_gradient {
139 best_gradient = gradient;
140 best_lr = Some(lr2);
141 }
142 }
143
144 best_lr
145 }
146
147 /// Print the LR finder results.
148 pub fn print_results(&self) {
149 println!("\n=== Learning Rate Finder Results ===");
150 println!(
151 "Tested {} learning rates from {:.2e} to {:.2e}",
152 self.history.len(),
153 self.start_lr,
154 self.end_lr
155 );
156
157 if let Some(suggested_lr) = self.suggest_lr() {
158 println!("Suggested optimal LR: {:.2e}", suggested_lr);
159 println!(
160 "Consider using LR between {:.2e} and {:.2e}",
161 suggested_lr / 10.0,
162 suggested_lr
163 );
164 }
165
166 println!("\nLR, Loss:");
167 for (lr, loss) in &self.history {
168 println!("{:.6e}, {:.6}", lr, loss);
169 }
170 println!("===================================\n");
171 }
172}
173
174impl Callback for LearningRateFinder {
175 fn on_batch_end(&mut self, _batch: usize, state: &TrainingState) -> TrainResult<()> {
176 if self.current_step >= self.num_steps {
177 return Ok(());
178 }
179
180 // Get current loss and smooth it
181 let loss = self.smooth_loss(state.batch_loss);
182
183 // Record (lr, loss) pair
184 let lr = self.compute_lr();
185 self.history.push((lr, loss));
186
187 self.current_step += 1;
188
189 // Note: The actual LR update happens via the trainer's optimizer
190 // This callback just tracks the relationship
191
192 Ok(())
193 }
194
195 fn should_stop(&self) -> bool {
196 // Stop after testing all LR values
197 self.current_step >= self.num_steps
198 }
199}