Skip to main content

tensorlogic_train/callbacks/
lr_finder.rs

1//! Learning rate finder callback using the LR range test.
2
3use crate::callbacks::core::Callback;
4use crate::{TrainResult, TrainingState};
5
6/// Learning rate finder callback using the LR range test.
7///
8/// This callback implements the learning rate range test proposed by Leslie N. Smith.
9/// It gradually increases the learning rate from a minimum to a maximum value over
10/// a specified number of iterations/epochs and tracks the loss at each step.
11///
12/// The optimal learning rate is typically found just before the loss starts to increase.
13///
14/// # Example
15/// ```rust,ignore
16/// use tensorlogic_train::{LearningRateFinder, CallbackList};
17///
18/// let mut callbacks = CallbackList::new();
19/// callbacks.add(Box::new(LearningRateFinder::new(
20///     1e-7,   // start_lr
21///     10.0,   // end_lr
22///     100,    // num_steps
23/// )));
24/// ```
25pub struct LearningRateFinder {
26    /// Starting learning rate.
27    start_lr: f64,
28    /// Ending learning rate.
29    end_lr: f64,
30    /// Number of steps to test.
31    num_steps: usize,
32    /// Current step.
33    current_step: usize,
34    /// History of (lr, loss) pairs.
35    pub history: Vec<(f64, f64)>,
36    /// Whether to use exponential or linear scaling.
37    exponential: bool,
38    /// Smoothing factor for loss (0.0 = no smoothing, 0.9 = heavy smoothing).
39    smoothing: f64,
40    /// Smoothed loss.
41    smoothed_loss: Option<f64>,
42}
43
44impl LearningRateFinder {
45    /// Create a new learning rate finder.
46    ///
47    /// # Arguments
48    /// * `start_lr` - Starting learning rate (e.g., 1e-7)
49    /// * `end_lr` - Ending learning rate (e.g., 10.0)
50    /// * `num_steps` - Number of steps to test
51    pub fn new(start_lr: f64, end_lr: f64, num_steps: usize) -> Self {
52        Self {
53            start_lr,
54            end_lr,
55            num_steps,
56            current_step: 0,
57            history: Vec::with_capacity(num_steps),
58            exponential: true, // Exponential scaling is recommended
59            smoothing: 0.0,    // No smoothing by default
60            smoothed_loss: None,
61        }
62    }
63
64    /// Enable exponential scaling (recommended, default).
65    pub fn with_exponential_scaling(mut self) -> Self {
66        self.exponential = true;
67        self
68    }
69
70    /// Enable linear scaling.
71    pub fn with_linear_scaling(mut self) -> Self {
72        self.exponential = false;
73        self
74    }
75
76    /// Set loss smoothing factor (0.0-1.0).
77    ///
78    /// Recommended: 0.9 for noisy losses, 0.0 for smooth losses.
79    pub fn with_smoothing(mut self, smoothing: f64) -> Self {
80        self.smoothing = smoothing.clamp(0.0, 1.0);
81        self
82    }
83
84    /// Compute the current learning rate based on step.
85    fn compute_lr(&self) -> f64 {
86        if self.num_steps <= 1 {
87            return self.start_lr;
88        }
89
90        let step_ratio = self.current_step as f64 / (self.num_steps - 1) as f64;
91
92        if self.exponential {
93            // Exponential scaling: lr = start_lr * (end_lr/start_lr)^step_ratio
94            self.start_lr * (self.end_lr / self.start_lr).powf(step_ratio)
95        } else {
96            // Linear scaling: lr = start_lr + (end_lr - start_lr) * step_ratio
97            self.start_lr + (self.end_lr - self.start_lr) * step_ratio
98        }
99    }
100
101    /// Get the smoothed loss.
102    fn smooth_loss(&mut self, loss: f64) -> f64 {
103        if self.smoothing == 0.0 {
104            return loss;
105        }
106
107        match self.smoothed_loss {
108            None => {
109                self.smoothed_loss = Some(loss);
110                loss
111            }
112            Some(prev) => {
113                let smoothed = self.smoothing * prev + (1.0 - self.smoothing) * loss;
114                self.smoothed_loss = Some(smoothed);
115                smoothed
116            }
117        }
118    }
119
120    /// Find the suggested optimal learning rate.
121    ///
122    /// Returns the learning rate with the steepest negative gradient (fastest decrease in loss).
123    pub fn suggest_lr(&self) -> Option<f64> {
124        if self.history.len() < 3 {
125            return None;
126        }
127
128        let mut best_lr = None;
129        let mut best_gradient = f64::INFINITY;
130
131        // Compute gradients and find steepest descent
132        for i in 1..self.history.len() {
133            let (lr1, loss1) = self.history[i - 1];
134            let (lr2, loss2) = self.history[i];
135
136            let gradient = (loss2 - loss1) / (lr2 - lr1);
137
138            if gradient < best_gradient {
139                best_gradient = gradient;
140                best_lr = Some(lr2);
141            }
142        }
143
144        best_lr
145    }
146
147    /// Print the LR finder results.
148    pub fn print_results(&self) {
149        println!("\n=== Learning Rate Finder Results ===");
150        println!(
151            "Tested {} learning rates from {:.2e} to {:.2e}",
152            self.history.len(),
153            self.start_lr,
154            self.end_lr
155        );
156
157        if let Some(suggested_lr) = self.suggest_lr() {
158            println!("Suggested optimal LR: {:.2e}", suggested_lr);
159            println!(
160                "Consider using LR between {:.2e} and {:.2e}",
161                suggested_lr / 10.0,
162                suggested_lr
163            );
164        }
165
166        println!("\nLR, Loss:");
167        for (lr, loss) in &self.history {
168            println!("{:.6e}, {:.6}", lr, loss);
169        }
170        println!("===================================\n");
171    }
172}
173
174impl Callback for LearningRateFinder {
175    fn on_batch_end(&mut self, _batch: usize, state: &TrainingState) -> TrainResult<()> {
176        if self.current_step >= self.num_steps {
177            return Ok(());
178        }
179
180        // Get current loss and smooth it
181        let loss = self.smooth_loss(state.batch_loss);
182
183        // Record (lr, loss) pair
184        let lr = self.compute_lr();
185        self.history.push((lr, loss));
186
187        self.current_step += 1;
188
189        // Note: The actual LR update happens via the trainer's optimizer
190        // This callback just tracks the relationship
191
192        Ok(())
193    }
194
195    fn should_stop(&self) -> bool {
196        // Stop after testing all LR values
197        self.current_step >= self.num_steps
198    }
199}