scirs2_optimize/stochastic/
approximation.rs1use crate::error::{OptimizeError, OptimizeResult};
19use scirs2_core::ndarray::{Array1, ArrayView1};
20
21#[derive(Debug, Clone)]
25pub struct RobbinsMonroResult {
26 pub x: Array1<f64>,
28 pub residual: f64,
30 pub n_iter: usize,
32 pub converged: bool,
34}
35
36#[derive(Debug, Clone)]
38pub struct RobbinsMonroOptions {
39 pub max_iter: usize,
41 pub tol: f64,
43 pub alpha: f64,
45 pub a: f64,
47}
48
49impl Default for RobbinsMonroOptions {
50 fn default() -> Self {
51 Self {
52 max_iter: 10_000,
53 tol: 1e-6,
54 alpha: 1.0,
55 a: 1.0,
56 }
57 }
58}
59
60pub fn robbins_monro<M>(
75 m: &mut M,
76 x0: &ArrayView1<f64>,
77 opts: &RobbinsMonroOptions,
78) -> OptimizeResult<RobbinsMonroResult>
79where
80 M: FnMut(&ArrayView1<f64>) -> Array1<f64>,
81{
82 let n = x0.len();
83 if n == 0 {
84 return Err(OptimizeError::ValueError(
85 "x0 must be non-empty".to_string(),
86 ));
87 }
88
89 let mut x = x0.to_owned();
90 let mut converged = false;
91 let mut residual = f64::INFINITY;
92
93 for k in 1..=opts.max_iter {
94 let mk = m(&x.view());
95 if mk.len() != n {
96 return Err(OptimizeError::ValueError(format!(
97 "M returned length {} but x has length {}",
98 mk.len(),
99 n
100 )));
101 }
102 let ak = opts.a / (k as f64).powf(opts.alpha);
103 let mut step_norm = 0.0_f64;
104 for i in 0..n {
105 let step = ak * mk[i];
106 x[i] -= step;
107 step_norm += step * step;
108 }
109 residual = step_norm.sqrt();
110 if residual < opts.tol {
111 converged = true;
112 residual = mk.iter().map(|v| v * v).sum::<f64>().sqrt();
113 return Ok(RobbinsMonroResult {
114 x,
115 residual,
116 n_iter: k,
117 converged,
118 });
119 }
120 }
121
122 let mk_final = m(&x.view());
124 residual = mk_final.iter().map(|v| v * v).sum::<f64>().sqrt();
125
126 Ok(RobbinsMonroResult {
127 x,
128 residual,
129 n_iter: opts.max_iter,
130 converged,
131 })
132}
133
134#[derive(Debug, Clone)]
138pub struct KieferWolfowitzResult {
139 pub x: Array1<f64>,
141 pub fun: f64,
143 pub n_iter: usize,
145 pub converged: bool,
147}
148
149#[derive(Debug, Clone)]
151pub struct KieferWolfowitzOptions {
152 pub max_iter: usize,
154 pub tol: f64,
156 pub alpha: f64,
158 pub gamma: f64,
160 pub a: f64,
162 pub c: f64,
164}
165
166impl Default for KieferWolfowitzOptions {
167 fn default() -> Self {
168 Self {
169 max_iter: 10_000,
170 tol: 1e-6,
171 alpha: 0.602,
172 gamma: 0.101,
173 a: 0.1,
174 c: 0.1,
175 }
176 }
177}
178
179pub fn kiefer_wolfowitz<L>(
194 loss: &mut L,
195 x0: &ArrayView1<f64>,
196 opts: &KieferWolfowitzOptions,
197) -> OptimizeResult<KieferWolfowitzResult>
198where
199 L: FnMut(&ArrayView1<f64>) -> f64,
200{
201 let n = x0.len();
202 if n == 0 {
203 return Err(OptimizeError::ValueError(
204 "x0 must be non-empty".to_string(),
205 ));
206 }
207
208 let mut x = x0.to_owned();
209 let mut converged = false;
210
211 for k in 1..=opts.max_iter {
212 let ak = opts.a / (k as f64).powf(opts.alpha);
213 let ck = opts.c / (k as f64).powf(opts.gamma);
214
215 let mut grad = Array1::<f64>::zeros(n);
217 for i in 0..n {
218 let mut x_fwd = x.clone();
219 let mut x_bwd = x.clone();
220 x_fwd[i] += ck;
221 x_bwd[i] -= ck;
222 grad[i] = (loss(&x_fwd.view()) - loss(&x_bwd.view())) / (2.0 * ck);
223 }
224
225 let mut step_norm = 0.0_f64;
226 for i in 0..n {
227 let step = ak * grad[i];
228 x[i] -= step;
229 step_norm += step * step;
230 }
231
232 if step_norm.sqrt() < opts.tol {
233 converged = true;
234 let fun = loss(&x.view());
235 return Ok(KieferWolfowitzResult {
236 x,
237 fun,
238 n_iter: k,
239 converged,
240 });
241 }
242 }
243
244 let fun = loss(&x.view());
245 Ok(KieferWolfowitzResult {
246 x,
247 fun,
248 n_iter: opts.max_iter,
249 converged,
250 })
251}
252
253#[derive(Debug, Clone)]
257pub struct SpsaOptions {
258 pub max_iter: usize,
260 pub tol: f64,
262 pub alpha: f64,
264 pub gamma: f64,
266 pub a: f64,
268 pub big_a: f64,
270 pub c: f64,
272}
273
274impl Default for SpsaOptions {
275 fn default() -> Self {
276 Self {
277 max_iter: 5_000,
278 tol: 1e-6,
279 alpha: 0.602,
280 gamma: 0.101,
281 a: 0.1,
282 big_a: 100.0,
283 c: 0.1,
284 }
285 }
286}
287
288#[derive(Debug, Clone)]
290pub struct SpsaResult {
291 pub x: Array1<f64>,
293 pub fun: f64,
295 pub n_iter: usize,
297 pub converged: bool,
299}
300
301pub fn spsa_step<F>(
318 f: &mut F,
319 x: &mut Array1<f64>,
320 k: usize,
321 opts: &SpsaOptions,
322 rng_state: &mut u64,
323) -> f64
324where
325 F: FnMut(&ArrayView1<f64>) -> f64,
326{
327 let n = x.len();
328 let ak = opts.a / (opts.big_a + k as f64).powf(opts.alpha);
329 let ck = opts.c / (k as f64).powf(opts.gamma);
330
331 let mut delta = Array1::<f64>::zeros(n);
333 for i in 0..n {
334 *rng_state = rng_state
335 .wrapping_mul(6364136223846793005)
336 .wrapping_add(1442695040888963407);
337 delta[i] = if (*rng_state >> 63) == 0 { 1.0 } else { -1.0 };
338 }
339
340 let x_fwd: Array1<f64> = x
342 .iter()
343 .zip(delta.iter())
344 .map(|(&xi, &di)| xi + ck * di)
345 .collect();
346 let x_bwd: Array1<f64> = x
347 .iter()
348 .zip(delta.iter())
349 .map(|(&xi, &di)| xi - ck * di)
350 .collect();
351 let f_fwd = f(&x_fwd.view());
352 let f_bwd = f(&x_bwd.view());
353
354 let diff = (f_fwd - f_bwd) / (2.0 * ck);
355
356 let mut step_sq = 0.0_f64;
358 for i in 0..n {
359 let gi = diff / delta[i]; let step = ak * gi;
361 x[i] -= step;
362 step_sq += step * step;
363 }
364 step_sq.sqrt()
365}
366
367pub fn spsa_minimize<F>(
378 f: &mut F,
379 x0: &ArrayView1<f64>,
380 opts: &SpsaOptions,
381) -> OptimizeResult<SpsaResult>
382where
383 F: FnMut(&ArrayView1<f64>) -> f64,
384{
385 if x0.is_empty() {
386 return Err(OptimizeError::ValueError(
387 "x0 must be non-empty".to_string(),
388 ));
389 }
390
391 let mut x = x0.to_owned();
392 let mut rng_state: u64 = 12345678901234567;
393 let mut converged = false;
394
395 for k in 1..=opts.max_iter {
396 let step_norm = spsa_step(f, &mut x, k, opts, &mut rng_state);
397 if step_norm < opts.tol {
398 converged = true;
399 let fun = f(&x.view());
400 return Ok(SpsaResult {
401 x,
402 fun,
403 n_iter: k,
404 converged,
405 });
406 }
407 }
408
409 let fun = f(&x.view());
410 Ok(SpsaResult {
411 x,
412 fun,
413 n_iter: opts.max_iter,
414 converged,
415 })
416}
417
418#[cfg(test)]
421mod tests {
422 use super::*;
423 use scirs2_core::ndarray::array;
424
425 #[test]
426 fn test_robbins_monro_linear() {
427 let mut m = |x: &ArrayView1<f64>| array![x[0] - 2.0];
429 let x0 = array![0.0];
430 let opts = RobbinsMonroOptions {
431 max_iter: 50_000,
432 tol: 1e-4,
433 a: 1.0,
434 alpha: 1.0,
435 };
436 let res = robbins_monro(&mut m, &x0.view(), &opts).expect("failed to create res");
437 assert!(
438 (res.x[0] - 2.0).abs() < 0.1,
439 "expected x* ≈ 2.0, got {}",
440 res.x[0]
441 );
442 }
443
444 #[test]
445 fn test_kiefer_wolfowitz_quadratic() {
446 let mut loss = |x: &ArrayView1<f64>| (x[0] - 3.0).powi(2);
448 let x0 = array![0.0];
449 let opts = KieferWolfowitzOptions {
450 max_iter: 20_000,
451 tol: 1e-5,
452 ..Default::default()
453 };
454 let res = kiefer_wolfowitz(&mut loss, &x0.view(), &opts).expect("failed to create res");
455 assert!(
456 (res.x[0] - 3.0).abs() < 0.2,
457 "expected x* ≈ 3.0, got {}",
458 res.x[0]
459 );
460 }
461
462 #[test]
463 fn test_spsa_quadratic() {
464 let mut f = |x: &ArrayView1<f64>| (x[0] - 1.0).powi(2) + (x[1] - 2.0).powi(2);
466 let x0 = array![0.0, 0.0];
467 let opts = SpsaOptions {
468 max_iter: 10_000,
469 tol: 1e-5,
470 a: 0.5,
471 big_a: 50.0,
472 c: 0.2,
473 ..Default::default()
474 };
475 let res = spsa_minimize(&mut f, &x0.view(), &opts).expect("failed to create res");
476 assert!(
477 (res.x[0] - 1.0).abs() < 0.3,
478 "expected x[0] ≈ 1.0, got {}",
479 res.x[0]
480 );
481 assert!(
482 (res.x[1] - 2.0).abs() < 0.3,
483 "expected x[1] ≈ 2.0, got {}",
484 res.x[1]
485 );
486 }
487}