Skip to main content

scirs2_optimize/stochastic/
variance_reduction.rs

1//! Variance Reduction Methods for Stochastic Gradient Descent
2//!
3//! Standard SGD suffers from variance due to mini-batch noise, which prevents
4//! convergence to the exact minimiser.  Variance reduction methods achieve
5//! *linear* convergence on strongly convex problems by periodically correcting
6//! for the gradient bias.
7//!
8//! # Algorithms
9//!
10//! | Method | Reference | Gradient evals per pass |
11//! |--------|-----------|--------------------------|
12//! | SVRG   | Johnson & Zhang (2013) | 2n (snapshot + inner) |
13//! | SARAH  | Nguyen et al. (2017)   | n + m·b               |
14//! | SPIDER | Fang et al. (2018)     | n + m·b               |
15//!
16//! All three maintain a recursive gradient correction that cancels the noise
17//! introduced by mini-batch subsampling.
18//!
19//! # References
20//!
21//! - Johnson, R. & Zhang, T. (2013). "Accelerating Stochastic Gradient Descent
22//!   using Predictive Variance Reduction". *NeurIPS*.
23//! - Nguyen, L.M. et al. (2017). "SARAH: A Novel Method for Machine Learning
24//!   Problems Using Stochastic Recursive Gradient". *ICML*.
25//! - Fang, C. et al. (2018). "SPIDER: Near-Optimal Non-Convex Optimization
26//!   via Stochastic Path-Integrated Differential Estimator". *NeurIPS*.
27
28use crate::error::{OptimizeError, OptimizeResult};
29use scirs2_core::ndarray::{Array1, ArrayView1};
30
31// ─── Shared helpers ──────────────────────────────────────────────────────────
32
33/// Compute a finite-difference gradient estimate for a single data point.
34#[inline]
35fn finite_diff_grad<F>(
36    f: &mut F,
37    x: &ArrayView1<f64>,
38    sample: &ArrayView1<f64>,
39    h: f64,
40) -> Array1<f64>
41where
42    F: FnMut(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
43{
44    let n = x.len();
45    let f0 = f(x, sample);
46    let mut grad = Array1::<f64>::zeros(n);
47    let mut x_fwd = x.to_owned();
48    for i in 0..n {
49        x_fwd[i] += h;
50        grad[i] = (f(&x_fwd.view(), sample) - f0) / h;
51        x_fwd[i] = x[i];
52    }
53    grad
54}
55
56/// Average gradient over a set of samples (full batch).
57fn full_grad<F>(f: &mut F, x: &ArrayView1<f64>, samples: &[Array1<f64>], h: f64) -> Array1<f64>
58where
59    F: FnMut(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
60{
61    let n = x.len();
62    if samples.is_empty() {
63        return Array1::zeros(n);
64    }
65    let mut avg = Array1::<f64>::zeros(n);
66    for s in samples {
67        let g = finite_diff_grad(f, x, &s.view(), h);
68        for i in 0..n {
69            avg[i] += g[i];
70        }
71    }
72    let inv_m = 1.0 / samples.len() as f64;
73    avg.mapv_inplace(|v| v * inv_m);
74    avg
75}
76
77// ─── SVRG ────────────────────────────────────────────────────────────────────
78
79/// Options for the SVRG optimizer.
80#[derive(Debug, Clone)]
81pub struct SvrgOptions {
82    /// Number of outer iterations (snapshot updates).
83    pub n_epochs: usize,
84    /// Number of inner SGD steps per epoch.
85    pub inner_steps: usize,
86    /// Step size η.
87    pub step_size: f64,
88    /// Convergence tolerance (gradient norm of snapshot).
89    pub tol: f64,
90    /// Finite-difference step for gradient approximation.
91    pub fd_step: f64,
92}
93
94impl Default for SvrgOptions {
95    fn default() -> Self {
96        Self {
97            n_epochs: 50,
98            inner_steps: 100,
99            step_size: 1e-3,
100            tol: 1e-6,
101            fd_step: 1e-5,
102        }
103    }
104}
105
106/// Result from SVRG optimisation.
107#[derive(Debug, Clone)]
108pub struct SvrgResult {
109    /// Approximate minimiser.
110    pub x: Array1<f64>,
111    /// Final full-gradient norm.
112    pub grad_norm: f64,
113    /// Total number of gradient evaluations.
114    pub n_grad_evals: usize,
115    /// Whether tolerance was met.
116    pub converged: bool,
117}
118
119/// Stochastic Variance Reduced Gradient (SVRG) optimizer.
120///
121/// At the start of each epoch, a full-gradient μ̃ = ∇f(x̃) is computed at a
122/// snapshot x̃.  Each inner step uses the variance-reduced direction:
123///
124/// v = ∇fᵢ(x) - ∇fᵢ(x̃) + μ̃
125///
126/// which has zero variance when x = x̃.
127///
128/// # Arguments
129///
130/// * `f`       – per-sample loss: (x, sample) → f64
131/// * `x0`      – starting point
132/// * `samples` – full dataset (each element is one sample parameter vector)
133/// * `opts`    – SVRG options
134pub fn svrg<F>(
135    f: &mut F,
136    x0: &ArrayView1<f64>,
137    samples: &[Array1<f64>],
138    opts: &SvrgOptions,
139) -> OptimizeResult<SvrgResult>
140where
141    F: FnMut(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
142{
143    let n = x0.len();
144    if n == 0 {
145        return Err(OptimizeError::ValueError(
146            "x0 must be non-empty".to_string(),
147        ));
148    }
149    if samples.is_empty() {
150        return Err(OptimizeError::ValueError(
151            "samples must be non-empty".to_string(),
152        ));
153    }
154
155    let m = samples.len();
156    let mut x = x0.to_owned();
157    let mut converged = false;
158    let mut total_evals: usize = 0;
159    // LCG for sample selection
160    let mut rng: u64 = 987654321;
161
162    for _ in 0..opts.n_epochs {
163        // Snapshot: x̃ ← x, compute full gradient μ̃
164        let x_tilde = x.clone();
165        let mu_tilde = full_grad(f, &x_tilde.view(), samples, opts.fd_step);
166        total_evals += m * (n + 1);
167
168        let grad_norm = mu_tilde.iter().map(|v| v * v).sum::<f64>().sqrt();
169        if grad_norm < opts.tol {
170            converged = true;
171            return Ok(SvrgResult {
172                x,
173                grad_norm,
174                n_grad_evals: total_evals,
175                converged,
176            });
177        }
178
179        // Inner loop
180        for _ in 0..opts.inner_steps {
181            // Pick random sample index via LCG
182            rng = rng
183                .wrapping_mul(6364136223846793005)
184                .wrapping_add(1442695040888963407);
185            let idx = (rng >> 33) as usize % m;
186            let s = &samples[idx];
187
188            let g_x = finite_diff_grad(f, &x.view(), &s.view(), opts.fd_step);
189            let g_tilde = finite_diff_grad(f, &x_tilde.view(), &s.view(), opts.fd_step);
190            total_evals += 2 * (n + 1);
191
192            // SVRG direction: v = g(x) - g(x̃) + μ̃
193            for i in 0..n {
194                x[i] -= opts.step_size * (g_x[i] - g_tilde[i] + mu_tilde[i]);
195            }
196        }
197    }
198
199    let grad_norm = full_grad(f, &x.view(), samples, opts.fd_step)
200        .iter()
201        .map(|v| v * v)
202        .sum::<f64>()
203        .sqrt();
204
205    Ok(SvrgResult {
206        x,
207        grad_norm,
208        n_grad_evals: total_evals,
209        converged,
210    })
211}
212
213// ─── SARAH ───────────────────────────────────────────────────────────────────
214
215/// Options for the SARAH optimizer.
216#[derive(Debug, Clone)]
217pub struct SarahOptions {
218    /// Number of outer iterations.
219    pub n_outer: usize,
220    /// Inner loop length m.
221    pub inner_steps: usize,
222    /// Step size η.
223    pub step_size: f64,
224    /// Convergence tolerance (full gradient norm).
225    pub tol: f64,
226    /// Finite-difference step.
227    pub fd_step: f64,
228}
229
230impl Default for SarahOptions {
231    fn default() -> Self {
232        Self {
233            n_outer: 50,
234            inner_steps: 50,
235            step_size: 1e-3,
236            tol: 1e-6,
237            fd_step: 1e-5,
238        }
239    }
240}
241
242/// Result from SARAH optimisation.
243#[derive(Debug, Clone)]
244pub struct SarahResult {
245    /// Approximate minimiser.
246    pub x: Array1<f64>,
247    /// Final full-gradient norm.
248    pub grad_norm: f64,
249    /// Total gradient evaluations (approximate).
250    pub n_grad_evals: usize,
251    /// Whether tolerance was met.
252    pub converged: bool,
253}
254
255/// StochAstic Recursive grAdient algoritHm (SARAH).
256///
257/// SARAH maintains a recursive gradient estimator:
258///
259///   v₀ = ∇f(x₀)   (full gradient)
260///   vₜ = ∇fᵢₜ(xₜ) - ∇fᵢₜ(xₜ₋₁) + vₜ₋₁   (recursive update)
261///   xₜ₊₁ = xₜ - η vₜ
262///
263/// This recursive estimator converges to the true gradient, enabling linear
264/// convergence on strongly-convex problems.
265///
266/// # Arguments
267///
268/// * `f`       – per-sample loss: (x, sample) → f64
269/// * `x0`      – starting point
270/// * `samples` – full dataset
271/// * `opts`    – SARAH options
272pub fn sarah<F>(
273    f: &mut F,
274    x0: &ArrayView1<f64>,
275    samples: &[Array1<f64>],
276    opts: &SarahOptions,
277) -> OptimizeResult<SarahResult>
278where
279    F: FnMut(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
280{
281    let n = x0.len();
282    if n == 0 {
283        return Err(OptimizeError::ValueError(
284            "x0 must be non-empty".to_string(),
285        ));
286    }
287    if samples.is_empty() {
288        return Err(OptimizeError::ValueError(
289            "samples must be non-empty".to_string(),
290        ));
291    }
292
293    let m = samples.len();
294    let mut x = x0.to_owned();
295    let mut converged = false;
296    let mut total_evals: usize = 0;
297    let mut rng: u64 = 11111111111111111;
298
299    for _ in 0..opts.n_outer {
300        // v₀ = full gradient at current x
301        let mut v = full_grad(f, &x.view(), samples, opts.fd_step);
302        total_evals += m * (n + 1);
303
304        let g_norm = v.iter().map(|vi| vi * vi).sum::<f64>().sqrt();
305        if g_norm < opts.tol {
306            converged = true;
307            return Ok(SarahResult {
308                x,
309                grad_norm: g_norm,
310                n_grad_evals: total_evals,
311                converged,
312            });
313        }
314
315        // x₀ of inner loop
316        for i in 0..n {
317            x[i] -= opts.step_size * v[i];
318        }
319
320        let mut x_prev = x.clone();
321
322        for _ in 0..opts.inner_steps {
323            rng = rng
324                .wrapping_mul(6364136223846793005)
325                .wrapping_add(1442695040888963407);
326            let idx = (rng >> 33) as usize % m;
327            let s = &samples[idx];
328
329            let g_curr = finite_diff_grad(f, &x.view(), &s.view(), opts.fd_step);
330            let g_prev = finite_diff_grad(f, &x_prev.view(), &s.view(), opts.fd_step);
331            total_evals += 2 * (n + 1);
332
333            // Recursive update: vₜ = ∇fᵢ(x) - ∇fᵢ(x_prev) + v_{t-1}
334            let v_new: Array1<f64> = g_curr
335                .iter()
336                .zip(g_prev.iter())
337                .zip(v.iter())
338                .map(|((&gc, &gp), &vp)| gc - gp + vp)
339                .collect();
340
341            x_prev = x.clone();
342            for i in 0..n {
343                x[i] -= opts.step_size * v_new[i];
344            }
345            v = v_new;
346        }
347    }
348
349    let g_norm = full_grad(f, &x.view(), samples, opts.fd_step)
350        .iter()
351        .map(|v| v * v)
352        .sum::<f64>()
353        .sqrt();
354
355    Ok(SarahResult {
356        x,
357        grad_norm: g_norm,
358        n_grad_evals: total_evals,
359        converged,
360    })
361}
362
363// ─── SPIDER ──────────────────────────────────────────────────────────────────
364
365/// Options for the SPIDER optimizer.
366#[derive(Debug, Clone)]
367pub struct SpiderOptions {
368    /// Number of outer iterations (full gradient recomputes).
369    pub n_outer: usize,
370    /// Inner steps per outer iteration.
371    pub inner_steps: usize,
372    /// Step size η.
373    pub step_size: f64,
374    /// Convergence tolerance (gradient estimator norm).
375    pub tol: f64,
376    /// Finite-difference step.
377    pub fd_step: f64,
378    /// Mini-batch size b for inner gradient updates.
379    pub mini_batch: usize,
380}
381
382impl Default for SpiderOptions {
383    fn default() -> Self {
384        Self {
385            n_outer: 30,
386            inner_steps: 50,
387            step_size: 5e-4,
388            tol: 1e-6,
389            fd_step: 1e-5,
390            mini_batch: 4,
391        }
392    }
393}
394
395/// Result from SPIDER optimisation.
396#[derive(Debug, Clone)]
397pub struct SpiderResult {
398    /// Approximate minimiser.
399    pub x: Array1<f64>,
400    /// Norm of the last gradient estimator.
401    pub grad_norm: f64,
402    /// Total gradient evaluations (approximate).
403    pub n_grad_evals: usize,
404    /// Whether tolerance was met.
405    pub converged: bool,
406}
407
408/// SPIDER (Stochastic Path-Integrated Differential EstimatoR) optimizer.
409///
410/// SPIDER extends SARAH with mini-batch gradient differences, achieving
411/// near-optimal oracle complexity for non-convex stochastic optimisation.
412///
413/// The estimator update:
414///   vₜ = (1/b) Σᵢ∈Bₜ [∇fᵢ(xₜ) - ∇fᵢ(xₜ₋₁)] + vₜ₋₁
415///
416/// # Arguments
417///
418/// * `f`       – per-sample loss: (x, sample) → f64
419/// * `x0`      – starting point
420/// * `samples` – full dataset
421/// * `opts`    – SPIDER options
422pub fn spider<F>(
423    f: &mut F,
424    x0: &ArrayView1<f64>,
425    samples: &[Array1<f64>],
426    opts: &SpiderOptions,
427) -> OptimizeResult<SpiderResult>
428where
429    F: FnMut(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
430{
431    let n = x0.len();
432    if n == 0 {
433        return Err(OptimizeError::ValueError(
434            "x0 must be non-empty".to_string(),
435        ));
436    }
437    if samples.is_empty() {
438        return Err(OptimizeError::ValueError(
439            "samples must be non-empty".to_string(),
440        ));
441    }
442
443    let m = samples.len();
444    let b = opts.mini_batch.max(1).min(m);
445    let mut x = x0.to_owned();
446    let mut converged = false;
447    let mut total_evals: usize = 0;
448    let mut rng: u64 = 999999999999;
449
450    for _ in 0..opts.n_outer {
451        // Full gradient at start of outer epoch
452        let mut v = full_grad(f, &x.view(), samples, opts.fd_step);
453        total_evals += m * (n + 1);
454
455        let g_norm = v.iter().map(|vi| vi * vi).sum::<f64>().sqrt();
456        if g_norm < opts.tol {
457            converged = true;
458            return Ok(SpiderResult {
459                x,
460                grad_norm: g_norm,
461                n_grad_evals: total_evals,
462                converged,
463            });
464        }
465
466        // Descend with v₀
467        for i in 0..n {
468            x[i] -= opts.step_size * v[i];
469        }
470
471        let mut x_prev = x.clone();
472
473        for _ in 0..opts.inner_steps {
474            // Sample mini-batch B of size b
475            let mut batch_indices = Vec::with_capacity(b);
476            for _ in 0..b {
477                rng = rng
478                    .wrapping_mul(6364136223846793005)
479                    .wrapping_add(1442695040888963407);
480                batch_indices.push((rng >> 33) as usize % m);
481            }
482
483            // Mini-batch gradient difference
484            let mut diff = Array1::<f64>::zeros(n);
485            for &idx in &batch_indices {
486                let s = &samples[idx];
487                let g_curr = finite_diff_grad(f, &x.view(), &s.view(), opts.fd_step);
488                let g_prev = finite_diff_grad(f, &x_prev.view(), &s.view(), opts.fd_step);
489                total_evals += 2 * (n + 1);
490                for i in 0..n {
491                    diff[i] += (g_curr[i] - g_prev[i]) / b as f64;
492                }
493            }
494
495            // Recursive estimator update
496            let v_new: Array1<f64> = diff.iter().zip(v.iter()).map(|(&d, &vp)| d + vp).collect();
497            x_prev = x.clone();
498            for i in 0..n {
499                x[i] -= opts.step_size * v_new[i];
500            }
501            v = v_new;
502
503            let cur_norm = v.iter().map(|vi| vi * vi).sum::<f64>().sqrt();
504            if cur_norm < opts.tol {
505                converged = true;
506                return Ok(SpiderResult {
507                    x,
508                    grad_norm: cur_norm,
509                    n_grad_evals: total_evals,
510                    converged,
511                });
512            }
513        }
514    }
515
516    let g_norm = v_norm_approx(&full_grad(f, &x.view(), samples, opts.fd_step));
517
518    Ok(SpiderResult {
519        x,
520        grad_norm: g_norm,
521        n_grad_evals: total_evals,
522        converged,
523    })
524}
525
526#[inline]
527fn v_norm_approx(v: &Array1<f64>) -> f64 {
528    v.iter().map(|vi| vi * vi).sum::<f64>().sqrt()
529}
530
531// ─── Tests ───────────────────────────────────────────────────────────────────
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use scirs2_core::ndarray::array;
537
538    /// Generate a simple quadratic dataset: f(x, ξ) = (x₀ - ξ₀)² + (x₁ - ξ₁)²
539    /// Minimiser at E[ξ] = (1, 2).
540    fn make_samples() -> Vec<Array1<f64>> {
541        vec![
542            array![0.9, 1.8],
543            array![1.1, 2.2],
544            array![1.0, 2.0],
545            array![0.8, 1.9],
546            array![1.2, 2.1],
547            array![1.0, 2.0],
548            array![0.95, 1.95],
549            array![1.05, 2.05],
550        ]
551    }
552
553    fn sample_loss(x: &ArrayView1<f64>, s: &ArrayView1<f64>) -> f64 {
554        (x[0] - s[0]).powi(2) + (x[1] - s[1]).powi(2)
555    }
556
557    #[test]
558    fn test_svrg_quadratic() {
559        let samples = make_samples();
560        let x0 = array![0.0, 0.0];
561        let opts = SvrgOptions {
562            n_epochs: 100,
563            inner_steps: 50,
564            step_size: 0.1,
565            tol: 1e-4,
566            fd_step: 1e-5,
567        };
568        let res = svrg(&mut |x, s| sample_loss(x, s), &x0.view(), &samples, &opts)
569            .expect("failed to create res");
570        assert!(
571            (res.x[0] - 1.0).abs() < 0.3,
572            "SVRG: expected x[0]≈1.0, got {}",
573            res.x[0]
574        );
575        assert!(
576            (res.x[1] - 2.0).abs() < 0.3,
577            "SVRG: expected x[1]≈2.0, got {}",
578            res.x[1]
579        );
580    }
581
582    #[test]
583    fn test_sarah_quadratic() {
584        let samples = make_samples();
585        let x0 = array![0.0, 0.0];
586        let opts = SarahOptions {
587            n_outer: 80,
588            inner_steps: 30,
589            step_size: 0.05,
590            tol: 1e-4,
591            fd_step: 1e-5,
592        };
593        let res = sarah(&mut |x, s| sample_loss(x, s), &x0.view(), &samples, &opts)
594            .expect("failed to create res");
595        assert!(
596            (res.x[0] - 1.0).abs() < 0.3,
597            "SARAH: expected x[0]≈1.0, got {}",
598            res.x[0]
599        );
600        assert!(
601            (res.x[1] - 2.0).abs() < 0.3,
602            "SARAH: expected x[1]≈2.0, got {}",
603            res.x[1]
604        );
605    }
606
607    #[test]
608    fn test_spider_quadratic() {
609        let samples = make_samples();
610        let x0 = array![0.0, 0.0];
611        let opts = SpiderOptions {
612            n_outer: 80,
613            inner_steps: 30,
614            step_size: 0.05,
615            tol: 1e-4,
616            fd_step: 1e-5,
617            mini_batch: 2,
618        };
619        let res = spider(&mut |x, s| sample_loss(x, s), &x0.view(), &samples, &opts)
620            .expect("failed to create res");
621        assert!(
622            (res.x[0] - 1.0).abs() < 0.4,
623            "SPIDER: expected x[0]≈1.0, got {}",
624            res.x[0]
625        );
626        assert!(
627            (res.x[1] - 2.0).abs() < 0.4,
628            "SPIDER: expected x[1]≈2.0, got {}",
629            res.x[1]
630        );
631    }
632}