Skip to main content

scirs2_optimize/stochastic/
new_sgd.rs

1//! Stateful SGD optimizer variants
2//!
3//! This module provides stateful, struct-based optimizer objects that hold
4//! their own parameter state across `step()` calls. These are suitable for
5//! use in ML training loops where the same optimizer instance is updated
6//! repeatedly.
7//!
8//! # Algorithms
9//!
10//! | Type | Description |
11//! |------|-------------|
12//! | `SgdOptimizer` | SGD with optional momentum (classical & Nesterov), weight decay |
13//! | `AdaGradOptimizer` | Adaptive learning rates via accumulated squared gradients |
14//! | `AdaDeltaOptimizer` | Adaptive learning rates without a global LR (Zeiler 2012) |
15//!
16//! # References
17//!
18//! - Polyak (1964). "Some methods of speeding up the convergence of iteration methods".
19//! - Nesterov (1983). "A method of solving a convex programming problem".
20//! - Duchi et al. (2011). "Adaptive Subgradient Methods for Online Learning". *JMLR*.
21//! - Zeiler (2012). "ADADELTA: An Adaptive Learning Rate Method". arXiv:1212.5701.
22
23use crate::error::OptimizeError;
24
25// ─── SGD ─────────────────────────────────────────────────────────────────────
26
27/// Stateful SGD optimizer with optional momentum, Nesterov momentum, and
28/// L2 weight decay.
29///
30/// # Update rule (no Nesterov)
31/// ```text
32/// v_t = μ·v_{t-1} + g_t + λ·θ_{t-1}
33/// θ_t = θ_{t-1} - α·v_t
34/// ```
35///
36/// # Update rule (Nesterov)
37/// ```text
38/// v_t = μ·v_{t-1} + g_t + λ·θ_{t-1}
39/// θ_t = θ_{t-1} - α·(g_t + μ·v_t)
40/// ```
41///
42/// where α = `lr`, μ = `momentum`, λ = `weight_decay`.
43#[derive(Debug, Clone)]
44pub struct SgdOptimizer {
45    /// Learning rate
46    pub lr: f64,
47    /// Momentum coefficient (0 = vanilla SGD)
48    pub momentum: f64,
49    /// Use Nesterov momentum
50    pub nesterov: bool,
51    /// L2 weight-decay coefficient
52    pub weight_decay: f64,
53    /// Velocity buffer (accumulated momentum); populated lazily on first step
54    velocity: Vec<f64>,
55}
56
57impl SgdOptimizer {
58    /// Create a new SGD optimizer.
59    ///
60    /// # Arguments
61    /// * `lr` - Learning rate (must be > 0)
62    /// * `momentum` - Momentum factor in [0, 1)
63    /// * `nesterov` - Whether to use Nesterov lookahead momentum
64    /// * `weight_decay` - L2 regularisation strength (≥ 0)
65    pub fn new(lr: f64, momentum: f64, nesterov: bool, weight_decay: f64) -> Self {
66        Self {
67            lr,
68            momentum,
69            nesterov,
70            weight_decay,
71            velocity: Vec::new(),
72        }
73    }
74
75    /// Vanilla SGD with default hyperparameters (lr=0.01, no momentum).
76    pub fn vanilla(lr: f64) -> Self {
77        Self::new(lr, 0.0, false, 0.0)
78    }
79
80    /// Perform one SGD update step.
81    ///
82    /// # Arguments
83    /// * `params` - Mutable parameter vector; updated in-place
84    /// * `grad` - Gradient vector (same length as `params`)
85    ///
86    /// # Errors
87    /// Returns `OptimizeError::ValueError` if `params` and `grad` have
88    /// different lengths.
89    pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> Result<(), OptimizeError> {
90        let n = params.len();
91        if grad.len() != n {
92            return Err(OptimizeError::ValueError(format!(
93                "params length {} != grad length {}",
94                n,
95                grad.len()
96            )));
97        }
98
99        // Lazy initialisation of velocity buffer
100        if self.velocity.len() != n {
101            self.velocity = vec![0.0; n];
102        }
103
104        for i in 0..n {
105            // Add L2 regularisation to gradient
106            let g = grad[i] + self.weight_decay * params[i];
107
108            if self.momentum == 0.0 {
109                // Vanilla SGD
110                params[i] -= self.lr * g;
111            } else {
112                // Update velocity
113                self.velocity[i] = self.momentum * self.velocity[i] + g;
114
115                if self.nesterov {
116                    // Nesterov: use the "lookahead" gradient
117                    params[i] -= self.lr * (g + self.momentum * self.velocity[i]);
118                } else {
119                    params[i] -= self.lr * self.velocity[i];
120                }
121            }
122        }
123        Ok(())
124    }
125
126    /// Reset velocity buffer (useful when restarting training).
127    pub fn reset(&mut self) {
128        self.velocity.clear();
129    }
130}
131
132// ─── AdaGrad ─────────────────────────────────────────────────────────────────
133
134/// AdaGrad optimizer.
135///
136/// Adapts the learning rate for each parameter by accumulating squared
137/// gradients. Parameters that receive large, frequent gradients see smaller
138/// effective learning rates.
139///
140/// # Update rule
141/// ```text
142/// G_t = G_{t-1} + g_t ⊙ g_t
143/// θ_t = θ_{t-1} - α / (√G_t + ε) ⊙ g_t
144/// ```
145///
146/// Reference: Duchi et al. (2011).
147#[derive(Debug, Clone)]
148pub struct AdaGradOptimizer {
149    /// Global learning rate
150    pub lr: f64,
151    /// Numerical stability constant
152    pub eps: f64,
153    /// Accumulated squared gradients
154    pub accum: Vec<f64>,
155}
156
157impl AdaGradOptimizer {
158    /// Create a new AdaGrad optimizer.
159    pub fn new(lr: f64, eps: f64) -> Self {
160        Self {
161            lr,
162            eps,
163            accum: Vec::new(),
164        }
165    }
166
167    /// Create with default hyperparameters (lr=0.01, eps=1e-8).
168    pub fn default_params(lr: f64) -> Self {
169        Self::new(lr, 1e-8)
170    }
171
172    /// Perform one AdaGrad update step.
173    ///
174    /// # Errors
175    /// Returns `OptimizeError::ValueError` if length mismatch.
176    pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> Result<(), OptimizeError> {
177        let n = params.len();
178        if grad.len() != n {
179            return Err(OptimizeError::ValueError(format!(
180                "params length {} != grad length {}",
181                n,
182                grad.len()
183            )));
184        }
185
186        if self.accum.len() != n {
187            self.accum = vec![0.0; n];
188        }
189
190        for i in 0..n {
191            self.accum[i] += grad[i] * grad[i];
192            params[i] -= self.lr / (self.accum[i].sqrt() + self.eps) * grad[i];
193        }
194        Ok(())
195    }
196
197    /// Reset accumulated state.
198    pub fn reset(&mut self) {
199        self.accum.clear();
200    }
201}
202
203// ─── AdaDelta ────────────────────────────────────────────────────────────────
204
205/// AdaDelta optimizer.
206///
207/// Extends AdaGrad to avoid its monotonically decreasing learning rate by
208/// using an exponentially decaying window of past squared gradients.
209/// Importantly, no global learning rate is required.
210///
211/// # Update rule
212/// ```text
213/// E[g²]_t    = ρ·E[g²]_{t-1}    + (1-ρ)·g_t²
214/// Δθ_t       = -√(E[Δθ²]_{t-1} + ε) / √(E[g²]_t + ε) · g_t
215/// E[Δθ²]_t   = ρ·E[Δθ²]_{t-1}  + (1-ρ)·Δθ_t²
216/// θ_t        = θ_{t-1} + Δθ_t
217/// ```
218///
219/// Reference: Zeiler (2012), "ADADELTA: An Adaptive Learning Rate Method".
220#[derive(Debug, Clone)]
221pub struct AdaDeltaOptimizer {
222    /// Decay rate for running averages
223    pub rho: f64,
224    /// Numerical stability constant
225    pub eps: f64,
226    /// Running average of squared gradients: E\[g²\]
227    pub accum_grad: Vec<f64>,
228    /// Running average of squared updates: E\[Δθ²\]
229    pub accum_update: Vec<f64>,
230}
231
232impl AdaDeltaOptimizer {
233    /// Create a new AdaDelta optimizer.
234    ///
235    /// # Arguments
236    /// * `rho` - Decay factor for exponential moving averages (typically 0.95)
237    /// * `eps` - Numerical stability (typically 1e-6)
238    pub fn new(rho: f64, eps: f64) -> Self {
239        Self {
240            rho,
241            eps,
242            accum_grad: Vec::new(),
243            accum_update: Vec::new(),
244        }
245    }
246
247    /// Create with default hyperparameters (rho=0.95, eps=1e-6).
248    pub fn default_params() -> Self {
249        Self::new(0.95, 1e-6)
250    }
251
252    /// Perform one AdaDelta update step.
253    ///
254    /// # Errors
255    /// Returns `OptimizeError::ValueError` if length mismatch.
256    pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> Result<(), OptimizeError> {
257        let n = params.len();
258        if grad.len() != n {
259            return Err(OptimizeError::ValueError(format!(
260                "params length {} != grad length {}",
261                n,
262                grad.len()
263            )));
264        }
265
266        if self.accum_grad.len() != n {
267            self.accum_grad = vec![0.0; n];
268            self.accum_update = vec![0.0; n];
269        }
270
271        for i in 0..n {
272            // Update running average of squared gradients
273            self.accum_grad[i] =
274                self.rho * self.accum_grad[i] + (1.0 - self.rho) * grad[i] * grad[i];
275
276            // Compute parameter update using RMS of past updates
277            let rms_update = (self.accum_update[i] + self.eps).sqrt();
278            let rms_grad = (self.accum_grad[i] + self.eps).sqrt();
279            let delta = -(rms_update / rms_grad) * grad[i];
280
281            // Update running average of squared updates
282            self.accum_update[i] =
283                self.rho * self.accum_update[i] + (1.0 - self.rho) * delta * delta;
284
285            params[i] += delta;
286        }
287        Ok(())
288    }
289
290    /// Reset accumulated state.
291    pub fn reset(&mut self) {
292        self.accum_grad.clear();
293        self.accum_update.clear();
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use approx::assert_abs_diff_eq;
301
302    fn quadratic_grad(x: &[f64]) -> Vec<f64> {
303        x.iter().map(|&xi| 2.0 * xi).collect()
304    }
305
306    #[test]
307    fn test_sgd_vanilla_converges() {
308        let mut opt = SgdOptimizer::vanilla(0.1);
309        let mut params = vec![1.0, -2.0, 0.5];
310        for _ in 0..200 {
311            let g = quadratic_grad(&params);
312            opt.step(&mut params, &g).expect("step failed");
313        }
314        for &p in &params {
315            assert_abs_diff_eq!(p, 0.0, epsilon = 1e-4);
316        }
317    }
318
319    #[test]
320    fn test_sgd_momentum_converges() {
321        let mut opt = SgdOptimizer::new(0.05, 0.9, false, 0.0);
322        let mut params = vec![2.0, -1.5];
323        for _ in 0..300 {
324            let g = quadratic_grad(&params);
325            opt.step(&mut params, &g).expect("step failed");
326        }
327        for &p in &params {
328            assert_abs_diff_eq!(p, 0.0, epsilon = 1e-3);
329        }
330    }
331
332    #[test]
333    fn test_sgd_nesterov_converges() {
334        let mut opt = SgdOptimizer::new(0.05, 0.9, true, 0.0);
335        let mut params = vec![1.5, -1.0];
336        for _ in 0..300 {
337            let g = quadratic_grad(&params);
338            opt.step(&mut params, &g).expect("step failed");
339        }
340        for &p in &params {
341            assert_abs_diff_eq!(p, 0.0, epsilon = 1e-3);
342        }
343    }
344
345    #[test]
346    fn test_sgd_weight_decay() {
347        // With weight decay, minimum shifts; check that update is applied
348        let mut opt = SgdOptimizer::new(0.01, 0.0, false, 0.1);
349        let mut params = vec![1.0];
350        let init = params[0];
351        let g = vec![0.0]; // zero gradient; only weight decay should pull
352        opt.step(&mut params, &g).expect("step failed");
353        assert!(params[0] < init, "weight decay should reduce param");
354    }
355
356    #[test]
357    fn test_sgd_length_mismatch() {
358        let mut opt = SgdOptimizer::vanilla(0.1);
359        let mut params = vec![1.0, 2.0];
360        let grad = vec![0.1]; // wrong length
361        assert!(opt.step(&mut params, &grad).is_err());
362    }
363
364    #[test]
365    fn test_adagrad_converges() {
366        let mut opt = AdaGradOptimizer::default_params(0.5);
367        let mut params = vec![3.0, -2.0];
368        for _ in 0..500 {
369            let g = quadratic_grad(&params);
370            opt.step(&mut params, &g).expect("step failed");
371        }
372        for &p in &params {
373            assert_abs_diff_eq!(p, 0.0, epsilon = 0.1);
374        }
375    }
376
377    #[test]
378    fn test_adadelta_converges() {
379        let mut opt = AdaDeltaOptimizer::default_params();
380        let mut params = vec![2.0, -1.0];
381        for _ in 0..2000 {
382            let g = quadratic_grad(&params);
383            opt.step(&mut params, &g).expect("step failed");
384        }
385        for &p in &params {
386            assert_abs_diff_eq!(p, 0.0, epsilon = 0.5);
387        }
388    }
389
390    #[test]
391    fn test_adadelta_length_mismatch() {
392        let mut opt = AdaDeltaOptimizer::default_params();
393        let mut params = vec![1.0, 2.0];
394        let grad = vec![0.1, 0.2, 0.3]; // wrong length
395        assert!(opt.step(&mut params, &grad).is_err());
396    }
397}