1#[derive(Debug, Clone)]
3pub struct LrFinderResult {
4 pub lr_values: Vec<f32>,
6 pub loss_values: Vec<f32>,
8 pub suggested_lr: f32,
10}
11
12#[derive(Debug, Clone)]
14pub struct LrFinderConfig {
15 pub start_lr: f32,
17 pub end_lr: f32,
19 pub num_steps: usize,
21 pub log_scale: bool,
23 pub smoothing: f32,
25}
26
27impl Default for LrFinderConfig {
28 fn default() -> Self {
29 Self {
30 start_lr: 1e-7,
31 end_lr: 10.0,
32 num_steps: 100,
33 log_scale: true,
34 smoothing: 0.05,
35 }
36 }
37}
38
39pub fn lr_range_test<F>(config: &LrFinderConfig, mut compute_loss: F) -> LrFinderResult
48where
49 F: FnMut(f32) -> f32,
50{
51 assert!(config.num_steps >= 2, "num_steps must be at least 2");
52
53 let lr_values: Vec<f32> = (0..config.num_steps)
55 .map(|i| {
56 let t = i as f32 / (config.num_steps - 1) as f32;
57 if config.log_scale {
58 let log_start = config.start_lr.ln();
59 let log_end = config.end_lr.ln();
60 (log_start + t * (log_end - log_start)).exp()
61 } else {
62 config.start_lr + t * (config.end_lr - config.start_lr)
63 }
64 })
65 .collect();
66
67 let raw_losses: Vec<f32> = lr_values.iter().map(|&lr| compute_loss(lr)).collect();
69
70 let beta = config.smoothing;
72 let mut smoothed = Vec::with_capacity(config.num_steps);
73 smoothed.push(raw_losses[0]);
74 for i in 1..config.num_steps {
75 let s = beta * raw_losses[i] + (1.0 - beta) * smoothed[i - 1];
76 smoothed.push(s);
77 }
78
79 let mut best_idx = 0usize;
81 let mut best_drop = f32::INFINITY; for i in 1..config.num_steps {
83 let drop = smoothed[i] - smoothed[i - 1];
84 if drop < best_drop {
85 best_drop = drop;
86 best_idx = i;
87 }
88 }
89
90 LrFinderResult {
91 suggested_lr: lr_values[best_idx],
92 lr_values,
93 loss_values: smoothed,
94 }
95}