1use crate::error::{OptimizeError, OptimizeResult};
29use scirs2_core::ndarray::{Array1, ArrayView1};
30
31#[inline]
35fn finite_diff_grad<F>(
36 f: &mut F,
37 x: &ArrayView1<f64>,
38 sample: &ArrayView1<f64>,
39 h: f64,
40) -> Array1<f64>
41where
42 F: FnMut(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
43{
44 let n = x.len();
45 let f0 = f(x, sample);
46 let mut grad = Array1::<f64>::zeros(n);
47 let mut x_fwd = x.to_owned();
48 for i in 0..n {
49 x_fwd[i] += h;
50 grad[i] = (f(&x_fwd.view(), sample) - f0) / h;
51 x_fwd[i] = x[i];
52 }
53 grad
54}
55
56fn full_grad<F>(f: &mut F, x: &ArrayView1<f64>, samples: &[Array1<f64>], h: f64) -> Array1<f64>
58where
59 F: FnMut(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
60{
61 let n = x.len();
62 if samples.is_empty() {
63 return Array1::zeros(n);
64 }
65 let mut avg = Array1::<f64>::zeros(n);
66 for s in samples {
67 let g = finite_diff_grad(f, x, &s.view(), h);
68 for i in 0..n {
69 avg[i] += g[i];
70 }
71 }
72 let inv_m = 1.0 / samples.len() as f64;
73 avg.mapv_inplace(|v| v * inv_m);
74 avg
75}
76
77#[derive(Debug, Clone)]
81pub struct SvrgOptions {
82 pub n_epochs: usize,
84 pub inner_steps: usize,
86 pub step_size: f64,
88 pub tol: f64,
90 pub fd_step: f64,
92}
93
94impl Default for SvrgOptions {
95 fn default() -> Self {
96 Self {
97 n_epochs: 50,
98 inner_steps: 100,
99 step_size: 1e-3,
100 tol: 1e-6,
101 fd_step: 1e-5,
102 }
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct SvrgResult {
109 pub x: Array1<f64>,
111 pub grad_norm: f64,
113 pub n_grad_evals: usize,
115 pub converged: bool,
117}
118
119pub fn svrg<F>(
135 f: &mut F,
136 x0: &ArrayView1<f64>,
137 samples: &[Array1<f64>],
138 opts: &SvrgOptions,
139) -> OptimizeResult<SvrgResult>
140where
141 F: FnMut(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
142{
143 let n = x0.len();
144 if n == 0 {
145 return Err(OptimizeError::ValueError(
146 "x0 must be non-empty".to_string(),
147 ));
148 }
149 if samples.is_empty() {
150 return Err(OptimizeError::ValueError(
151 "samples must be non-empty".to_string(),
152 ));
153 }
154
155 let m = samples.len();
156 let mut x = x0.to_owned();
157 let mut converged = false;
158 let mut total_evals: usize = 0;
159 let mut rng: u64 = 987654321;
161
162 for _ in 0..opts.n_epochs {
163 let x_tilde = x.clone();
165 let mu_tilde = full_grad(f, &x_tilde.view(), samples, opts.fd_step);
166 total_evals += m * (n + 1);
167
168 let grad_norm = mu_tilde.iter().map(|v| v * v).sum::<f64>().sqrt();
169 if grad_norm < opts.tol {
170 converged = true;
171 return Ok(SvrgResult {
172 x,
173 grad_norm,
174 n_grad_evals: total_evals,
175 converged,
176 });
177 }
178
179 for _ in 0..opts.inner_steps {
181 rng = rng
183 .wrapping_mul(6364136223846793005)
184 .wrapping_add(1442695040888963407);
185 let idx = (rng >> 33) as usize % m;
186 let s = &samples[idx];
187
188 let g_x = finite_diff_grad(f, &x.view(), &s.view(), opts.fd_step);
189 let g_tilde = finite_diff_grad(f, &x_tilde.view(), &s.view(), opts.fd_step);
190 total_evals += 2 * (n + 1);
191
192 for i in 0..n {
194 x[i] -= opts.step_size * (g_x[i] - g_tilde[i] + mu_tilde[i]);
195 }
196 }
197 }
198
199 let grad_norm = full_grad(f, &x.view(), samples, opts.fd_step)
200 .iter()
201 .map(|v| v * v)
202 .sum::<f64>()
203 .sqrt();
204
205 Ok(SvrgResult {
206 x,
207 grad_norm,
208 n_grad_evals: total_evals,
209 converged,
210 })
211}
212
213#[derive(Debug, Clone)]
217pub struct SarahOptions {
218 pub n_outer: usize,
220 pub inner_steps: usize,
222 pub step_size: f64,
224 pub tol: f64,
226 pub fd_step: f64,
228}
229
230impl Default for SarahOptions {
231 fn default() -> Self {
232 Self {
233 n_outer: 50,
234 inner_steps: 50,
235 step_size: 1e-3,
236 tol: 1e-6,
237 fd_step: 1e-5,
238 }
239 }
240}
241
242#[derive(Debug, Clone)]
244pub struct SarahResult {
245 pub x: Array1<f64>,
247 pub grad_norm: f64,
249 pub n_grad_evals: usize,
251 pub converged: bool,
253}
254
255pub fn sarah<F>(
273 f: &mut F,
274 x0: &ArrayView1<f64>,
275 samples: &[Array1<f64>],
276 opts: &SarahOptions,
277) -> OptimizeResult<SarahResult>
278where
279 F: FnMut(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
280{
281 let n = x0.len();
282 if n == 0 {
283 return Err(OptimizeError::ValueError(
284 "x0 must be non-empty".to_string(),
285 ));
286 }
287 if samples.is_empty() {
288 return Err(OptimizeError::ValueError(
289 "samples must be non-empty".to_string(),
290 ));
291 }
292
293 let m = samples.len();
294 let mut x = x0.to_owned();
295 let mut converged = false;
296 let mut total_evals: usize = 0;
297 let mut rng: u64 = 11111111111111111;
298
299 for _ in 0..opts.n_outer {
300 let mut v = full_grad(f, &x.view(), samples, opts.fd_step);
302 total_evals += m * (n + 1);
303
304 let g_norm = v.iter().map(|vi| vi * vi).sum::<f64>().sqrt();
305 if g_norm < opts.tol {
306 converged = true;
307 return Ok(SarahResult {
308 x,
309 grad_norm: g_norm,
310 n_grad_evals: total_evals,
311 converged,
312 });
313 }
314
315 for i in 0..n {
317 x[i] -= opts.step_size * v[i];
318 }
319
320 let mut x_prev = x.clone();
321
322 for _ in 0..opts.inner_steps {
323 rng = rng
324 .wrapping_mul(6364136223846793005)
325 .wrapping_add(1442695040888963407);
326 let idx = (rng >> 33) as usize % m;
327 let s = &samples[idx];
328
329 let g_curr = finite_diff_grad(f, &x.view(), &s.view(), opts.fd_step);
330 let g_prev = finite_diff_grad(f, &x_prev.view(), &s.view(), opts.fd_step);
331 total_evals += 2 * (n + 1);
332
333 let v_new: Array1<f64> = g_curr
335 .iter()
336 .zip(g_prev.iter())
337 .zip(v.iter())
338 .map(|((&gc, &gp), &vp)| gc - gp + vp)
339 .collect();
340
341 x_prev = x.clone();
342 for i in 0..n {
343 x[i] -= opts.step_size * v_new[i];
344 }
345 v = v_new;
346 }
347 }
348
349 let g_norm = full_grad(f, &x.view(), samples, opts.fd_step)
350 .iter()
351 .map(|v| v * v)
352 .sum::<f64>()
353 .sqrt();
354
355 Ok(SarahResult {
356 x,
357 grad_norm: g_norm,
358 n_grad_evals: total_evals,
359 converged,
360 })
361}
362
363#[derive(Debug, Clone)]
367pub struct SpiderOptions {
368 pub n_outer: usize,
370 pub inner_steps: usize,
372 pub step_size: f64,
374 pub tol: f64,
376 pub fd_step: f64,
378 pub mini_batch: usize,
380}
381
382impl Default for SpiderOptions {
383 fn default() -> Self {
384 Self {
385 n_outer: 30,
386 inner_steps: 50,
387 step_size: 5e-4,
388 tol: 1e-6,
389 fd_step: 1e-5,
390 mini_batch: 4,
391 }
392 }
393}
394
395#[derive(Debug, Clone)]
397pub struct SpiderResult {
398 pub x: Array1<f64>,
400 pub grad_norm: f64,
402 pub n_grad_evals: usize,
404 pub converged: bool,
406}
407
408pub fn spider<F>(
423 f: &mut F,
424 x0: &ArrayView1<f64>,
425 samples: &[Array1<f64>],
426 opts: &SpiderOptions,
427) -> OptimizeResult<SpiderResult>
428where
429 F: FnMut(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
430{
431 let n = x0.len();
432 if n == 0 {
433 return Err(OptimizeError::ValueError(
434 "x0 must be non-empty".to_string(),
435 ));
436 }
437 if samples.is_empty() {
438 return Err(OptimizeError::ValueError(
439 "samples must be non-empty".to_string(),
440 ));
441 }
442
443 let m = samples.len();
444 let b = opts.mini_batch.max(1).min(m);
445 let mut x = x0.to_owned();
446 let mut converged = false;
447 let mut total_evals: usize = 0;
448 let mut rng: u64 = 999999999999;
449
450 for _ in 0..opts.n_outer {
451 let mut v = full_grad(f, &x.view(), samples, opts.fd_step);
453 total_evals += m * (n + 1);
454
455 let g_norm = v.iter().map(|vi| vi * vi).sum::<f64>().sqrt();
456 if g_norm < opts.tol {
457 converged = true;
458 return Ok(SpiderResult {
459 x,
460 grad_norm: g_norm,
461 n_grad_evals: total_evals,
462 converged,
463 });
464 }
465
466 for i in 0..n {
468 x[i] -= opts.step_size * v[i];
469 }
470
471 let mut x_prev = x.clone();
472
473 for _ in 0..opts.inner_steps {
474 let mut batch_indices = Vec::with_capacity(b);
476 for _ in 0..b {
477 rng = rng
478 .wrapping_mul(6364136223846793005)
479 .wrapping_add(1442695040888963407);
480 batch_indices.push((rng >> 33) as usize % m);
481 }
482
483 let mut diff = Array1::<f64>::zeros(n);
485 for &idx in &batch_indices {
486 let s = &samples[idx];
487 let g_curr = finite_diff_grad(f, &x.view(), &s.view(), opts.fd_step);
488 let g_prev = finite_diff_grad(f, &x_prev.view(), &s.view(), opts.fd_step);
489 total_evals += 2 * (n + 1);
490 for i in 0..n {
491 diff[i] += (g_curr[i] - g_prev[i]) / b as f64;
492 }
493 }
494
495 let v_new: Array1<f64> = diff.iter().zip(v.iter()).map(|(&d, &vp)| d + vp).collect();
497 x_prev = x.clone();
498 for i in 0..n {
499 x[i] -= opts.step_size * v_new[i];
500 }
501 v = v_new;
502
503 let cur_norm = v.iter().map(|vi| vi * vi).sum::<f64>().sqrt();
504 if cur_norm < opts.tol {
505 converged = true;
506 return Ok(SpiderResult {
507 x,
508 grad_norm: cur_norm,
509 n_grad_evals: total_evals,
510 converged,
511 });
512 }
513 }
514 }
515
516 let g_norm = v_norm_approx(&full_grad(f, &x.view(), samples, opts.fd_step));
517
518 Ok(SpiderResult {
519 x,
520 grad_norm: g_norm,
521 n_grad_evals: total_evals,
522 converged,
523 })
524}
525
526#[inline]
527fn v_norm_approx(v: &Array1<f64>) -> f64 {
528 v.iter().map(|vi| vi * vi).sum::<f64>().sqrt()
529}
530
531#[cfg(test)]
534mod tests {
535 use super::*;
536 use scirs2_core::ndarray::array;
537
538 fn make_samples() -> Vec<Array1<f64>> {
541 vec![
542 array![0.9, 1.8],
543 array![1.1, 2.2],
544 array![1.0, 2.0],
545 array![0.8, 1.9],
546 array![1.2, 2.1],
547 array![1.0, 2.0],
548 array![0.95, 1.95],
549 array![1.05, 2.05],
550 ]
551 }
552
553 fn sample_loss(x: &ArrayView1<f64>, s: &ArrayView1<f64>) -> f64 {
554 (x[0] - s[0]).powi(2) + (x[1] - s[1]).powi(2)
555 }
556
557 #[test]
558 fn test_svrg_quadratic() {
559 let samples = make_samples();
560 let x0 = array![0.0, 0.0];
561 let opts = SvrgOptions {
562 n_epochs: 100,
563 inner_steps: 50,
564 step_size: 0.1,
565 tol: 1e-4,
566 fd_step: 1e-5,
567 };
568 let res = svrg(&mut |x, s| sample_loss(x, s), &x0.view(), &samples, &opts)
569 .expect("failed to create res");
570 assert!(
571 (res.x[0] - 1.0).abs() < 0.3,
572 "SVRG: expected x[0]≈1.0, got {}",
573 res.x[0]
574 );
575 assert!(
576 (res.x[1] - 2.0).abs() < 0.3,
577 "SVRG: expected x[1]≈2.0, got {}",
578 res.x[1]
579 );
580 }
581
582 #[test]
583 fn test_sarah_quadratic() {
584 let samples = make_samples();
585 let x0 = array![0.0, 0.0];
586 let opts = SarahOptions {
587 n_outer: 80,
588 inner_steps: 30,
589 step_size: 0.05,
590 tol: 1e-4,
591 fd_step: 1e-5,
592 };
593 let res = sarah(&mut |x, s| sample_loss(x, s), &x0.view(), &samples, &opts)
594 .expect("failed to create res");
595 assert!(
596 (res.x[0] - 1.0).abs() < 0.3,
597 "SARAH: expected x[0]≈1.0, got {}",
598 res.x[0]
599 );
600 assert!(
601 (res.x[1] - 2.0).abs() < 0.3,
602 "SARAH: expected x[1]≈2.0, got {}",
603 res.x[1]
604 );
605 }
606
607 #[test]
608 fn test_spider_quadratic() {
609 let samples = make_samples();
610 let x0 = array![0.0, 0.0];
611 let opts = SpiderOptions {
612 n_outer: 80,
613 inner_steps: 30,
614 step_size: 0.05,
615 tol: 1e-4,
616 fd_step: 1e-5,
617 mini_batch: 2,
618 };
619 let res = spider(&mut |x, s| sample_loss(x, s), &x0.view(), &samples, &opts)
620 .expect("failed to create res");
621 assert!(
622 (res.x[0] - 1.0).abs() < 0.4,
623 "SPIDER: expected x[0]≈1.0, got {}",
624 res.x[0]
625 );
626 assert!(
627 (res.x[1] - 2.0).abs() < 0.4,
628 "SPIDER: expected x[1]≈2.0, got {}",
629 res.x[1]
630 );
631 }
632}