Skip to main content

scirs2_optimize/stochastic/
approximation.rs

1//! Classical stochastic approximation algorithms.
2//!
3//! This submodule implements the three canonical stochastic approximation (SA)
4//! methods from the 1950s–1960s together with a modern SPSA variant:
5//!
6//! | Algorithm | Reference | Use case |
7//! |-----------|-----------|----------|
8//! | Robbins-Monro | Robbins & Monro (1951) | Root-finding under noise |
9//! | Kiefer-Wolfowitz | Kiefer & Wolfowitz (1952) | Gradient-free stochastic minimisation |
10//! | SPSA | Spall (1992) | High-dimensional gradient-free SA |
11//!
12//! # Notation
13//!
14//! - xₖ  : current iterate
15//! - aₖ  : gain sequence for update step (must satisfy Σ aₖ = ∞, Σ aₖ² < ∞)
16//! - cₖ  : gain sequence for finite-difference width (must → 0)
17
18use crate::error::{OptimizeError, OptimizeResult};
19use scirs2_core::ndarray::{Array1, ArrayView1};
20
21// ─── Robbins-Monro ───────────────────────────────────────────────────────────
22
23/// Result from Robbins-Monro root finding.
24#[derive(Debug, Clone)]
25pub struct RobbinsMonroResult {
26    /// Approximate root: θ* such that M(θ*) ≈ 0.
27    pub x: Array1<f64>,
28    /// Final residual ‖M(xₖ)‖.
29    pub residual: f64,
30    /// Number of iterations.
31    pub n_iter: usize,
32    /// Whether tolerance was met.
33    pub converged: bool,
34}
35
36/// Options for the Robbins-Monro algorithm.
37#[derive(Debug, Clone)]
38pub struct RobbinsMonroOptions {
39    /// Maximum number of iterations.
40    pub max_iter: usize,
41    /// Convergence tolerance on ‖update step‖.
42    pub tol: f64,
43    /// Exponent α in the gain aₖ = a / kᵅ.  Standard choice: α = 1.0.
44    pub alpha: f64,
45    /// Scale a in aₖ = a / kᵅ.
46    pub a: f64,
47}
48
49impl Default for RobbinsMonroOptions {
50    fn default() -> Self {
51        Self {
52            max_iter: 10_000,
53            tol: 1e-6,
54            alpha: 1.0,
55            a: 1.0,
56        }
57    }
58}
59
60/// Robbins-Monro stochastic root-finding algorithm.
61///
62/// Finds θ* such that M(θ*) = 0, where M is a noisy mapping (possibly
63/// the gradient of an expected loss).  The update rule is:
64///
65/// θₖ₊₁ = θₖ - aₖ · M(θₖ)
66///
67/// with gain aₖ = a / k^α.
68///
69/// # Arguments
70///
71/// * `m`    – noisy mapping M: θ → ℝⁿ (should satisfy E[M(θ)] ≈ ∇L(θ))
72/// * `x0`   – initial point
73/// * `opts` – algorithm options
74pub fn robbins_monro<M>(
75    m: &mut M,
76    x0: &ArrayView1<f64>,
77    opts: &RobbinsMonroOptions,
78) -> OptimizeResult<RobbinsMonroResult>
79where
80    M: FnMut(&ArrayView1<f64>) -> Array1<f64>,
81{
82    let n = x0.len();
83    if n == 0 {
84        return Err(OptimizeError::ValueError(
85            "x0 must be non-empty".to_string(),
86        ));
87    }
88
89    let mut x = x0.to_owned();
90    let mut converged = false;
91    let mut residual = f64::INFINITY;
92
93    for k in 1..=opts.max_iter {
94        let mk = m(&x.view());
95        if mk.len() != n {
96            return Err(OptimizeError::ValueError(format!(
97                "M returned length {} but x has length {}",
98                mk.len(),
99                n
100            )));
101        }
102        let ak = opts.a / (k as f64).powf(opts.alpha);
103        let mut step_norm = 0.0_f64;
104        for i in 0..n {
105            let step = ak * mk[i];
106            x[i] -= step;
107            step_norm += step * step;
108        }
109        residual = step_norm.sqrt();
110        if residual < opts.tol {
111            converged = true;
112            residual = mk.iter().map(|v| v * v).sum::<f64>().sqrt();
113            return Ok(RobbinsMonroResult {
114                x,
115                residual,
116                n_iter: k,
117                converged,
118            });
119        }
120    }
121
122    // Final residual
123    let mk_final = m(&x.view());
124    residual = mk_final.iter().map(|v| v * v).sum::<f64>().sqrt();
125
126    Ok(RobbinsMonroResult {
127        x,
128        residual,
129        n_iter: opts.max_iter,
130        converged,
131    })
132}
133
134// ─── Kiefer-Wolfowitz ────────────────────────────────────────────────────────
135
136/// Result from the Kiefer-Wolfowitz algorithm.
137#[derive(Debug, Clone)]
138pub struct KieferWolfowitzResult {
139    /// Approximate minimiser.
140    pub x: Array1<f64>,
141    /// Function value at x.
142    pub fun: f64,
143    /// Number of iterations.
144    pub n_iter: usize,
145    /// Whether tolerance was met.
146    pub converged: bool,
147}
148
149/// Options for the Kiefer-Wolfowitz algorithm.
150#[derive(Debug, Clone)]
151pub struct KieferWolfowitzOptions {
152    /// Maximum iterations.
153    pub max_iter: usize,
154    /// Convergence tolerance (step norm).
155    pub tol: f64,
156    /// Exponent α in aₖ = a / kᵅ.  Must satisfy α ∈ (1/2, 1].
157    pub alpha: f64,
158    /// Exponent γ in cₖ = c / kᵞ.  Must satisfy γ ∈ (0, 1/6] for unbiased gradients.
159    pub gamma: f64,
160    /// Scale constant a.
161    pub a: f64,
162    /// Scale constant c (initial finite-difference width).
163    pub c: f64,
164}
165
166impl Default for KieferWolfowitzOptions {
167    fn default() -> Self {
168        Self {
169            max_iter: 10_000,
170            tol: 1e-6,
171            alpha: 0.602,
172            gamma: 0.101,
173            a: 0.1,
174            c: 0.1,
175        }
176    }
177}
178
179/// Kiefer-Wolfowitz gradient-free stochastic approximation.
180///
181/// Minimises E[L(x, ξ)] using only noisy function evaluations (no gradients).
182/// The finite-difference gradient estimate in dimension i is:
183///
184/// ĝᵢ = (L(x + cₖ eᵢ) - L(x - cₖ eᵢ)) / (2 cₖ)
185///
186/// followed by the update x ← x - aₖ ĝ.
187///
188/// # Arguments
189///
190/// * `loss` – noisy loss function (x) → f64 (internally uses two evaluations per dimension per step)
191/// * `x0`   – initial point
192/// * `opts` – algorithm options
193pub fn kiefer_wolfowitz<L>(
194    loss: &mut L,
195    x0: &ArrayView1<f64>,
196    opts: &KieferWolfowitzOptions,
197) -> OptimizeResult<KieferWolfowitzResult>
198where
199    L: FnMut(&ArrayView1<f64>) -> f64,
200{
201    let n = x0.len();
202    if n == 0 {
203        return Err(OptimizeError::ValueError(
204            "x0 must be non-empty".to_string(),
205        ));
206    }
207
208    let mut x = x0.to_owned();
209    let mut converged = false;
210
211    for k in 1..=opts.max_iter {
212        let ak = opts.a / (k as f64).powf(opts.alpha);
213        let ck = opts.c / (k as f64).powf(opts.gamma);
214
215        // Finite-difference gradient
216        let mut grad = Array1::<f64>::zeros(n);
217        for i in 0..n {
218            let mut x_fwd = x.clone();
219            let mut x_bwd = x.clone();
220            x_fwd[i] += ck;
221            x_bwd[i] -= ck;
222            grad[i] = (loss(&x_fwd.view()) - loss(&x_bwd.view())) / (2.0 * ck);
223        }
224
225        let mut step_norm = 0.0_f64;
226        for i in 0..n {
227            let step = ak * grad[i];
228            x[i] -= step;
229            step_norm += step * step;
230        }
231
232        if step_norm.sqrt() < opts.tol {
233            converged = true;
234            let fun = loss(&x.view());
235            return Ok(KieferWolfowitzResult {
236                x,
237                fun,
238                n_iter: k,
239                converged,
240            });
241        }
242    }
243
244    let fun = loss(&x.view());
245    Ok(KieferWolfowitzResult {
246        x,
247        fun,
248        n_iter: opts.max_iter,
249        converged,
250    })
251}
252
253// ─── SPSA ────────────────────────────────────────────────────────────────────
254
255/// Options for the SPSA optimizer.
256#[derive(Debug, Clone)]
257pub struct SpsaOptions {
258    /// Maximum number of iterations.
259    pub max_iter: usize,
260    /// Convergence tolerance on ‖step‖.
261    pub tol: f64,
262    /// Exponent α in aₖ = a / (A + k)^α.
263    pub alpha: f64,
264    /// Exponent γ in cₖ = c / k^γ.
265    pub gamma: f64,
266    /// Scale a.
267    pub a: f64,
268    /// Stability constant A (typically ≈ 0.1 * max_iter).
269    pub big_a: f64,
270    /// Finite-difference constant c.
271    pub c: f64,
272}
273
274impl Default for SpsaOptions {
275    fn default() -> Self {
276        Self {
277            max_iter: 5_000,
278            tol: 1e-6,
279            alpha: 0.602,
280            gamma: 0.101,
281            a: 0.1,
282            big_a: 100.0,
283            c: 0.1,
284        }
285    }
286}
287
288/// Result from the SPSA algorithm.
289#[derive(Debug, Clone)]
290pub struct SpsaResult {
291    /// Approximate minimiser.
292    pub x: Array1<f64>,
293    /// Function value at x.
294    pub fun: f64,
295    /// Number of iterations.
296    pub n_iter: usize,
297    /// Whether tolerance was met.
298    pub converged: bool,
299}
300
301/// Compute one SPSA gradient-estimate step.
302///
303/// The simultaneous perturbation direction Δ is sampled from {±1}ⁿ (Rademacher).
304/// The gradient estimate is:
305///
306/// ĝ(x) = [f(x + cₖ Δ) - f(x - cₖ Δ)] / (2 cₖ) * (1/Δᵢ)  component-wise.
307///
308/// Returns the updated x after one SPSA step.
309///
310/// # Arguments
311///
312/// * `f`     – noisy objective function
313/// * `x`     – current point (modified in place)
314/// * `k`     – current iteration number (1-indexed)
315/// * `opts`  – SPSA options
316/// * `rng`   – mutable u64 state for the Rademacher perturbation (LCG)
317pub fn spsa_step<F>(
318    f: &mut F,
319    x: &mut Array1<f64>,
320    k: usize,
321    opts: &SpsaOptions,
322    rng_state: &mut u64,
323) -> f64
324where
325    F: FnMut(&ArrayView1<f64>) -> f64,
326{
327    let n = x.len();
328    let ak = opts.a / (opts.big_a + k as f64).powf(opts.alpha);
329    let ck = opts.c / (k as f64).powf(opts.gamma);
330
331    // Draw Rademacher perturbation Δ ∈ {-1, +1}ⁿ using LCG
332    let mut delta = Array1::<f64>::zeros(n);
333    for i in 0..n {
334        *rng_state = rng_state
335            .wrapping_mul(6364136223846793005)
336            .wrapping_add(1442695040888963407);
337        delta[i] = if (*rng_state >> 63) == 0 { 1.0 } else { -1.0 };
338    }
339
340    // Two-sided function evaluations
341    let x_fwd: Array1<f64> = x
342        .iter()
343        .zip(delta.iter())
344        .map(|(&xi, &di)| xi + ck * di)
345        .collect();
346    let x_bwd: Array1<f64> = x
347        .iter()
348        .zip(delta.iter())
349        .map(|(&xi, &di)| xi - ck * di)
350        .collect();
351    let f_fwd = f(&x_fwd.view());
352    let f_bwd = f(&x_bwd.view());
353
354    let diff = (f_fwd - f_bwd) / (2.0 * ck);
355
356    // Update: x ← x - aₖ * ĝ  where ĝᵢ = diff / Δᵢ
357    let mut step_sq = 0.0_f64;
358    for i in 0..n {
359        let gi = diff / delta[i]; // Δᵢ ∈ {±1} so 1/Δᵢ = Δᵢ
360        let step = ak * gi;
361        x[i] -= step;
362        step_sq += step * step;
363    }
364    step_sq.sqrt()
365}
366
367/// Simultaneous Perturbation Stochastic Approximation (SPSA) optimizer.
368///
369/// SPSA uses only two function evaluations per iteration (regardless of dimension n),
370/// making it very efficient for high-dimensional black-box minimisation.
371///
372/// # Arguments
373///
374/// * `f`    – noisy objective (minimised)
375/// * `x0`   – starting point
376/// * `opts` – SPSA options
377pub fn spsa_minimize<F>(
378    f: &mut F,
379    x0: &ArrayView1<f64>,
380    opts: &SpsaOptions,
381) -> OptimizeResult<SpsaResult>
382where
383    F: FnMut(&ArrayView1<f64>) -> f64,
384{
385    if x0.is_empty() {
386        return Err(OptimizeError::ValueError(
387            "x0 must be non-empty".to_string(),
388        ));
389    }
390
391    let mut x = x0.to_owned();
392    let mut rng_state: u64 = 12345678901234567;
393    let mut converged = false;
394
395    for k in 1..=opts.max_iter {
396        let step_norm = spsa_step(f, &mut x, k, opts, &mut rng_state);
397        if step_norm < opts.tol {
398            converged = true;
399            let fun = f(&x.view());
400            return Ok(SpsaResult {
401                x,
402                fun,
403                n_iter: k,
404                converged,
405            });
406        }
407    }
408
409    let fun = f(&x.view());
410    Ok(SpsaResult {
411        x,
412        fun,
413        n_iter: opts.max_iter,
414        converged,
415    })
416}
417
418// ─── Tests ───────────────────────────────────────────────────────────────────
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use scirs2_core::ndarray::array;
424
425    #[test]
426    fn test_robbins_monro_linear() {
427        // M(x) = x - 2 (root at x=2)
428        let mut m = |x: &ArrayView1<f64>| array![x[0] - 2.0];
429        let x0 = array![0.0];
430        let opts = RobbinsMonroOptions {
431            max_iter: 50_000,
432            tol: 1e-4,
433            a: 1.0,
434            alpha: 1.0,
435        };
436        let res = robbins_monro(&mut m, &x0.view(), &opts).expect("failed to create res");
437        assert!(
438            (res.x[0] - 2.0).abs() < 0.1,
439            "expected x* ≈ 2.0, got {}",
440            res.x[0]
441        );
442    }
443
444    #[test]
445    fn test_kiefer_wolfowitz_quadratic() {
446        // L(x) = (x-3)²; minimiser at x=3
447        let mut loss = |x: &ArrayView1<f64>| (x[0] - 3.0).powi(2);
448        let x0 = array![0.0];
449        let opts = KieferWolfowitzOptions {
450            max_iter: 20_000,
451            tol: 1e-5,
452            ..Default::default()
453        };
454        let res = kiefer_wolfowitz(&mut loss, &x0.view(), &opts).expect("failed to create res");
455        assert!(
456            (res.x[0] - 3.0).abs() < 0.2,
457            "expected x* ≈ 3.0, got {}",
458            res.x[0]
459        );
460    }
461
462    #[test]
463    fn test_spsa_quadratic() {
464        // f(x) = (x₀-1)² + (x₁-2)²; minimiser at (1, 2)
465        let mut f = |x: &ArrayView1<f64>| (x[0] - 1.0).powi(2) + (x[1] - 2.0).powi(2);
466        let x0 = array![0.0, 0.0];
467        let opts = SpsaOptions {
468            max_iter: 10_000,
469            tol: 1e-5,
470            a: 0.5,
471            big_a: 50.0,
472            c: 0.2,
473            ..Default::default()
474        };
475        let res = spsa_minimize(&mut f, &x0.view(), &opts).expect("failed to create res");
476        assert!(
477            (res.x[0] - 1.0).abs() < 0.3,
478            "expected x[0] ≈ 1.0, got {}",
479            res.x[0]
480        );
481        assert!(
482            (res.x[1] - 2.0).abs() < 0.3,
483            "expected x[1] ≈ 2.0, got {}",
484            res.x[1]
485        );
486    }
487}