scirs2_optimize/second_order/
slbfgs.rs1use super::lbfgsb::hv_product;
17use super::types::{OptResult, SlbfgsConfig};
18use crate::error::OptimizeError;
19
20pub struct Lcg {
26 state: u64,
27}
28
29impl Lcg {
30 pub fn new(seed: u64) -> Self {
32 Self { state: seed }
33 }
34
35 pub fn next_u32(&mut self) -> u32 {
37 self.state = self
38 .state
39 .wrapping_mul(1_664_525)
40 .wrapping_add(1_013_904_223)
41 & 0xFFFF_FFFF;
42 self.state as u32
43 }
44
45 pub fn next_usize(&mut self, n: usize) -> usize {
47 (self.next_u32() as usize) % n
48 }
49
50 pub fn sample_without_replacement(&mut self, n: usize, k: usize, buf: &mut Vec<usize>) {
54 buf.clear();
55 if k == 0 || n == 0 {
56 return;
57 }
58 let k = k.min(n);
59 let mut pool: Vec<usize> = (0..n).collect();
60 for i in 0..k {
61 let j = i + self.next_usize(n - i);
62 pool.swap(i, j);
63 buf.push(pool[i]);
64 }
65 }
66}
67
68fn svrg_gradient(
78 stoch_f_and_g: &dyn Fn(&[f64], &[usize]) -> (f64, Vec<f64>),
79 x_k: &[f64],
80 x_snap: &[f64],
81 g_snap: &[f64], batch: &[usize],
83) -> Vec<f64> {
84 let (_, g_k) = stoch_f_and_g(x_k, batch);
85 let (_, g_s) = stoch_f_and_g(x_snap, batch);
86 g_k.iter()
87 .zip(g_s.iter())
88 .zip(g_snap.iter())
89 .map(|((gki, gsi), gfi)| gki - gsi + gfi)
90 .collect()
91}
92
93fn curvature_y(
102 stoch_f_and_g: &dyn Fn(&[f64], &[usize]) -> (f64, Vec<f64>),
103 x_new: &[f64],
104 x_old: &[f64],
105 batch: &[usize],
106) -> Vec<f64> {
107 let n = x_new.len();
108 if batch.is_empty() {
109 return vec![0.0; n];
110 }
111 let (_, g_new) = stoch_f_and_g(x_new, batch);
112 let (_, g_old) = stoch_f_and_g(x_old, batch);
113 g_new
114 .iter()
115 .zip(g_old.iter())
116 .map(|(gn, go)| gn - go)
117 .collect()
118}
119
120pub struct SlbfgsOptimizer {
124 pub config: SlbfgsConfig,
126}
127
128impl SlbfgsOptimizer {
129 pub fn new(config: SlbfgsConfig) -> Self {
131 Self { config }
132 }
133
134 pub fn default_config() -> Self {
136 Self {
137 config: SlbfgsConfig::default(),
138 }
139 }
140
141 pub fn minimize(
154 &self,
155 stoch_f_and_g: &dyn Fn(&[f64], &[usize]) -> (f64, Vec<f64>),
156 full_grad_fn: &dyn Fn(&[f64]) -> (f64, Vec<f64>),
157 n_samples: usize,
158 x0: &[f64],
159 ) -> Result<OptResult, OptimizeError> {
160 let n = x0.len();
161 let cfg = &self.config;
162 let m = cfg.m;
163
164 if n_samples == 0 {
165 return Err(OptimizeError::ValueError(
166 "n_samples must be positive".to_string(),
167 ));
168 }
169
170 let mut x = x0.to_vec();
171 let mut rng = Lcg::new(cfg.seed);
172
173 let mut s_hist: Vec<Vec<f64>> = Vec::with_capacity(m);
175 let mut y_hist: Vec<Vec<f64>> = Vec::with_capacity(m);
176 let mut rho_hist: Vec<f64> = Vec::with_capacity(m);
177 let mut gamma = 1.0_f64;
178
179 let mut x_snap = x.clone();
181 let (mut f_snap, mut g_snap) = full_grad_fn(&x_snap);
182
183 let mut n_iter = 0usize;
184 let mut converged = false;
185 let mut batch_buf: Vec<usize> = Vec::with_capacity(cfg.batch_size);
186 let mut curv_batch_buf: Vec<usize> = Vec::with_capacity(cfg.curvature_batch_size);
187
188 let mut best_x = x.clone();
190 let mut best_f = f_snap;
191
192 for iter in 0..cfg.max_iter {
193 n_iter = iter;
194
195 if cfg.variance_reduction && iter % cfg.snapshot_freq == 0 {
197 let (fs, gs) = full_grad_fn(&x);
198 x_snap = x.clone();
199 f_snap = fs;
200 g_snap = gs;
201 }
202
203 let gn = g_snap.iter().map(|g| g * g).sum::<f64>().sqrt();
205 if gn < cfg.tol {
206 converged = true;
207 break;
208 }
209
210 rng.sample_without_replacement(n_samples, cfg.batch_size, &mut batch_buf);
212
213 let g_k = if cfg.variance_reduction {
215 svrg_gradient(stoch_f_and_g, &x, &x_snap, &g_snap, &batch_buf)
216 } else {
217 let (_, gk) = stoch_f_and_g(&x, &batch_buf);
218 gk
219 };
220
221 let hg = hv_product(&g_k, &s_hist, &y_hist, &rho_hist, gamma);
223 let d: Vec<f64> = hg.iter().map(|v| -v).collect();
224
225 let slope: f64 = g_k.iter().zip(d.iter()).map(|(gi, di)| gi * di).sum();
227 let d = if slope >= 0.0 {
228 g_k.iter().map(|gi| -gi).collect::<Vec<f64>>()
229 } else {
230 d
231 };
232
233 let x_new: Vec<f64> = x
235 .iter()
236 .zip(d.iter())
237 .map(|(xi, di)| xi + cfg.lr * di)
238 .collect();
239
240 rng.sample_without_replacement(
242 n_samples,
243 cfg.curvature_batch_size,
244 &mut curv_batch_buf,
245 );
246 let s_k: Vec<f64> = (0..n).map(|i| x_new[i] - x[i]).collect();
247 let y_k = curvature_y(stoch_f_and_g, &x_new, &x, &curv_batch_buf);
248
249 let sy: f64 = s_k.iter().zip(y_k.iter()).map(|(si, yi)| si * yi).sum();
251 if sy > 1e-14 * s_k.iter().map(|si| si * si).sum::<f64>().sqrt() {
252 if s_hist.len() == m {
253 s_hist.remove(0);
254 y_hist.remove(0);
255 rho_hist.remove(0);
256 }
257 let yy: f64 = y_k.iter().map(|yi| yi * yi).sum();
258 if yy > 1e-14 {
259 gamma = sy / yy;
260 }
261 rho_hist.push(1.0 / sy);
262 s_hist.push(s_k);
263 y_hist.push(y_k);
264 }
265
266 x = x_new;
267
268 let (f_curr, _) = full_grad_fn(&x);
270 if f_curr < best_f {
271 best_f = f_curr;
272 best_x = x.clone();
273 }
274 }
275
276 let (_, g_final) = full_grad_fn(&best_x);
277 let grad_norm = g_final.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
278
279 Ok(OptResult {
280 x: best_x,
281 f_val: best_f,
282 grad_norm,
283 n_iter,
284 converged,
285 })
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use crate::second_order::types::SlbfgsConfig;
293
294 fn stoch_quad(x: &[f64], batch: &[usize]) -> (f64, Vec<f64>) {
297 let n = x.len();
298 if batch.is_empty() {
299 let f: f64 = x.iter().map(|xi| (xi - 1.0).powi(2)).sum::<f64>() / n as f64;
300 let g: Vec<f64> = x.iter().map(|xi| 2.0 * (xi - 1.0) / n as f64).collect();
301 return (f, g);
302 }
303 let bs = batch.len() as f64;
304 let f: f64 = batch.iter().map(|&i| (x[i % n] - 1.0).powi(2)).sum::<f64>() / bs;
305 let mut g = vec![0.0_f64; n];
306 for &idx in batch {
307 g[idx % n] += 2.0 * (x[idx % n] - 1.0) / bs;
308 }
309 (f, g)
310 }
311
312 fn full_quad(x: &[f64]) -> (f64, Vec<f64>) {
313 let n = x.len();
314 let all: Vec<usize> = (0..n).collect();
315 stoch_quad(x, &all)
316 }
317
318 #[test]
319 fn test_slbfgs_gradient_variance_reduction() {
320 let x_star = vec![1.0; 4];
322 let x_snap = vec![1.0; 4];
323 let (_, g_snap) = full_quad(&x_snap);
324 let batch = vec![0, 1, 2, 3];
325 let g_corr = svrg_gradient(&stoch_quad, &x_star, &x_snap, &g_snap, &batch);
326 for gi in &g_corr {
327 assert!(
328 gi.abs() < 1e-12,
329 "Corrected gradient should be zero at optimum: got {}",
330 gi
331 );
332 }
333 }
334
335 #[test]
336 fn test_slbfgs_curvature_condition() {
337 let x_old = vec![2.0, 3.0];
339 let x_new = vec![1.5, 2.5];
340 let all_batch: Vec<usize> = (0..2).collect();
341 let y = curvature_y(&stoch_quad, &x_new, &x_old, &all_batch);
342 let s: Vec<f64> = x_new
343 .iter()
344 .zip(x_old.iter())
345 .map(|(xn, xo)| xn - xo)
346 .collect();
347 let sy: f64 = s.iter().zip(y.iter()).map(|(si, yi)| si * yi).sum();
348 assert!(sy > 0.0, "Curvature condition y^T s > 0 violated: {}", sy);
349 }
350
351 #[test]
352 fn test_slbfgs_stochastic_convergence() {
353 let mut cfg = SlbfgsConfig::default();
354 cfg.max_iter = 300;
355 cfg.lr = 0.05;
356 cfg.batch_size = 4;
357 cfg.curvature_batch_size = 8;
358 cfg.variance_reduction = true;
359 cfg.tol = 1e-4;
360 let opt = SlbfgsOptimizer::new(cfg);
361
362 let x0 = vec![0.0_f64; 4];
363 let result = opt
364 .minimize(&stoch_quad, &full_quad, 4, &x0)
365 .expect("S-L-BFGS failed");
366 for xi in &result.x {
367 assert!(
368 (xi - 1.0).abs() < 0.2,
369 "S-L-BFGS did not converge: x={:?}",
370 result.x
371 );
372 }
373 }
374
375 #[test]
376 fn test_second_order_config_default() {
377 use crate::second_order::types::{LbfgsBConfig, SlbfgsConfig, Sr1Config};
378 let _c1 = LbfgsBConfig::default();
379 let _c2 = Sr1Config::default();
380 let _c3 = SlbfgsConfig::default();
381 }
383
384 #[test]
385 fn test_slbfgs_batch_selection() {
386 let mut rng = Lcg::new(12345);
387 let mut buf = Vec::new();
388 rng.sample_without_replacement(100, 10, &mut buf);
389 assert_eq!(buf.len(), 10);
391 let mut sorted = buf.clone();
393 sorted.sort_unstable();
394 sorted.dedup();
395 assert_eq!(sorted.len(), 10, "Duplicate indices in batch selection");
396 for &idx in &buf {
398 assert!(idx < 100, "Index out of bounds: {}", idx);
399 }
400 }
401}