Skip to main content

scirs2_optimize/second_order/
slbfgs.rs

1//! Stochastic L-BFGS (S-L-BFGS) with SVRG-style variance reduction.
2//!
3//! Implements the algorithm of Moritz, Nishihara & Jordan (2016) combining:
4//! - Mini-batch stochastic gradients
5//! - L-BFGS curvature estimates from larger batches
6//! - SVRG variance reduction to stabilize curvature pairs
7//!
8//! ## References
9//!
10//! - Moritz, P., Nishihara, R., Jordan, M.I. (2016).
11//!   "A linearly-convergent stochastic L-BFGS algorithm." AISTATS.
12//! - Johnson, R., Zhang, T. (2013).
13//!   "Accelerating stochastic gradient descent using predictive variance reduction."
14//!   NIPS 26.
15
16use super::lbfgsb::hv_product;
17use super::types::{OptResult, SlbfgsConfig};
18use crate::error::OptimizeError;
19
20// ─── LCG pseudo-random number generator ──────────────────────────────────────
21
22/// Linear congruential generator for mini-batch selection.
23///
24/// Uses parameters from Numerical Recipes: a=1664525, c=1013904223, m=2^32.
25pub struct Lcg {
26    state: u64,
27}
28
29impl Lcg {
30    /// Create a new LCG with the given seed.
31    pub fn new(seed: u64) -> Self {
32        Self { state: seed }
33    }
34
35    /// Advance the state and return the next value in [0, 2^32).
36    pub fn next_u32(&mut self) -> u32 {
37        self.state = self
38            .state
39            .wrapping_mul(1_664_525)
40            .wrapping_add(1_013_904_223)
41            & 0xFFFF_FFFF;
42        self.state as u32
43    }
44
45    /// Return a random index in [0, n).
46    pub fn next_usize(&mut self, n: usize) -> usize {
47        (self.next_u32() as usize) % n
48    }
49
50    /// Fill `buf` with distinct random indices drawn from [0, n) without replacement.
51    ///
52    /// Uses a partial Fisher-Yates shuffle on a temporary index array.
53    pub fn sample_without_replacement(&mut self, n: usize, k: usize, buf: &mut Vec<usize>) {
54        buf.clear();
55        if k == 0 || n == 0 {
56            return;
57        }
58        let k = k.min(n);
59        let mut pool: Vec<usize> = (0..n).collect();
60        for i in 0..k {
61            let j = i + self.next_usize(n - i);
62            pool.swap(i, j);
63            buf.push(pool[i]);
64        }
65    }
66}
67
68// ─── SVRG variance reduction ─────────────────────────────────────────────────
69
70/// Compute the SVRG-corrected stochastic gradient.
71///
72/// g_corrected = ∇f_{i}(x_k) − ∇f_{i}(x̃) + ∇F(x̃)
73///
74/// where x̃ is the snapshot point and ∇F(x̃) is the full (snapshot) gradient.
75/// This correction gives an unbiased estimate of ∇F(x_k) with lower variance
76/// than the raw stochastic gradient.
77fn svrg_gradient(
78    stoch_f_and_g: &dyn Fn(&[f64], &[usize]) -> (f64, Vec<f64>),
79    x_k: &[f64],
80    x_snap: &[f64],
81    g_snap: &[f64], // ∇F(x̃) — full gradient at snapshot
82    batch: &[usize],
83) -> Vec<f64> {
84    let (_, g_k) = stoch_f_and_g(x_k, batch);
85    let (_, g_s) = stoch_f_and_g(x_snap, batch);
86    g_k.iter()
87        .zip(g_s.iter())
88        .zip(g_snap.iter())
89        .map(|((gki, gsi), gfi)| gki - gsi + gfi)
90        .collect()
91}
92
93// ─── Curvature pair computation ───────────────────────────────────────────────
94
95/// Compute a curvature pair (y_k) using a large mini-batch.
96///
97/// y_k = (1/|B|) Σ_{i∈B} (∇f_i(x_{k+1}) − ∇f_i(x_k))
98///
99/// Uses a larger batch than the gradient estimate to reduce noise in the
100/// curvature estimate, which is crucial for quasi-Newton stability.
101fn curvature_y(
102    stoch_f_and_g: &dyn Fn(&[f64], &[usize]) -> (f64, Vec<f64>),
103    x_new: &[f64],
104    x_old: &[f64],
105    batch: &[usize],
106) -> Vec<f64> {
107    let n = x_new.len();
108    if batch.is_empty() {
109        return vec![0.0; n];
110    }
111    let (_, g_new) = stoch_f_and_g(x_new, batch);
112    let (_, g_old) = stoch_f_and_g(x_old, batch);
113    g_new
114        .iter()
115        .zip(g_old.iter())
116        .map(|(gn, go)| gn - go)
117        .collect()
118}
119
120// ─── S-L-BFGS optimizer ───────────────────────────────────────────────────────
121
122/// Stochastic L-BFGS optimizer with optional SVRG variance reduction.
123pub struct SlbfgsOptimizer {
124    /// Algorithm configuration.
125    pub config: SlbfgsConfig,
126}
127
128impl SlbfgsOptimizer {
129    /// Create with given configuration.
130    pub fn new(config: SlbfgsConfig) -> Self {
131        Self { config }
132    }
133
134    /// Create with default configuration.
135    pub fn default_config() -> Self {
136        Self {
137            config: SlbfgsConfig::default(),
138        }
139    }
140
141    /// Minimize a stochastic objective function using S-L-BFGS.
142    ///
143    /// # Arguments
144    /// * `stoch_f_and_g` — stochastic oracle: given `x` and a batch of indices
145    ///   (subset of `0..n_samples`), returns (f, ∇f) on that mini-batch.
146    /// * `full_grad_fn` — deterministic oracle for full gradient at snapshot:
147    ///   `(x) → (f, ∇f)` over the full dataset.
148    /// * `n_samples` — total number of data points.
149    /// * `x0` — initial point.
150    ///
151    /// # Returns
152    /// An `OptResult` describing the found minimizer and convergence status.
153    pub fn minimize(
154        &self,
155        stoch_f_and_g: &dyn Fn(&[f64], &[usize]) -> (f64, Vec<f64>),
156        full_grad_fn: &dyn Fn(&[f64]) -> (f64, Vec<f64>),
157        n_samples: usize,
158        x0: &[f64],
159    ) -> Result<OptResult, OptimizeError> {
160        let n = x0.len();
161        let cfg = &self.config;
162        let m = cfg.m;
163
164        if n_samples == 0 {
165            return Err(OptimizeError::ValueError(
166                "n_samples must be positive".to_string(),
167            ));
168        }
169
170        let mut x = x0.to_vec();
171        let mut rng = Lcg::new(cfg.seed);
172
173        // L-BFGS circular buffer
174        let mut s_hist: Vec<Vec<f64>> = Vec::with_capacity(m);
175        let mut y_hist: Vec<Vec<f64>> = Vec::with_capacity(m);
176        let mut rho_hist: Vec<f64> = Vec::with_capacity(m);
177        let mut gamma = 1.0_f64;
178
179        // Snapshot for SVRG
180        let mut x_snap = x.clone();
181        let (mut f_snap, mut g_snap) = full_grad_fn(&x_snap);
182
183        let mut n_iter = 0usize;
184        let mut converged = false;
185        let mut batch_buf: Vec<usize> = Vec::with_capacity(cfg.batch_size);
186        let mut curv_batch_buf: Vec<usize> = Vec::with_capacity(cfg.curvature_batch_size);
187
188        // Track best solution seen
189        let mut best_x = x.clone();
190        let mut best_f = f_snap;
191
192        for iter in 0..cfg.max_iter {
193            n_iter = iter;
194
195            // Re-snapshot for SVRG
196            if cfg.variance_reduction && iter % cfg.snapshot_freq == 0 {
197                let (fs, gs) = full_grad_fn(&x);
198                x_snap = x.clone();
199                f_snap = fs;
200                g_snap = gs;
201            }
202
203            // Convergence check (on snapshot gradient)
204            let gn = g_snap.iter().map(|g| g * g).sum::<f64>().sqrt();
205            if gn < cfg.tol {
206                converged = true;
207                break;
208            }
209
210            // Sample mini-batch for gradient
211            rng.sample_without_replacement(n_samples, cfg.batch_size, &mut batch_buf);
212
213            // Compute stochastic gradient (SVRG-corrected or plain)
214            let g_k = if cfg.variance_reduction {
215                svrg_gradient(stoch_f_and_g, &x, &x_snap, &g_snap, &batch_buf)
216            } else {
217                let (_, gk) = stoch_f_and_g(&x, &batch_buf);
218                gk
219            };
220
221            // Compute L-BFGS search direction
222            let hg = hv_product(&g_k, &s_hist, &y_hist, &rho_hist, gamma);
223            let d: Vec<f64> = hg.iter().map(|v| -v).collect();
224
225            // Check descent
226            let slope: f64 = g_k.iter().zip(d.iter()).map(|(gi, di)| gi * di).sum();
227            let d = if slope >= 0.0 {
228                g_k.iter().map(|gi| -gi).collect::<Vec<f64>>()
229            } else {
230                d
231            };
232
233            // Step
234            let x_new: Vec<f64> = x
235                .iter()
236                .zip(d.iter())
237                .map(|(xi, di)| xi + cfg.lr * di)
238                .collect();
239
240            // Compute curvature pair with a larger batch
241            rng.sample_without_replacement(
242                n_samples,
243                cfg.curvature_batch_size,
244                &mut curv_batch_buf,
245            );
246            let s_k: Vec<f64> = (0..n).map(|i| x_new[i] - x[i]).collect();
247            let y_k = curvature_y(stoch_f_and_g, &x_new, &x, &curv_batch_buf);
248
249            // Curvature condition: y^T s > 0
250            let sy: f64 = s_k.iter().zip(y_k.iter()).map(|(si, yi)| si * yi).sum();
251            if sy > 1e-14 * s_k.iter().map(|si| si * si).sum::<f64>().sqrt() {
252                if s_hist.len() == m {
253                    s_hist.remove(0);
254                    y_hist.remove(0);
255                    rho_hist.remove(0);
256                }
257                let yy: f64 = y_k.iter().map(|yi| yi * yi).sum();
258                if yy > 1e-14 {
259                    gamma = sy / yy;
260                }
261                rho_hist.push(1.0 / sy);
262                s_hist.push(s_k);
263                y_hist.push(y_k);
264            }
265
266            x = x_new;
267
268            // Update best
269            let (f_curr, _) = full_grad_fn(&x);
270            if f_curr < best_f {
271                best_f = f_curr;
272                best_x = x.clone();
273            }
274        }
275
276        let (_, g_final) = full_grad_fn(&best_x);
277        let grad_norm = g_final.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
278
279        Ok(OptResult {
280            x: best_x,
281            f_val: best_f,
282            grad_norm,
283            n_iter,
284            converged,
285        })
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use crate::second_order::types::SlbfgsConfig;
293
294    /// Simple separable stochastic problem: f(x) = sum_i (x_i - 1)^2 / n
295    /// Each sample i contributes (x_i - 1)^2.
296    fn stoch_quad(x: &[f64], batch: &[usize]) -> (f64, Vec<f64>) {
297        let n = x.len();
298        if batch.is_empty() {
299            let f: f64 = x.iter().map(|xi| (xi - 1.0).powi(2)).sum::<f64>() / n as f64;
300            let g: Vec<f64> = x.iter().map(|xi| 2.0 * (xi - 1.0) / n as f64).collect();
301            return (f, g);
302        }
303        let bs = batch.len() as f64;
304        let f: f64 = batch.iter().map(|&i| (x[i % n] - 1.0).powi(2)).sum::<f64>() / bs;
305        let mut g = vec![0.0_f64; n];
306        for &idx in batch {
307            g[idx % n] += 2.0 * (x[idx % n] - 1.0) / bs;
308        }
309        (f, g)
310    }
311
312    fn full_quad(x: &[f64]) -> (f64, Vec<f64>) {
313        let n = x.len();
314        let all: Vec<usize> = (0..n).collect();
315        stoch_quad(x, &all)
316    }
317
318    #[test]
319    fn test_slbfgs_gradient_variance_reduction() {
320        // SVRG-corrected gradient should have zero mean at the optimum
321        let x_star = vec![1.0; 4];
322        let x_snap = vec![1.0; 4];
323        let (_, g_snap) = full_quad(&x_snap);
324        let batch = vec![0, 1, 2, 3];
325        let g_corr = svrg_gradient(&stoch_quad, &x_star, &x_snap, &g_snap, &batch);
326        for gi in &g_corr {
327            assert!(
328                gi.abs() < 1e-12,
329                "Corrected gradient should be zero at optimum: got {}",
330                gi
331            );
332        }
333    }
334
335    #[test]
336    fn test_slbfgs_curvature_condition() {
337        // y^T s > 0 must hold for a strongly convex quadratic
338        let x_old = vec![2.0, 3.0];
339        let x_new = vec![1.5, 2.5];
340        let all_batch: Vec<usize> = (0..2).collect();
341        let y = curvature_y(&stoch_quad, &x_new, &x_old, &all_batch);
342        let s: Vec<f64> = x_new
343            .iter()
344            .zip(x_old.iter())
345            .map(|(xn, xo)| xn - xo)
346            .collect();
347        let sy: f64 = s.iter().zip(y.iter()).map(|(si, yi)| si * yi).sum();
348        assert!(sy > 0.0, "Curvature condition y^T s > 0 violated: {}", sy);
349    }
350
351    #[test]
352    fn test_slbfgs_stochastic_convergence() {
353        let mut cfg = SlbfgsConfig::default();
354        cfg.max_iter = 300;
355        cfg.lr = 0.05;
356        cfg.batch_size = 4;
357        cfg.curvature_batch_size = 8;
358        cfg.variance_reduction = true;
359        cfg.tol = 1e-4;
360        let opt = SlbfgsOptimizer::new(cfg);
361
362        let x0 = vec![0.0_f64; 4];
363        let result = opt
364            .minimize(&stoch_quad, &full_quad, 4, &x0)
365            .expect("S-L-BFGS failed");
366        for xi in &result.x {
367            assert!(
368                (xi - 1.0).abs() < 0.2,
369                "S-L-BFGS did not converge: x={:?}",
370                result.x
371            );
372        }
373    }
374
375    #[test]
376    fn test_second_order_config_default() {
377        use crate::second_order::types::{LbfgsBConfig, SlbfgsConfig, Sr1Config};
378        let _c1 = LbfgsBConfig::default();
379        let _c2 = Sr1Config::default();
380        let _c3 = SlbfgsConfig::default();
381        // All should construct without error
382    }
383
384    #[test]
385    fn test_slbfgs_batch_selection() {
386        let mut rng = Lcg::new(12345);
387        let mut buf = Vec::new();
388        rng.sample_without_replacement(100, 10, &mut buf);
389        // Check: length = k
390        assert_eq!(buf.len(), 10);
391        // Check: no duplicates
392        let mut sorted = buf.clone();
393        sorted.sort_unstable();
394        sorted.dedup();
395        assert_eq!(sorted.len(), 10, "Duplicate indices in batch selection");
396        // Check: all within [0, 100)
397        for &idx in &buf {
398            assert!(idx < 100, "Index out of bounds: {}", idx);
399        }
400    }
401}