scirs2_optimize/stochastic/
schedules.rs1use std::f64::consts::PI;
8
9pub trait LrSchedule: Send + Sync {
14 fn get_lr(&self, epoch: usize, base_lr: f64) -> f64;
23}
24
25#[derive(Debug, Clone)]
31pub struct StepDecay {
32 pub step_size: usize,
34 pub gamma: f64,
36}
37
38impl StepDecay {
39 pub fn new(step_size: usize, gamma: f64) -> Self {
45 Self { step_size, gamma }
46 }
47}
48
49impl LrSchedule for StepDecay {
50 fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
51 let steps = epoch / self.step_size.max(1);
52 base_lr * self.gamma.powi(steps as i32)
53 }
54}
55
56#[derive(Debug, Clone)]
66pub struct CosineAnnealing {
67 pub t_max: usize,
69 pub eta_min: f64,
71}
72
73impl CosineAnnealing {
74 pub fn new(t_max: usize, eta_min: f64) -> Self {
80 Self { t_max, eta_min }
81 }
82}
83
84impl LrSchedule for CosineAnnealing {
85 fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
86 let t_max = self.t_max.max(1) as f64;
87 let cos_val = (PI * epoch as f64 / t_max).cos();
88 self.eta_min + 0.5 * (base_lr - self.eta_min) * (1.0 + cos_val)
89 }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq)]
96pub enum AnnealStrategy {
97 Cos,
99 Linear,
101}
102
103#[derive(Debug, Clone)]
112pub struct OneCycle {
113 pub max_lr: f64,
115 pub pct_start: f64,
117 pub anneal_strategy: AnnealStrategy,
119 pub total_epochs: usize,
121 pub div_factor: f64,
123 pub final_div_factor: f64,
125}
126
127impl OneCycle {
128 pub fn new(
136 max_lr: f64,
137 pct_start: f64,
138 anneal_strategy: AnnealStrategy,
139 total_epochs: usize,
140 ) -> Self {
141 Self {
142 max_lr,
143 pct_start: pct_start.clamp(0.0, 1.0),
144 anneal_strategy,
145 total_epochs,
146 div_factor: 25.0,
147 final_div_factor: 1e4,
148 }
149 }
150
151 fn anneal(&self, start: f64, end: f64, pct: f64) -> f64 {
153 let p = pct.clamp(0.0, 1.0);
154 match self.anneal_strategy {
155 AnnealStrategy::Cos => end + (start - end) / 2.0 * (1.0 + (PI * p).cos()),
156 AnnealStrategy::Linear => start + (end - start) * p,
157 }
158 }
159}
160
161impl LrSchedule for OneCycle {
162 fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
163 let total = self.total_epochs.max(1) as f64;
164 let pct = epoch as f64 / total;
165 let init_lr = base_lr / self.div_factor;
166 let final_lr = init_lr / self.final_div_factor;
167
168 if pct <= self.pct_start {
169 let phase_pct = if self.pct_start > 0.0 {
171 pct / self.pct_start
172 } else {
173 1.0
174 };
175 self.anneal(init_lr, self.max_lr, phase_pct)
176 } else {
177 let phase_pct = (pct - self.pct_start) / (1.0 - self.pct_start).max(1e-9);
179 self.anneal(self.max_lr, final_lr, phase_pct)
180 }
181 }
182}
183
184#[derive(Debug, Clone)]
194pub struct WarmupCosine {
195 pub warmup_steps: usize,
197 pub total_steps: usize,
199 pub min_lr: f64,
201}
202
203impl WarmupCosine {
204 pub fn new(warmup_steps: usize, total_steps: usize, min_lr: f64) -> Self {
211 Self {
212 warmup_steps,
213 total_steps,
214 min_lr,
215 }
216 }
217}
218
219impl LrSchedule for WarmupCosine {
220 fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
221 if epoch < self.warmup_steps {
222 let warmup = self.warmup_steps.max(1) as f64;
224 base_lr * epoch as f64 / warmup
225 } else {
226 let decay_steps = (self.total_steps.saturating_sub(self.warmup_steps)).max(1) as f64;
228 let step = (epoch - self.warmup_steps) as f64;
229 let cos_val = (PI * step / decay_steps).cos();
230 self.min_lr + 0.5 * (base_lr - self.min_lr) * (1.0 + cos_val)
231 }
232 }
233}
234
235#[derive(Debug, Clone)]
241pub struct ExponentialDecay {
242 pub gamma: f64,
244}
245
246impl ExponentialDecay {
247 pub fn new(gamma: f64) -> Self {
249 Self { gamma }
250 }
251}
252
253impl LrSchedule for ExponentialDecay {
254 fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
255 base_lr * self.gamma.powi(epoch as i32)
256 }
257}
258
259#[derive(Debug, Clone, Default)]
263pub struct ConstantLr;
264
265impl LrSchedule for ConstantLr {
266 fn get_lr(&self, _epoch: usize, base_lr: f64) -> f64 {
267 base_lr
268 }
269}
270
271#[derive(Debug, Clone)]
277pub struct PolynomialDecay {
278 pub total_epochs: usize,
280 pub power: f64,
282 pub end_lr: f64,
284}
285
286impl PolynomialDecay {
287 pub fn new(total_epochs: usize, power: f64, end_lr: f64) -> Self {
289 Self {
290 total_epochs,
291 power,
292 end_lr,
293 }
294 }
295}
296
297impl LrSchedule for PolynomialDecay {
298 fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
299 let total = self.total_epochs.max(1);
300 if epoch >= total {
301 return self.end_lr;
302 }
303 let decay = (1.0 - epoch as f64 / total as f64).powf(self.power);
304 let lr = (base_lr - self.end_lr) * decay + self.end_lr;
305 lr.max(self.end_lr)
306 }
307}
308
309#[derive(Debug, Clone)]
316pub struct CyclicLr {
317 pub base_lr: f64,
319 pub max_lr: f64,
321 pub step_size: usize,
323}
324
325impl CyclicLr {
326 pub fn new(base_lr: f64, max_lr: f64, step_size: usize) -> Self {
328 Self {
329 base_lr,
330 max_lr,
331 step_size: step_size.max(1),
332 }
333 }
334}
335
336impl LrSchedule for CyclicLr {
337 fn get_lr(&self, epoch: usize, _base_lr: f64) -> f64 {
338 let cycle = epoch / (2 * self.step_size);
339 let x = (epoch as f64 / self.step_size as f64) - 2.0 * cycle as f64 - 1.0;
340 let scale = (1.0 - x.abs()).max(0.0);
341 self.base_lr + (self.max_lr - self.base_lr) * scale
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use approx::assert_abs_diff_eq;
349
350 #[test]
351 fn test_step_decay() {
352 let sched = StepDecay::new(10, 0.5);
353 assert_abs_diff_eq!(sched.get_lr(0, 0.1), 0.1, epsilon = 1e-12);
354 assert_abs_diff_eq!(sched.get_lr(9, 0.1), 0.1, epsilon = 1e-12);
355 assert_abs_diff_eq!(sched.get_lr(10, 0.1), 0.05, epsilon = 1e-12);
356 assert_abs_diff_eq!(sched.get_lr(20, 0.1), 0.025, epsilon = 1e-12);
357 }
358
359 #[test]
360 fn test_cosine_annealing() {
361 let sched = CosineAnnealing::new(100, 0.0);
362 let lr_start = sched.get_lr(0, 1.0);
363 let lr_mid = sched.get_lr(50, 1.0);
364 let lr_end = sched.get_lr(100, 1.0);
365 assert_abs_diff_eq!(lr_start, 1.0, epsilon = 1e-12);
366 assert_abs_diff_eq!(lr_mid, 0.5, epsilon = 1e-10);
367 assert_abs_diff_eq!(lr_end, 0.0, epsilon = 1e-12);
368 }
369
370 #[test]
371 fn test_one_cycle_warmup_peak() {
372 let sched = OneCycle::new(0.1, 0.3, AnnealStrategy::Cos, 100);
373 let lr_start = sched.get_lr(0, 0.01);
375 let lr_peak = sched.get_lr(30, 0.01);
377 assert!(lr_peak >= lr_start, "peak must exceed start");
378 assert_abs_diff_eq!(lr_peak, sched.max_lr, epsilon = 1e-10);
379 }
380
381 #[test]
382 fn test_warmup_cosine() {
383 let sched = WarmupCosine::new(10, 100, 0.0);
384 assert_abs_diff_eq!(sched.get_lr(0, 1.0), 0.0, epsilon = 1e-12);
386 assert_abs_diff_eq!(sched.get_lr(5, 1.0), 0.5, epsilon = 1e-12);
387 assert_abs_diff_eq!(sched.get_lr(10, 1.0), 1.0, epsilon = 1e-12);
388 let lr_after = sched.get_lr(55, 1.0);
390 assert!(lr_after < 1.0, "should decay after warmup");
391 assert!(lr_after >= 0.0, "should not go below min_lr");
392 }
393
394 #[test]
395 fn test_exponential_decay() {
396 let sched = ExponentialDecay::new(0.9);
397 assert_abs_diff_eq!(sched.get_lr(0, 1.0), 1.0, epsilon = 1e-12);
398 assert_abs_diff_eq!(sched.get_lr(1, 1.0), 0.9, epsilon = 1e-12);
399 assert_abs_diff_eq!(sched.get_lr(2, 1.0), 0.81, epsilon = 1e-12);
400 }
401
402 #[test]
403 fn test_constant_lr() {
404 let sched = ConstantLr;
405 for epoch in 0..100 {
406 assert_abs_diff_eq!(sched.get_lr(epoch, 0.01), 0.01, epsilon = 1e-12);
407 }
408 }
409
410 #[test]
411 fn test_cyclic_lr() {
412 let sched = CyclicLr::new(0.001, 0.01, 5);
413 let lr0 = sched.get_lr(0, 0.0);
415 let lr5 = sched.get_lr(5, 0.0);
417 assert_abs_diff_eq!(lr5, sched.max_lr, epsilon = 1e-10);
418 let lr10 = sched.get_lr(10, 0.0);
420 assert_abs_diff_eq!(lr10, sched.base_lr, epsilon = 1e-10);
421 assert!(lr5 > lr0, "peak should exceed start");
422 }
423}