Skip to main content

scirs2_optimize/stochastic/
optimizers.rs

1//! Stateful first-order stochastic optimizers with learning rate schedules.
2//!
3//! Provides clean, struct-based implementations of:
4//! - **SGD** — vanilla, momentum, Nesterov momentum, weight-decay
5//! - **Adam** — adaptive moment estimation
6//! - **AdaGrad** — cumulative gradient-square scaling
7//! - **RMSprop** — exponential moving-average second moment
8//! - **AdamW** — Adam with decoupled weight decay
9//! - **SVRG** — Stochastic Variance Reduced Gradient
10//! - **LrSchedule** — a rich enum of learning rate schedules
11//!
12//! # Usage pattern
13//!
14//! ```rust
15//! use scirs2_optimize::stochastic::optimizers::{Sgd, LrSchedule};
16//!
17//! let mut opt = Sgd::new(0.01, 0.9);
18//! let mut params = vec![1.0_f64, -2.0];
19//! for _ in 0..100 {
20//!     let grad = params.iter().map(|&p| 2.0 * p).collect::<Vec<_>>();
21//!     opt.step(&mut params, &grad).expect("valid input");
22//! }
23//! ```
24
25use crate::error::{OptimizeError, OptimizeResult};
26use std::f64::consts::PI;
27
28// ─────────────────────────────────────────────────────────────────────────────
29// Learning rate schedule
30// ─────────────────────────────────────────────────────────────────────────────
31
32/// Learning rate schedule variants.
33///
34/// All variants are pure functions of the current step count and return the
35/// effective learning rate to use at that step.
36#[derive(Debug, Clone)]
37pub enum LrSchedule {
38    /// Constant learning rate.
39    Constant(f64),
40
41    /// Exponential decay: lr = initial * decay^step.
42    ExponentialDecay {
43        /// Initial learning rate.
44        initial: f64,
45        /// Per-step multiplicative decay factor (0 < decay < 1).
46        decay: f64,
47    },
48
49    /// Cosine annealing: lr oscillates between lr_min and lr_max over t_max steps.
50    CosineAnnealing {
51        /// Maximum (initial) learning rate.
52        lr_max: f64,
53        /// Minimum learning rate at the end of the cycle.
54        lr_min: f64,
55        /// Half-period (number of steps).
56        t_max: usize,
57    },
58
59    /// Linear warmup followed by cosine annealing.
60    WarmupCosine {
61        /// Number of linear warmup steps (lr 0 → lr_peak).
62        warmup_steps: usize,
63        /// Peak learning rate reached after warmup.
64        lr_peak: f64,
65        /// Minimum learning rate at the end.
66        lr_min: f64,
67        /// Total number of steps (warmup + cosine phase).
68        total_steps: usize,
69    },
70
71    /// Step decay: lr is multiplied by gamma every step_size steps.
72    StepLr {
73        /// Initial learning rate.
74        initial: f64,
75        /// Number of steps between reductions.
76        step_size: usize,
77        /// Multiplicative reduction factor per epoch.
78        gamma: f64,
79    },
80}
81
82impl LrSchedule {
83    /// Return the effective learning rate at the given step index.
84    ///
85    /// # Examples
86    ///
87    /// ```
88    /// use scirs2_optimize::stochastic::optimizers::LrSchedule;
89    /// let sched = LrSchedule::Constant(0.01);
90    /// assert!((sched.lr_at(42) - 0.01).abs() < 1e-14);
91    /// ```
92    pub fn lr_at(&self, step: usize) -> f64 {
93        match self {
94            LrSchedule::Constant(lr) => *lr,
95
96            LrSchedule::ExponentialDecay { initial, decay } => initial * decay.powi(step as i32),
97
98            LrSchedule::CosineAnnealing {
99                lr_max,
100                lr_min,
101                t_max,
102            } => {
103                let t = (step % (2 * (*t_max).max(1))) as f64;
104                let t_m = *t_max as f64;
105                let cos_inner = PI * t / t_m;
106                lr_min + 0.5 * (lr_max - lr_min) * (1.0 + cos_inner.cos())
107            }
108
109            LrSchedule::WarmupCosine {
110                warmup_steps,
111                lr_peak,
112                lr_min,
113                total_steps,
114            } => {
115                let ws = *warmup_steps;
116                let ts = (*total_steps).max(ws + 1);
117                if step < ws {
118                    // Linear warmup
119                    lr_peak * step as f64 / ws.max(1) as f64
120                } else {
121                    // Cosine decay from lr_peak to lr_min
122                    let progress = (step - ws) as f64 / (ts - ws) as f64;
123                    lr_min + 0.5 * (lr_peak - lr_min) * (1.0 + (PI * progress).cos())
124                }
125            }
126
127            LrSchedule::StepLr {
128                initial,
129                step_size,
130                gamma,
131            } => {
132                let n_decays = step / (*step_size).max(1);
133                initial * gamma.powi(n_decays as i32)
134            }
135        }
136    }
137}
138
139// ─────────────────────────────────────────────────────────────────────────────
140// SGD
141// ─────────────────────────────────────────────────────────────────────────────
142
143/// Stochastic Gradient Descent with optional momentum, Nesterov lookahead,
144/// and L2 weight decay.
145///
146/// Update rule (with momentum):
147///   v ← momentum · v + lr · (grad + weight_decay · params)
148///   params ← params - v        (vanilla momentum)
149///   params ← params - (momentum · v + lr · grad)   (Nesterov)
150#[derive(Debug, Clone)]
151pub struct Sgd {
152    /// Base learning rate.
153    pub learning_rate: f64,
154    /// Momentum coefficient (0 = vanilla SGD).
155    pub momentum: f64,
156    /// L2 weight-decay coefficient.
157    pub weight_decay: f64,
158    /// Use Nesterov momentum instead of classical momentum.
159    pub nesterov: bool,
160    /// Velocity buffer (initialised lazily).
161    velocity: Vec<f64>,
162}
163
164impl Sgd {
165    /// Create a new SGD optimizer.
166    pub fn new(learning_rate: f64, momentum: f64) -> Self {
167        Self {
168            learning_rate,
169            momentum,
170            weight_decay: 0.0,
171            nesterov: false,
172            velocity: Vec::new(),
173        }
174    }
175
176    /// Perform one SGD update step.
177    ///
178    /// # Errors
179    /// Returns [`OptimizeError::InvalidInput`] if `params` and `grad` have
180    /// different lengths.
181    pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> OptimizeResult<()> {
182        if params.len() != grad.len() {
183            return Err(OptimizeError::InvalidInput(format!(
184                "params length {} != grad length {}",
185                params.len(),
186                grad.len()
187            )));
188        }
189
190        let n = params.len();
191        if self.velocity.len() != n {
192            self.velocity = vec![0.0; n];
193        }
194
195        let lr = self.learning_rate;
196        let mu = self.momentum;
197        let wd = self.weight_decay;
198
199        if self.nesterov {
200            for i in 0..n {
201                let g = grad[i] + wd * params[i];
202                self.velocity[i] = mu * self.velocity[i] + g;
203                params[i] -= lr * (mu * self.velocity[i] + g);
204            }
205        } else {
206            for i in 0..n {
207                let g = grad[i] + wd * params[i];
208                self.velocity[i] = mu * self.velocity[i] + g;
209                params[i] -= lr * self.velocity[i];
210            }
211        }
212        Ok(())
213    }
214
215    /// Reset velocity buffer to zeros (for new training run).
216    pub fn zero_velocity(&mut self, n: usize) {
217        self.velocity = vec![0.0; n];
218    }
219}
220
221// ─────────────────────────────────────────────────────────────────────────────
222// Adam
223// ─────────────────────────────────────────────────────────────────────────────
224
225/// Adam optimizer (Kingma & Ba 2015).
226///
227/// Update rule:
228///   m ← β₁ m + (1-β₁) g
229///   v ← β₂ v + (1-β₂) g²
230///   m̂ = m / (1-β₁ᵗ),  v̂ = v / (1-β₂ᵗ)
231///   params ← params - lr · m̂ / (√v̂ + ε)
232#[derive(Debug, Clone)]
233pub struct Adam {
234    /// Learning rate.
235    pub lr: f64,
236    /// First moment decay (default 0.9).
237    pub beta1: f64,
238    /// Second moment decay (default 0.999).
239    pub beta2: f64,
240    /// Numerical stability constant (default 1e-8).
241    pub eps: f64,
242    /// L2 weight-decay (added to gradient, not decoupled).
243    pub weight_decay: f64,
244    /// First moment estimate.
245    m: Vec<f64>,
246    /// Second moment estimate.
247    v: Vec<f64>,
248    /// Step counter (1-indexed).
249    t: usize,
250}
251
252impl Adam {
253    /// Create a new Adam optimizer with default β₁=0.9, β₂=0.999, ε=1e-8.
254    pub fn new(lr: f64) -> Self {
255        Self {
256            lr,
257            beta1: 0.9,
258            beta2: 0.999,
259            eps: 1e-8,
260            weight_decay: 0.0,
261            m: Vec::new(),
262            v: Vec::new(),
263            t: 0,
264        }
265    }
266
267    /// Perform one Adam update step.
268    pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> OptimizeResult<()> {
269        if params.len() != grad.len() {
270            return Err(OptimizeError::InvalidInput(format!(
271                "params length {} != grad length {}",
272                params.len(),
273                grad.len()
274            )));
275        }
276
277        let n = params.len();
278        if self.m.len() != n {
279            self.m = vec![0.0; n];
280            self.v = vec![0.0; n];
281        }
282
283        self.t += 1;
284        let t = self.t as f64;
285        let bias_corr1 = 1.0 - self.beta1.powf(t);
286        let bias_corr2 = 1.0 - self.beta2.powf(t);
287
288        for i in 0..n {
289            let g = grad[i] + self.weight_decay * params[i];
290            self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
291            self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
292            let m_hat = self.m[i] / bias_corr1;
293            let v_hat = self.v[i] / bias_corr2;
294            params[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
295        }
296        Ok(())
297    }
298
299    /// Reset all state (moments and step counter).
300    pub fn reset_state(&mut self) {
301        self.m.clear();
302        self.v.clear();
303        self.t = 0;
304    }
305}
306
307// ─────────────────────────────────────────────────────────────────────────────
308// AdaGrad
309// ─────────────────────────────────────────────────────────────────────────────
310
311/// AdaGrad optimizer (Duchi et al. 2011).
312///
313/// Accumulates the sum of squared gradients per parameter and scales the
314/// learning rate accordingly:
315///   G ← G + g²
316///   params ← params - lr · g / (√G + ε)
317#[derive(Debug, Clone)]
318pub struct AdaGrad {
319    /// Base learning rate.
320    pub lr: f64,
321    /// Numerical stability constant.
322    pub eps: f64,
323    /// L2 weight-decay (coupled).
324    pub weight_decay: f64,
325    /// Accumulated squared-gradient sum.
326    sum_sq_grad: Vec<f64>,
327}
328
329impl AdaGrad {
330    /// Create a new AdaGrad optimizer.
331    pub fn new(lr: f64) -> Self {
332        Self {
333            lr,
334            eps: 1e-8,
335            weight_decay: 0.0,
336            sum_sq_grad: Vec::new(),
337        }
338    }
339
340    /// Perform one AdaGrad update step.
341    pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> OptimizeResult<()> {
342        if params.len() != grad.len() {
343            return Err(OptimizeError::InvalidInput(format!(
344                "params/grad length mismatch: {} vs {}",
345                params.len(),
346                grad.len()
347            )));
348        }
349
350        let n = params.len();
351        if self.sum_sq_grad.len() != n {
352            self.sum_sq_grad = vec![0.0; n];
353        }
354
355        for i in 0..n {
356            let g = grad[i] + self.weight_decay * params[i];
357            self.sum_sq_grad[i] += g * g;
358            params[i] -= self.lr * g / (self.sum_sq_grad[i].sqrt() + self.eps);
359        }
360        Ok(())
361    }
362
363    /// Reset accumulated state.
364    pub fn reset_state(&mut self) {
365        self.sum_sq_grad.clear();
366    }
367}
368
369// ─────────────────────────────────────────────────────────────────────────────
370// RMSprop
371// ─────────────────────────────────────────────────────────────────────────────
372
373/// RMSprop optimizer (Hinton 2012).
374///
375/// Maintains an exponential moving average of squared gradients:
376///   E\[g²\] ← α·E\[g²\] + (1-α)·g²
377///   params ← params - lr · g / (√E\[g²\] + ε)
378///
379/// With momentum > 0:
380///   vel ← momentum·vel + lr · g / (√E\[g²\] + ε)
381///   params ← params - vel
382#[derive(Debug, Clone)]
383pub struct RmsProp {
384    /// Base learning rate.
385    pub lr: f64,
386    /// Decay factor for the moving average (default 0.99).
387    pub alpha: f64,
388    /// Numerical stability constant.
389    pub eps: f64,
390    /// Momentum coefficient (0 = no momentum).
391    pub momentum: f64,
392    /// Moving-average of squared gradients.
393    sq_avg: Vec<f64>,
394    /// Momentum buffer.
395    velocity: Vec<f64>,
396}
397
398impl RmsProp {
399    /// Create a new RMSprop optimizer with default α=0.99.
400    pub fn new(lr: f64) -> Self {
401        Self {
402            lr,
403            alpha: 0.99,
404            eps: 1e-8,
405            momentum: 0.0,
406            sq_avg: Vec::new(),
407            velocity: Vec::new(),
408        }
409    }
410
411    /// Perform one RMSprop update step.
412    pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> OptimizeResult<()> {
413        if params.len() != grad.len() {
414            return Err(OptimizeError::InvalidInput(format!(
415                "params/grad length mismatch: {} vs {}",
416                params.len(),
417                grad.len()
418            )));
419        }
420
421        let n = params.len();
422        if self.sq_avg.len() != n {
423            self.sq_avg = vec![0.0; n];
424            self.velocity = vec![0.0; n];
425        }
426
427        for i in 0..n {
428            let g = grad[i];
429            self.sq_avg[i] = self.alpha * self.sq_avg[i] + (1.0 - self.alpha) * g * g;
430            let denom = self.sq_avg[i].sqrt() + self.eps;
431            if self.momentum > 0.0 {
432                self.velocity[i] = self.momentum * self.velocity[i] + self.lr * g / denom;
433                params[i] -= self.velocity[i];
434            } else {
435                params[i] -= self.lr * g / denom;
436            }
437        }
438        Ok(())
439    }
440
441    /// Reset accumulated state.
442    pub fn reset_state(&mut self) {
443        self.sq_avg.clear();
444        self.velocity.clear();
445    }
446}
447
448// ─────────────────────────────────────────────────────────────────────────────
449// AdamW
450// ─────────────────────────────────────────────────────────────────────────────
451
452/// AdamW optimizer (Loshchilov & Hutter 2017) — Adam with decoupled weight decay.
453///
454/// Unlike Adam (which couples weight decay into the gradient), AdamW applies
455/// weight decay directly to the parameters before the gradient update:
456///   params ← params - lr · weight_decay · params   (L2 shrinkage)
457///   then Adam update on the raw gradient
458#[derive(Debug, Clone)]
459pub struct AdamW {
460    /// Learning rate.
461    pub lr: f64,
462    /// First moment decay.
463    pub beta1: f64,
464    /// Second moment decay.
465    pub beta2: f64,
466    /// Numerical stability constant.
467    pub eps: f64,
468    /// Decoupled L2 weight-decay coefficient.
469    pub weight_decay: f64,
470    /// First moment.
471    m: Vec<f64>,
472    /// Second moment.
473    v: Vec<f64>,
474    /// Step counter.
475    t: usize,
476}
477
478impl AdamW {
479    /// Create a new AdamW optimizer with default β₁=0.9, β₂=0.999, ε=1e-8,
480    /// weight_decay=0.01.
481    pub fn new(lr: f64) -> Self {
482        Self {
483            lr,
484            beta1: 0.9,
485            beta2: 0.999,
486            eps: 1e-8,
487            weight_decay: 0.01,
488            m: Vec::new(),
489            v: Vec::new(),
490            t: 0,
491        }
492    }
493
494    /// Perform one AdamW update step.
495    pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> OptimizeResult<()> {
496        if params.len() != grad.len() {
497            return Err(OptimizeError::InvalidInput(format!(
498                "params/grad length mismatch: {} vs {}",
499                params.len(),
500                grad.len()
501            )));
502        }
503
504        let n = params.len();
505        if self.m.len() != n {
506            self.m = vec![0.0; n];
507            self.v = vec![0.0; n];
508        }
509
510        self.t += 1;
511        let t = self.t as f64;
512        let bc1 = 1.0 - self.beta1.powf(t);
513        let bc2 = 1.0 - self.beta2.powf(t);
514
515        for i in 0..n {
516            // Decoupled weight decay: shrink parameters first
517            params[i] *= 1.0 - self.lr * self.weight_decay;
518
519            let g = grad[i];
520            self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
521            self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
522            let m_hat = self.m[i] / bc1;
523            let v_hat = self.v[i] / bc2;
524            params[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
525        }
526        Ok(())
527    }
528
529    /// Reset all state.
530    pub fn reset_state(&mut self) {
531        self.m.clear();
532        self.v.clear();
533        self.t = 0;
534    }
535}
536
537// ─────────────────────────────────────────────────────────────────────────────
538// SVRG
539// ─────────────────────────────────────────────────────────────────────────────
540
541/// Stochastic Variance Reduced Gradient (SVRG; Johnson & Zhang 2013).
542///
543/// SVRG maintains a snapshot of parameters and the corresponding full gradient,
544/// and uses a control variate to reduce variance of the stochastic gradient
545/// estimate:
546///
547///   g̃ = g_i(x) - g_i(x̃) + ∇f(x̃)
548///   x ← x - lr · g̃
549///
550/// The snapshot (x̃, ∇f(x̃)) must be updated every `update_freq` inner steps
551/// by calling [`Svrg::update_snapshot`].
552#[derive(Debug, Clone)]
553pub struct Svrg {
554    /// Learning rate.
555    pub lr: f64,
556    /// Dataset size (used for normalization in documentation, not internally).
557    pub n: usize,
558    /// Number of inner steps between snapshot updates.
559    pub update_freq: usize,
560    /// Snapshot of parameters x̃.
561    snapshot_params: Vec<f64>,
562    /// Full gradient at snapshot ∇f(x̃).
563    snapshot_grad: Vec<f64>,
564    /// Inner iteration counter.
565    inner_t: usize,
566}
567
568impl Svrg {
569    /// Create a new SVRG optimizer.
570    ///
571    /// # Arguments
572    /// * `lr`          – Learning rate.
573    /// * `n`           – Dataset size.
574    /// * `update_freq` – Inner-loop length (snapshot updated every this many steps).
575    pub fn new(lr: f64, n: usize, update_freq: usize) -> Self {
576        Self {
577            lr,
578            n,
579            update_freq,
580            snapshot_params: Vec::new(),
581            snapshot_grad: Vec::new(),
582            inner_t: 0,
583        }
584    }
585
586    /// Perform one SVRG inner-loop update.
587    ///
588    /// # Arguments
589    /// * `params`           – Current parameters (modified in place).
590    /// * `stochastic_grad`  – Mini-batch gradient at current `params`.
591    /// * `snapshot_grad_i`  – Mini-batch gradient at the snapshot params (same mini-batch).
592    ///
593    /// # Errors
594    /// Returns [`OptimizeError::InvalidInput`] on length mismatches, or if
595    /// `update_snapshot` has not been called first.
596    pub fn step(
597        &mut self,
598        params: &mut Vec<f64>,
599        stochastic_grad: &[f64],
600        snapshot_grad_i: &[f64],
601    ) -> OptimizeResult<()> {
602        let n = params.len();
603
604        if stochastic_grad.len() != n || snapshot_grad_i.len() != n {
605            return Err(OptimizeError::InvalidInput(format!(
606                "SVRG gradient/param length mismatch: params={}, sg={}, sgi={}",
607                n,
608                stochastic_grad.len(),
609                snapshot_grad_i.len()
610            )));
611        }
612
613        if self.snapshot_grad.len() != n {
614            return Err(OptimizeError::InvalidInput(
615                "SVRG: snapshot not initialised — call update_snapshot first".to_string(),
616            ));
617        }
618
619        // Variance-reduced gradient estimate
620        for i in 0..n {
621            let g_tilde = stochastic_grad[i] - snapshot_grad_i[i] + self.snapshot_grad[i];
622            params[i] -= self.lr * g_tilde;
623        }
624
625        self.inner_t += 1;
626        Ok(())
627    }
628
629    /// Update the snapshot with current parameters and full gradient.
630    ///
631    /// Should be called at the start of each outer epoch (every `update_freq`
632    /// inner steps).
633    pub fn update_snapshot(&mut self, params: &[f64], full_grad: &[f64]) {
634        self.snapshot_params = params.to_vec();
635        self.snapshot_grad = full_grad.to_vec();
636        self.inner_t = 0;
637    }
638
639    /// Whether the inner loop has completed and a snapshot update is due.
640    pub fn needs_snapshot_update(&self) -> bool {
641        self.inner_t >= self.update_freq
642    }
643
644    /// Current snapshot parameters.
645    pub fn snapshot_params(&self) -> &[f64] {
646        &self.snapshot_params
647    }
648}
649
650// ─────────────────────────────────────────────────────────────────────────────
651// Tests
652// ─────────────────────────────────────────────────────────────────────────────
653
654#[cfg(test)]
655mod tests {
656    use super::*;
657    use approx::assert_abs_diff_eq;
658
659    fn quadratic_grad(params: &[f64]) -> Vec<f64> {
660        params.iter().map(|&p| 2.0 * p).collect()
661    }
662
663    // ── LrSchedule ───────────────────────────────────────────────────────────
664
665    #[test]
666    fn test_constant_schedule() {
667        let s = LrSchedule::Constant(0.01);
668        assert_abs_diff_eq!(s.lr_at(0), 0.01, epsilon = 1e-14);
669        assert_abs_diff_eq!(s.lr_at(1000), 0.01, epsilon = 1e-14);
670    }
671
672    #[test]
673    fn test_exponential_decay_schedule() {
674        let s = LrSchedule::ExponentialDecay {
675            initial: 0.1,
676            decay: 0.9,
677        };
678        assert_abs_diff_eq!(s.lr_at(0), 0.1, epsilon = 1e-12);
679        assert_abs_diff_eq!(s.lr_at(1), 0.09, epsilon = 1e-10);
680        assert_abs_diff_eq!(s.lr_at(10), 0.1 * 0.9_f64.powi(10), epsilon = 1e-10);
681    }
682
683    #[test]
684    fn test_cosine_annealing_at_zero() {
685        let s = LrSchedule::CosineAnnealing {
686            lr_max: 0.1,
687            lr_min: 0.0,
688            t_max: 100,
689        };
690        // At step 0: cos(0) = 1 → lr = lr_min + 0.5*(lr_max-lr_min)*2 = lr_max
691        assert_abs_diff_eq!(s.lr_at(0), 0.1, epsilon = 1e-10);
692    }
693
694    #[test]
695    fn test_cosine_annealing_at_t_max() {
696        let s = LrSchedule::CosineAnnealing {
697            lr_max: 0.1,
698            lr_min: 0.001,
699            t_max: 50,
700        };
701        // At step t_max: cos(π) = -1 → lr = lr_min
702        assert_abs_diff_eq!(s.lr_at(50), 0.001, epsilon = 1e-10);
703    }
704
705    #[test]
706    fn test_warmup_cosine_warmup_phase() {
707        let s = LrSchedule::WarmupCosine {
708            warmup_steps: 10,
709            lr_peak: 0.1,
710            lr_min: 0.0,
711            total_steps: 110,
712        };
713        // Step 5 of 10 warmup → lr = 0.1 * 5/10 = 0.05
714        assert_abs_diff_eq!(s.lr_at(5), 0.05, epsilon = 1e-10);
715        // After warmup start: lr = lr_peak at step 10 (cos(0) phase)
716        let lr10 = s.lr_at(10);
717        assert!(
718            lr10 >= 0.09 && lr10 <= 0.1 + 1e-9,
719            "lr at warmup end ≈ peak, got {}",
720            lr10
721        );
722    }
723
724    #[test]
725    fn test_step_lr_schedule() {
726        let s = LrSchedule::StepLr {
727            initial: 0.1,
728            step_size: 10,
729            gamma: 0.5,
730        };
731        assert_abs_diff_eq!(s.lr_at(0), 0.1, epsilon = 1e-12);
732        assert_abs_diff_eq!(s.lr_at(9), 0.1, epsilon = 1e-12);
733        assert_abs_diff_eq!(s.lr_at(10), 0.05, epsilon = 1e-12);
734        assert_abs_diff_eq!(s.lr_at(20), 0.025, epsilon = 1e-12);
735    }
736
737    // ── SGD ──────────────────────────────────────────────────────────────────
738
739    #[test]
740    fn test_sgd_converges_quadratic() {
741        let mut opt = Sgd::new(0.1, 0.0);
742        let mut p = vec![1.0, -2.0];
743        for _ in 0..200 {
744            let g = quadratic_grad(&p);
745            opt.step(&mut p, &g).expect("step failed");
746        }
747        assert_abs_diff_eq!(p[0], 0.0, epsilon = 1e-4);
748        assert_abs_diff_eq!(p[1], 0.0, epsilon = 1e-4);
749    }
750
751    #[test]
752    fn test_sgd_momentum_converges() {
753        let mut opt = Sgd::new(0.05, 0.9);
754        let mut p = vec![2.0, -1.5];
755        for _ in 0..500 {
756            let g = quadratic_grad(&p);
757            opt.step(&mut p, &g).expect("step failed");
758        }
759        assert_abs_diff_eq!(p[0], 0.0, epsilon = 1e-3);
760        assert_abs_diff_eq!(p[1], 0.0, epsilon = 1e-3);
761    }
762
763    #[test]
764    fn test_sgd_nesterov() {
765        let mut opt = Sgd {
766            nesterov: true,
767            ..Sgd::new(0.05, 0.9)
768        };
769        let mut p = vec![1.0, 1.0];
770        for _ in 0..500 {
771            let g = quadratic_grad(&p);
772            opt.step(&mut p, &g).expect("step failed");
773        }
774        assert_abs_diff_eq!(p[0], 0.0, epsilon = 1e-3);
775    }
776
777    #[test]
778    fn test_sgd_weight_decay() {
779        let mut opt = Sgd {
780            weight_decay: 0.1,
781            ..Sgd::new(0.01, 0.0)
782        };
783        let mut p = vec![1.0];
784        opt.step(&mut p, &[0.0]).expect("step failed");
785        // Weight decay pulls toward 0: p_new = p - lr * wd * p = p * (1 - lr*wd)
786        assert!(p[0] < 1.0, "weight decay should shrink param");
787    }
788
789    #[test]
790    fn test_sgd_length_mismatch() {
791        let mut opt = Sgd::new(0.01, 0.0);
792        let mut p = vec![1.0, 2.0];
793        assert!(opt.step(&mut p, &[0.1]).is_err());
794    }
795
796    #[test]
797    fn test_sgd_zero_velocity() {
798        let mut opt = Sgd::new(0.01, 0.9);
799        opt.zero_velocity(5);
800        assert_eq!(opt.velocity.len(), 5);
801        assert!(opt.velocity.iter().all(|&v| v == 0.0));
802    }
803
804    // ── Adam ─────────────────────────────────────────────────────────────────
805
806    #[test]
807    fn test_adam_converges() {
808        let mut opt = Adam::new(0.01);
809        let mut p = vec![3.0, -3.0];
810        for _ in 0..1000 {
811            let g = quadratic_grad(&p);
812            opt.step(&mut p, &g).expect("step failed");
813        }
814        assert_abs_diff_eq!(p[0], 0.0, epsilon = 1e-2);
815        assert_abs_diff_eq!(p[1], 0.0, epsilon = 1e-2);
816    }
817
818    #[test]
819    fn test_adam_reset_state() {
820        let mut opt = Adam::new(0.01);
821        let mut p = vec![1.0];
822        opt.step(&mut p, &[0.5]).expect("step failed");
823        assert_eq!(opt.t, 1);
824        opt.reset_state();
825        assert_eq!(opt.t, 0);
826        assert!(opt.m.is_empty());
827        assert!(opt.v.is_empty());
828    }
829
830    #[test]
831    fn test_adam_weight_decay_coupled() {
832        let mut opt = Adam {
833            weight_decay: 0.01,
834            ..Adam::new(0.001)
835        };
836        let mut p = vec![1.0];
837        let p_before = p[0];
838        opt.step(&mut p, &[0.0]).expect("step failed");
839        // With pure weight decay (grad=0), param should shrink
840        assert!(p[0] < p_before, "weight decay should reduce param");
841    }
842
843    // ── AdaGrad ──────────────────────────────────────────────────────────────
844
845    #[test]
846    fn test_adagrad_converges() {
847        let mut opt = AdaGrad::new(0.5);
848        let mut p = vec![3.0, -2.0];
849        for _ in 0..2000 {
850            let g = quadratic_grad(&p);
851            opt.step(&mut p, &g).expect("step failed");
852        }
853        assert!(p[0].abs() < 0.5, "adagrad should converge, p[0]={}", p[0]);
854    }
855
856    #[test]
857    fn test_adagrad_reset() {
858        let mut opt = AdaGrad::new(0.1);
859        let mut p = vec![1.0];
860        opt.step(&mut p, &[1.0]).expect("step failed");
861        assert_eq!(opt.sum_sq_grad.len(), 1);
862        opt.reset_state();
863        assert!(opt.sum_sq_grad.is_empty());
864    }
865
866    // ── RMSprop ──────────────────────────────────────────────────────────────
867
868    #[test]
869    fn test_rmsprop_converges() {
870        let mut opt = RmsProp::new(0.01);
871        let mut p = vec![2.0, -2.0];
872        for _ in 0..1000 {
873            let g = quadratic_grad(&p);
874            opt.step(&mut p, &g).expect("step failed");
875        }
876        assert!(p[0].abs() < 0.1, "rmsprop p[0]={}", p[0]);
877    }
878
879    #[test]
880    fn test_rmsprop_with_momentum() {
881        let mut opt = RmsProp {
882            momentum: 0.9,
883            ..RmsProp::new(0.01)
884        };
885        let mut p = vec![1.0, 1.0];
886        for _ in 0..500 {
887            let g = quadratic_grad(&p);
888            opt.step(&mut p, &g).expect("step failed");
889        }
890        assert!(p[0].abs() < 0.5, "rmsprop+momentum p[0]={}", p[0]);
891    }
892
893    #[test]
894    fn test_rmsprop_length_mismatch() {
895        let mut opt = RmsProp::new(0.01);
896        let mut p = vec![1.0, 2.0];
897        assert!(opt.step(&mut p, &[0.1]).is_err());
898    }
899
900    // ── AdamW ────────────────────────────────────────────────────────────────
901
902    #[test]
903    fn test_adamw_decoupled_wd() {
904        let mut opt = AdamW {
905            weight_decay: 0.1,
906            ..AdamW::new(0.001)
907        };
908        let mut p = vec![1.0];
909        let p_before = p[0];
910        opt.step(&mut p, &[0.0]).expect("step failed");
911        // Decoupled WD: p ← p * (1 - lr * wd) then Adam on grad=0
912        assert!(p[0] < p_before, "decoupled WD should shrink param");
913    }
914
915    #[test]
916    fn test_adamw_converges() {
917        let mut opt = AdamW {
918            weight_decay: 0.0,
919            ..AdamW::new(0.01)
920        };
921        let mut p = vec![2.0, -2.0];
922        for _ in 0..1000 {
923            let g = quadratic_grad(&p);
924            opt.step(&mut p, &g).expect("step failed");
925        }
926        assert!(p[0].abs() < 0.1, "adamw p[0]={}", p[0]);
927    }
928
929    #[test]
930    fn test_adamw_reset() {
931        let mut opt = AdamW::new(0.001);
932        let mut p = vec![1.0];
933        opt.step(&mut p, &[0.5]).expect("step failed");
934        assert_eq!(opt.t, 1);
935        opt.reset_state();
936        assert_eq!(opt.t, 0);
937        assert!(opt.m.is_empty());
938    }
939
940    // ── SVRG ─────────────────────────────────────────────────────────────────
941
942    #[test]
943    fn test_svrg_needs_snapshot() {
944        let mut svrg = Svrg::new(0.01, 100, 10);
945        let mut p = vec![1.0, 2.0];
946        let sg = vec![0.1, 0.2];
947        let sgi = vec![0.05, 0.1];
948        // No snapshot: should error
949        assert!(svrg.step(&mut p, &sg, &sgi).is_err());
950    }
951
952    #[test]
953    fn test_svrg_step_after_snapshot() {
954        let mut svrg = Svrg::new(0.01, 100, 10);
955        let mut p = vec![1.0, 1.0];
956        let full_grad = vec![2.0, 2.0]; // grad at snapshot
957        svrg.update_snapshot(&p, &full_grad);
958
959        let sg = vec![2.1, 1.9];
960        let sgi = vec![2.0, 2.0];
961        svrg.step(&mut p, &sg, &sgi).expect("step failed");
962        // Effective grad = sg - sgi + full_grad = [2.1-2+2, 1.9-2+2] = [2.1, 1.9]
963        // p[0] ← 1 - 0.01*2.1 = 0.979
964        assert_abs_diff_eq!(p[0], 1.0 - 0.01 * 2.1, epsilon = 1e-12);
965    }
966
967    #[test]
968    fn test_svrg_update_freq() {
969        let mut svrg = Svrg::new(0.01, 100, 3);
970        let mut p = vec![1.0];
971        svrg.update_snapshot(&p, &[0.0]);
972        assert!(!svrg.needs_snapshot_update());
973
974        for _ in 0..3 {
975            svrg.step(&mut p, &[0.0], &[0.0]).expect("step");
976        }
977        assert!(svrg.needs_snapshot_update());
978    }
979
980    #[test]
981    fn test_svrg_snapshot_params() {
982        let mut svrg = Svrg::new(0.01, 100, 10);
983        let snap = vec![3.0, 4.0];
984        svrg.update_snapshot(&snap, &[0.0, 0.0]);
985        assert_eq!(svrg.snapshot_params(), &[3.0, 4.0]);
986    }
987
988    #[test]
989    fn test_svrg_length_mismatch() {
990        let mut svrg = Svrg::new(0.01, 100, 10);
991        let mut p = vec![1.0, 2.0];
992        svrg.update_snapshot(&p, &[0.0, 0.0]);
993        // Wrong length stochastic grad
994        assert!(svrg.step(&mut p, &[0.1], &[0.0, 0.0]).is_err());
995    }
996}