scirs2_optimize/stochastic/new_sgd.rs
1//! Stateful SGD optimizer variants
2//!
3//! This module provides stateful, struct-based optimizer objects that hold
4//! their own parameter state across `step()` calls. These are suitable for
5//! use in ML training loops where the same optimizer instance is updated
6//! repeatedly.
7//!
8//! # Algorithms
9//!
10//! | Type | Description |
11//! |------|-------------|
12//! | `SgdOptimizer` | SGD with optional momentum (classical & Nesterov), weight decay |
13//! | `AdaGradOptimizer` | Adaptive learning rates via accumulated squared gradients |
14//! | `AdaDeltaOptimizer` | Adaptive learning rates without a global LR (Zeiler 2012) |
15//!
16//! # References
17//!
18//! - Polyak (1964). "Some methods of speeding up the convergence of iteration methods".
19//! - Nesterov (1983). "A method of solving a convex programming problem".
20//! - Duchi et al. (2011). "Adaptive Subgradient Methods for Online Learning". *JMLR*.
21//! - Zeiler (2012). "ADADELTA: An Adaptive Learning Rate Method". arXiv:1212.5701.
22
23use crate::error::OptimizeError;
24
25// ─── SGD ─────────────────────────────────────────────────────────────────────
26
27/// Stateful SGD optimizer with optional momentum, Nesterov momentum, and
28/// L2 weight decay.
29///
30/// # Update rule (no Nesterov)
31/// ```text
32/// v_t = μ·v_{t-1} + g_t + λ·θ_{t-1}
33/// θ_t = θ_{t-1} - α·v_t
34/// ```
35///
36/// # Update rule (Nesterov)
37/// ```text
38/// v_t = μ·v_{t-1} + g_t + λ·θ_{t-1}
39/// θ_t = θ_{t-1} - α·(g_t + μ·v_t)
40/// ```
41///
42/// where α = `lr`, μ = `momentum`, λ = `weight_decay`.
43#[derive(Debug, Clone)]
44pub struct SgdOptimizer {
45 /// Learning rate
46 pub lr: f64,
47 /// Momentum coefficient (0 = vanilla SGD)
48 pub momentum: f64,
49 /// Use Nesterov momentum
50 pub nesterov: bool,
51 /// L2 weight-decay coefficient
52 pub weight_decay: f64,
53 /// Velocity buffer (accumulated momentum); populated lazily on first step
54 velocity: Vec<f64>,
55}
56
57impl SgdOptimizer {
58 /// Create a new SGD optimizer.
59 ///
60 /// # Arguments
61 /// * `lr` - Learning rate (must be > 0)
62 /// * `momentum` - Momentum factor in [0, 1)
63 /// * `nesterov` - Whether to use Nesterov lookahead momentum
64 /// * `weight_decay` - L2 regularisation strength (≥ 0)
65 pub fn new(lr: f64, momentum: f64, nesterov: bool, weight_decay: f64) -> Self {
66 Self {
67 lr,
68 momentum,
69 nesterov,
70 weight_decay,
71 velocity: Vec::new(),
72 }
73 }
74
75 /// Vanilla SGD with default hyperparameters (lr=0.01, no momentum).
76 pub fn vanilla(lr: f64) -> Self {
77 Self::new(lr, 0.0, false, 0.0)
78 }
79
80 /// Perform one SGD update step.
81 ///
82 /// # Arguments
83 /// * `params` - Mutable parameter vector; updated in-place
84 /// * `grad` - Gradient vector (same length as `params`)
85 ///
86 /// # Errors
87 /// Returns `OptimizeError::ValueError` if `params` and `grad` have
88 /// different lengths.
89 pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> Result<(), OptimizeError> {
90 let n = params.len();
91 if grad.len() != n {
92 return Err(OptimizeError::ValueError(format!(
93 "params length {} != grad length {}",
94 n,
95 grad.len()
96 )));
97 }
98
99 // Lazy initialisation of velocity buffer
100 if self.velocity.len() != n {
101 self.velocity = vec![0.0; n];
102 }
103
104 for i in 0..n {
105 // Add L2 regularisation to gradient
106 let g = grad[i] + self.weight_decay * params[i];
107
108 if self.momentum == 0.0 {
109 // Vanilla SGD
110 params[i] -= self.lr * g;
111 } else {
112 // Update velocity
113 self.velocity[i] = self.momentum * self.velocity[i] + g;
114
115 if self.nesterov {
116 // Nesterov: use the "lookahead" gradient
117 params[i] -= self.lr * (g + self.momentum * self.velocity[i]);
118 } else {
119 params[i] -= self.lr * self.velocity[i];
120 }
121 }
122 }
123 Ok(())
124 }
125
126 /// Reset velocity buffer (useful when restarting training).
127 pub fn reset(&mut self) {
128 self.velocity.clear();
129 }
130}
131
132// ─── AdaGrad ─────────────────────────────────────────────────────────────────
133
134/// AdaGrad optimizer.
135///
136/// Adapts the learning rate for each parameter by accumulating squared
137/// gradients. Parameters that receive large, frequent gradients see smaller
138/// effective learning rates.
139///
140/// # Update rule
141/// ```text
142/// G_t = G_{t-1} + g_t ⊙ g_t
143/// θ_t = θ_{t-1} - α / (√G_t + ε) ⊙ g_t
144/// ```
145///
146/// Reference: Duchi et al. (2011).
147#[derive(Debug, Clone)]
148pub struct AdaGradOptimizer {
149 /// Global learning rate
150 pub lr: f64,
151 /// Numerical stability constant
152 pub eps: f64,
153 /// Accumulated squared gradients
154 pub accum: Vec<f64>,
155}
156
157impl AdaGradOptimizer {
158 /// Create a new AdaGrad optimizer.
159 pub fn new(lr: f64, eps: f64) -> Self {
160 Self {
161 lr,
162 eps,
163 accum: Vec::new(),
164 }
165 }
166
167 /// Create with default hyperparameters (lr=0.01, eps=1e-8).
168 pub fn default_params(lr: f64) -> Self {
169 Self::new(lr, 1e-8)
170 }
171
172 /// Perform one AdaGrad update step.
173 ///
174 /// # Errors
175 /// Returns `OptimizeError::ValueError` if length mismatch.
176 pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> Result<(), OptimizeError> {
177 let n = params.len();
178 if grad.len() != n {
179 return Err(OptimizeError::ValueError(format!(
180 "params length {} != grad length {}",
181 n,
182 grad.len()
183 )));
184 }
185
186 if self.accum.len() != n {
187 self.accum = vec![0.0; n];
188 }
189
190 for i in 0..n {
191 self.accum[i] += grad[i] * grad[i];
192 params[i] -= self.lr / (self.accum[i].sqrt() + self.eps) * grad[i];
193 }
194 Ok(())
195 }
196
197 /// Reset accumulated state.
198 pub fn reset(&mut self) {
199 self.accum.clear();
200 }
201}
202
203// ─── AdaDelta ────────────────────────────────────────────────────────────────
204
205/// AdaDelta optimizer.
206///
207/// Extends AdaGrad to avoid its monotonically decreasing learning rate by
208/// using an exponentially decaying window of past squared gradients.
209/// Importantly, no global learning rate is required.
210///
211/// # Update rule
212/// ```text
213/// E[g²]_t = ρ·E[g²]_{t-1} + (1-ρ)·g_t²
214/// Δθ_t = -√(E[Δθ²]_{t-1} + ε) / √(E[g²]_t + ε) · g_t
215/// E[Δθ²]_t = ρ·E[Δθ²]_{t-1} + (1-ρ)·Δθ_t²
216/// θ_t = θ_{t-1} + Δθ_t
217/// ```
218///
219/// Reference: Zeiler (2012), "ADADELTA: An Adaptive Learning Rate Method".
220#[derive(Debug, Clone)]
221pub struct AdaDeltaOptimizer {
222 /// Decay rate for running averages
223 pub rho: f64,
224 /// Numerical stability constant
225 pub eps: f64,
226 /// Running average of squared gradients: E\[g²\]
227 pub accum_grad: Vec<f64>,
228 /// Running average of squared updates: E\[Δθ²\]
229 pub accum_update: Vec<f64>,
230}
231
232impl AdaDeltaOptimizer {
233 /// Create a new AdaDelta optimizer.
234 ///
235 /// # Arguments
236 /// * `rho` - Decay factor for exponential moving averages (typically 0.95)
237 /// * `eps` - Numerical stability (typically 1e-6)
238 pub fn new(rho: f64, eps: f64) -> Self {
239 Self {
240 rho,
241 eps,
242 accum_grad: Vec::new(),
243 accum_update: Vec::new(),
244 }
245 }
246
247 /// Create with default hyperparameters (rho=0.95, eps=1e-6).
248 pub fn default_params() -> Self {
249 Self::new(0.95, 1e-6)
250 }
251
252 /// Perform one AdaDelta update step.
253 ///
254 /// # Errors
255 /// Returns `OptimizeError::ValueError` if length mismatch.
256 pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> Result<(), OptimizeError> {
257 let n = params.len();
258 if grad.len() != n {
259 return Err(OptimizeError::ValueError(format!(
260 "params length {} != grad length {}",
261 n,
262 grad.len()
263 )));
264 }
265
266 if self.accum_grad.len() != n {
267 self.accum_grad = vec![0.0; n];
268 self.accum_update = vec![0.0; n];
269 }
270
271 for i in 0..n {
272 // Update running average of squared gradients
273 self.accum_grad[i] =
274 self.rho * self.accum_grad[i] + (1.0 - self.rho) * grad[i] * grad[i];
275
276 // Compute parameter update using RMS of past updates
277 let rms_update = (self.accum_update[i] + self.eps).sqrt();
278 let rms_grad = (self.accum_grad[i] + self.eps).sqrt();
279 let delta = -(rms_update / rms_grad) * grad[i];
280
281 // Update running average of squared updates
282 self.accum_update[i] =
283 self.rho * self.accum_update[i] + (1.0 - self.rho) * delta * delta;
284
285 params[i] += delta;
286 }
287 Ok(())
288 }
289
290 /// Reset accumulated state.
291 pub fn reset(&mut self) {
292 self.accum_grad.clear();
293 self.accum_update.clear();
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use approx::assert_abs_diff_eq;
301
302 fn quadratic_grad(x: &[f64]) -> Vec<f64> {
303 x.iter().map(|&xi| 2.0 * xi).collect()
304 }
305
306 #[test]
307 fn test_sgd_vanilla_converges() {
308 let mut opt = SgdOptimizer::vanilla(0.1);
309 let mut params = vec![1.0, -2.0, 0.5];
310 for _ in 0..200 {
311 let g = quadratic_grad(¶ms);
312 opt.step(&mut params, &g).expect("step failed");
313 }
314 for &p in ¶ms {
315 assert_abs_diff_eq!(p, 0.0, epsilon = 1e-4);
316 }
317 }
318
319 #[test]
320 fn test_sgd_momentum_converges() {
321 let mut opt = SgdOptimizer::new(0.05, 0.9, false, 0.0);
322 let mut params = vec![2.0, -1.5];
323 for _ in 0..300 {
324 let g = quadratic_grad(¶ms);
325 opt.step(&mut params, &g).expect("step failed");
326 }
327 for &p in ¶ms {
328 assert_abs_diff_eq!(p, 0.0, epsilon = 1e-3);
329 }
330 }
331
332 #[test]
333 fn test_sgd_nesterov_converges() {
334 let mut opt = SgdOptimizer::new(0.05, 0.9, true, 0.0);
335 let mut params = vec![1.5, -1.0];
336 for _ in 0..300 {
337 let g = quadratic_grad(¶ms);
338 opt.step(&mut params, &g).expect("step failed");
339 }
340 for &p in ¶ms {
341 assert_abs_diff_eq!(p, 0.0, epsilon = 1e-3);
342 }
343 }
344
345 #[test]
346 fn test_sgd_weight_decay() {
347 // With weight decay, minimum shifts; check that update is applied
348 let mut opt = SgdOptimizer::new(0.01, 0.0, false, 0.1);
349 let mut params = vec![1.0];
350 let init = params[0];
351 let g = vec![0.0]; // zero gradient; only weight decay should pull
352 opt.step(&mut params, &g).expect("step failed");
353 assert!(params[0] < init, "weight decay should reduce param");
354 }
355
356 #[test]
357 fn test_sgd_length_mismatch() {
358 let mut opt = SgdOptimizer::vanilla(0.1);
359 let mut params = vec![1.0, 2.0];
360 let grad = vec![0.1]; // wrong length
361 assert!(opt.step(&mut params, &grad).is_err());
362 }
363
364 #[test]
365 fn test_adagrad_converges() {
366 let mut opt = AdaGradOptimizer::default_params(0.5);
367 let mut params = vec![3.0, -2.0];
368 for _ in 0..500 {
369 let g = quadratic_grad(¶ms);
370 opt.step(&mut params, &g).expect("step failed");
371 }
372 for &p in ¶ms {
373 assert_abs_diff_eq!(p, 0.0, epsilon = 0.1);
374 }
375 }
376
377 #[test]
378 fn test_adadelta_converges() {
379 let mut opt = AdaDeltaOptimizer::default_params();
380 let mut params = vec![2.0, -1.0];
381 for _ in 0..2000 {
382 let g = quadratic_grad(¶ms);
383 opt.step(&mut params, &g).expect("step failed");
384 }
385 for &p in ¶ms {
386 assert_abs_diff_eq!(p, 0.0, epsilon = 0.5);
387 }
388 }
389
390 #[test]
391 fn test_adadelta_length_mismatch() {
392 let mut opt = AdaDeltaOptimizer::default_params();
393 let mut params = vec![1.0, 2.0];
394 let grad = vec![0.1, 0.2, 0.3]; // wrong length
395 assert!(opt.step(&mut params, &grad).is_err());
396 }
397}