Skip to main content

yscv_model/
lr_finder.rs

1/// Result of an LR range test.
2#[derive(Debug, Clone)]
3pub struct LrFinderResult {
4    /// LR values used at each step.
5    pub lr_values: Vec<f32>,
6    /// Loss values recorded at each step (smoothed).
7    pub loss_values: Vec<f32>,
8    /// Suggested LR (steepest loss decrease).
9    pub suggested_lr: f32,
10}
11
12/// Configuration for LR range test.
13#[derive(Debug, Clone)]
14pub struct LrFinderConfig {
15    /// Starting learning rate.
16    pub start_lr: f32,
17    /// Ending learning rate.
18    pub end_lr: f32,
19    /// Number of steps.
20    pub num_steps: usize,
21    /// Whether to use exponential (`true`) or linear (`false`) schedule.
22    pub log_scale: bool,
23    /// Smoothing factor for loss (exponential moving average).
24    pub smoothing: f32,
25}
26
27impl Default for LrFinderConfig {
28    fn default() -> Self {
29        Self {
30            start_lr: 1e-7,
31            end_lr: 10.0,
32            num_steps: 100,
33            log_scale: true,
34            smoothing: 0.05,
35        }
36    }
37}
38
39/// Run an LR range test.
40///
41/// This generates a schedule of learning rates and calls `compute_loss` for each
42/// one. The user-supplied closure should set the optimiser LR and run one training
43/// step, returning the resulting loss.
44///
45/// The function applies exponential-moving-average smoothing, then finds the LR
46/// at which the smoothed loss decreased most steeply.
47pub fn lr_range_test<F>(config: &LrFinderConfig, mut compute_loss: F) -> LrFinderResult
48where
49    F: FnMut(f32) -> f32,
50{
51    assert!(config.num_steps >= 2, "num_steps must be at least 2");
52
53    // Generate LR values.
54    let lr_values: Vec<f32> = (0..config.num_steps)
55        .map(|i| {
56            let t = i as f32 / (config.num_steps - 1) as f32;
57            if config.log_scale {
58                let log_start = config.start_lr.ln();
59                let log_end = config.end_lr.ln();
60                (log_start + t * (log_end - log_start)).exp()
61            } else {
62                config.start_lr + t * (config.end_lr - config.start_lr)
63            }
64        })
65        .collect();
66
67    // Collect raw losses.
68    let raw_losses: Vec<f32> = lr_values.iter().map(|&lr| compute_loss(lr)).collect();
69
70    // Apply exponential moving average smoothing.
71    let beta = config.smoothing;
72    let mut smoothed = Vec::with_capacity(config.num_steps);
73    smoothed.push(raw_losses[0]);
74    for i in 1..config.num_steps {
75        let s = beta * raw_losses[i] + (1.0 - beta) * smoothed[i - 1];
76        smoothed.push(s);
77    }
78
79    // Find steepest loss decrease (max negative derivative).
80    let mut best_idx = 0usize;
81    let mut best_drop = f32::INFINITY; // most negative = best
82    for i in 1..config.num_steps {
83        let drop = smoothed[i] - smoothed[i - 1];
84        if drop < best_drop {
85            best_drop = drop;
86            best_idx = i;
87        }
88    }
89
90    LrFinderResult {
91        suggested_lr: lr_values[best_idx],
92        lr_values,
93        loss_values: smoothed,
94    }
95}