scirs2_optimize/second_order/
sr1.rs1use super::types::{OptResult, Sr1Config};
17use crate::error::OptimizeError;
18
19fn mat_vec(a: &[f64], x: &[f64], n: usize) -> Vec<f64> {
23 (0..n)
24 .map(|i| (0..n).map(|j| a[i * n + j] * x[j]).sum::<f64>())
25 .collect()
26}
27
28fn sym_mat_vec(a: &[f64], x: &[f64], n: usize) -> Vec<f64> {
30 mat_vec(a, x, n)
31}
32
33fn add_outer(a: &mut Vec<f64>, u: &[f64], v: &[f64], n: usize, scale: f64) {
35 for i in 0..n {
36 for j in 0..n {
37 a[i * n + j] += scale * u[i] * v[j];
38 }
39 }
40}
41
42pub fn sr1_update_dense(b: &mut Vec<f64>, s: &[f64], y: &[f64], n: usize, skip_tol: f64) -> bool {
52 let bs = sym_mat_vec(b, s, n);
53 let r: Vec<f64> = (0..n).map(|i| y[i] - bs[i]).collect(); let rs: f64 = r.iter().zip(s.iter()).map(|(ri, si)| ri * si).sum(); let r_norm = r.iter().map(|ri| ri * ri).sum::<f64>().sqrt();
57 let s_norm = s.iter().map(|si| si * si).sum::<f64>().sqrt();
58
59 if r_norm < 1e-14 || rs.abs() < skip_tol * s_norm * r_norm {
62 return false; }
64
65 let inv_rs = 1.0 / rs;
66 add_outer(b, &r, &r, n, inv_rs);
67 true
68}
69
70pub fn lsr1_hv_product(
87 g: &[f64],
88 s_hist: &[Vec<f64>],
89 y_hist: &[Vec<f64>],
90 gamma: f64,
91) -> Vec<f64> {
92 let n = g.len();
93 let m = s_hist.len();
94
95 let mut r: Vec<f64> = g.iter().map(|gi| gamma * gi).collect();
97
98 if m == 0 {
99 return r;
100 }
101
102 let psi: Vec<Vec<f64>> = (0..m)
104 .map(|i| {
105 (0..n)
106 .map(|j| s_hist[i][j] - gamma * y_hist[i][j])
107 .collect::<Vec<f64>>()
108 })
109 .collect();
110
111 let mut big_m = vec![0.0_f64; m * m];
113 for i in 0..m {
114 for j in 0..m {
115 big_m[i * m + j] = psi[i]
116 .iter()
117 .zip(y_hist[j].iter())
118 .map(|(pi, yj)| pi * yj)
119 .sum();
120 }
121 }
122
123 let psi_g: Vec<f64> = (0..m)
125 .map(|i| psi[i].iter().zip(g.iter()).map(|(pi, gi)| pi * gi).sum())
126 .collect();
127
128 let v = match gaussian_solve(&big_m, &psi_g, m) {
130 Some(x) => x,
131 None => return r, };
133
134 for i in 0..m {
136 for j in 0..n {
137 r[j] += psi[i][j] * v[i];
138 }
139 }
140
141 r
142}
143
144fn gaussian_solve(a: &[f64], b: &[f64], m: usize) -> Option<Vec<f64>> {
146 let mut aug = vec![0.0_f64; m * (m + 1)];
147 for i in 0..m {
148 for j in 0..m {
149 aug[i * (m + 1) + j] = a[i * m + j];
150 }
151 aug[i * (m + 1) + m] = b[i];
152 }
153
154 for col in 0..m {
155 let mut max_row = col;
157 let mut max_val = aug[col * (m + 1) + col].abs();
158 for row in (col + 1)..m {
159 let v = aug[row * (m + 1) + col].abs();
160 if v > max_val {
161 max_val = v;
162 max_row = row;
163 }
164 }
165 if max_val < 1e-14 {
166 return None; }
168 if max_row != col {
170 for j in 0..=(m) {
171 aug.swap(col * (m + 1) + j, max_row * (m + 1) + j);
172 }
173 }
174 let pivot = aug[col * (m + 1) + col];
176 for row in (col + 1)..m {
177 let factor = aug[row * (m + 1) + col] / pivot;
178 for j in col..=(m) {
179 let val = aug[col * (m + 1) + j] * factor;
180 aug[row * (m + 1) + j] -= val;
181 }
182 }
183 }
184
185 let mut x = vec![0.0_f64; m];
187 for i in (0..m).rev() {
188 let rhs = aug[i * (m + 1) + m];
189 let diag = aug[i * (m + 1) + i];
190 if diag.abs() < 1e-14 {
191 return None;
192 }
193 let sum: f64 = ((i + 1)..m).map(|j| aug[i * (m + 1) + j] * x[j]).sum();
194 x[i] = (rhs - sum) / diag;
195 }
196 Some(x)
197}
198
199pub fn trust_region_step(b: &[f64], g: &[f64], delta: f64, n: usize) -> Vec<f64> {
208 let b_vec = b.to_vec();
210 if let Some(d) = solve_linear_system(&b_vec, g, n) {
211 let d_neg: Vec<f64> = d.iter().map(|di| -di).collect();
212 let dnorm = d_neg.iter().map(|di| di * di).sum::<f64>().sqrt();
213 if dnorm <= delta {
214 return d_neg; }
216 }
217
218 let mut lam_lo = 0.0_f64;
220 let mut lam_hi = {
221 let g_norm = g.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
223 g_norm / delta
224 + b.iter()
225 .enumerate()
226 .filter(|(idx, _)| idx % (n + 1) == 0)
227 .map(|(_, v)| v.abs())
228 .fold(0.0_f64, f64::max)
229 };
230 lam_hi = lam_hi.max(1.0);
231
232 for _ in 0..50 {
233 let lam_mid = 0.5 * (lam_lo + lam_hi);
234 let mut b_reg = b_vec.clone();
235 for i in 0..n {
236 b_reg[i * n + i] += lam_mid;
237 }
238 if let Some(d) = solve_linear_system(&b_reg, g, n) {
239 let d_neg: Vec<f64> = d.iter().map(|di| -di).collect();
240 let dnorm = d_neg.iter().map(|di| di * di).sum::<f64>().sqrt();
241 if dnorm <= delta {
242 lam_hi = lam_mid;
243 } else {
244 lam_lo = lam_mid;
245 }
246 if (lam_hi - lam_lo).abs() < 1e-12 * (1.0 + lam_mid) {
247 return d_neg;
248 }
249 } else {
250 lam_lo = lam_mid;
251 }
252 }
253
254 let g_norm = g.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
256 if g_norm < 1e-14 {
257 return vec![0.0; n];
258 }
259 g.iter().map(|gi| -gi * delta / g_norm).collect()
260}
261
262fn solve_linear_system(a: &[f64], b: &[f64], n: usize) -> Option<Vec<f64>> {
264 gaussian_solve(a, b, n)
265}
266
267pub struct Sr1Optimizer {
271 pub config: Sr1Config,
273}
274
275impl Sr1Optimizer {
276 pub fn new(config: Sr1Config) -> Self {
278 Self { config }
279 }
280
281 pub fn default_config() -> Self {
283 Self {
284 config: Sr1Config::default(),
285 }
286 }
287
288 pub fn minimize<F>(&self, f_and_g: &F, x0: &[f64]) -> Result<OptResult, OptimizeError>
297 where
298 F: Fn(&[f64]) -> (f64, Vec<f64>),
299 {
300 let n = x0.len();
301 let cfg = &self.config;
302 let m = cfg.m;
303
304 let mut x = x0.to_vec();
305 let (mut f_val, mut g) = f_and_g(&x);
306
307 let mut s_hist: Vec<Vec<f64>> = Vec::with_capacity(m);
309 let mut y_hist: Vec<Vec<f64>> = Vec::with_capacity(m);
310
311 let mut gamma = 1.0_f64;
313
314 let mut delta = cfg.delta_init;
316
317 let mut n_iter = 0usize;
318 let mut converged = false;
319
320 for iter in 0..cfg.max_iter {
321 n_iter = iter;
322 let g_norm = g.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
323 if g_norm < cfg.tol {
324 converged = true;
325 break;
326 }
327
328 let hg = lsr1_hv_product(&g, &s_hist, &y_hist, gamma);
330
331 let hg_norm = hg.iter().map(|v| v * v).sum::<f64>().sqrt();
333 let d: Vec<f64> = if hg_norm > delta {
334 hg.iter().map(|v| -v * delta / hg_norm).collect()
335 } else {
336 hg.iter().map(|v| -v).collect()
337 };
338
339 let slope: f64 = g.iter().zip(d.iter()).map(|(gi, di)| gi * di).sum();
341 let d = if slope >= 0.0 {
342 let gn = g_norm.max(1e-14);
344 let sc = delta / gn;
345 g.iter().map(|gi| -gi * sc).collect::<Vec<f64>>()
346 } else {
347 d
348 };
349
350 let x_new: Vec<f64> = x.iter().zip(d.iter()).map(|(xi, di)| xi + di).collect();
351 let (f_new, g_new) = f_and_g(&x_new);
352
353 let actual_red = f_val - f_new;
355 let gd: f64 = g.iter().zip(d.iter()).map(|(gi, di)| gi * di).sum();
357 let predicted_red = -gd; let rho = if predicted_red.abs() < 1e-14 {
360 0.0
361 } else {
362 actual_red / predicted_red
363 };
364
365 if rho > cfg.eta {
367 let s: Vec<f64> = d.clone();
369 let y: Vec<f64> = (0..n).map(|i| g_new[i] - g[i]).collect();
370
371 let bs = lsr1_hv_product(&s, &s_hist, &y_hist, 1.0 / gamma);
373 let r: Vec<f64> = (0..n).map(|i| y[i] - bs[i]).collect();
374 let rs: f64 = r.iter().zip(s.iter()).map(|(ri, si)| ri * si).sum();
375 let r_norm = r.iter().map(|ri| ri * ri).sum::<f64>().sqrt();
376 let s_norm = s.iter().map(|si| si * si).sum::<f64>().sqrt();
377
378 if rs.abs() >= cfg.skip_tol * s_norm * r_norm {
379 let sy: f64 = s.iter().zip(y.iter()).map(|(si, yi)| si * yi).sum();
380 let yy: f64 = y.iter().map(|yi| yi * yi).sum::<f64>();
381 if yy > 1e-14 {
382 gamma = sy / yy; }
384 if s_hist.len() == m {
385 s_hist.remove(0);
386 y_hist.remove(0);
387 }
388 s_hist.push(s);
389 y_hist.push(y);
390 }
391
392 x = x_new;
393 f_val = f_new;
394 g = g_new;
395 }
396
397 if rho < 0.25 {
399 delta *= 0.25;
400 } else if rho > 0.75
401 && (d.iter().map(|di| di * di).sum::<f64>().sqrt() - delta).abs() < 1e-10
402 {
403 delta = (2.0 * delta).min(cfg.delta_max);
404 }
405
406 if delta < 1e-12 {
407 break; }
409 }
410
411 let grad_norm = g.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
412 Ok(OptResult {
413 x,
414 f_val,
415 grad_norm,
416 n_iter,
417 converged,
418 })
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425 use crate::second_order::types::Sr1Config;
426
427 fn quadratic(x: &[f64]) -> (f64, Vec<f64>) {
428 let f: f64 = x
429 .iter()
430 .enumerate()
431 .map(|(i, xi)| 0.5 * (i as f64 + 1.0) * xi * xi)
432 .sum();
433 let g: Vec<f64> = x
434 .iter()
435 .enumerate()
436 .map(|(i, xi)| (i as f64 + 1.0) * xi)
437 .collect();
438 (f, g)
439 }
440
441 #[test]
442 fn test_sr1_update_formula() {
443 let n = 3;
444 let mut b = vec![0.0_f64; n * n];
445 for i in 0..n {
447 b[i * n + i] = 1.0;
448 }
449 let s = vec![1.0, 0.0, 0.0];
450 let y = vec![2.0, 0.0, 0.0]; let updated = sr1_update_dense(&mut b, &s, &y, n, 1e-8);
452 assert!(updated, "SR1 update should proceed");
453 assert!(
455 (b[0] - 2.0).abs() < 1e-10,
456 "B[0,0] should be 2, got {}",
457 b[0]
458 );
459 }
460
461 #[test]
462 fn test_sr1_skip_bad_curvature() {
463 let n = 2;
464 let mut b = vec![1.0, 0.0, 0.0, 1.0]; let s = vec![1.0, 0.0];
467 let y = vec![1.0, 0.0]; let updated = sr1_update_dense(&mut b, &s, &y, n, 1e-8);
469 assert!(!updated, "SR1 update should be skipped (zero denominator)");
470 }
471
472 #[test]
473 fn test_sr1_trust_region() {
474 let n = 2;
475 let b = vec![2.0, 0.0, 0.0, 3.0]; let g = vec![1.0, 1.0];
477 let delta = 0.5_f64;
478 let d = trust_region_step(&b, &g, delta, n);
479 let d_norm = d.iter().map(|di| di * di).sum::<f64>().sqrt();
480 assert!(
481 d_norm <= delta + 1e-9,
482 "Trust region violated: ‖d‖={} > δ={}",
483 d_norm,
484 delta
485 );
486 }
487
488 #[test]
489 fn test_sr1_quadratic() {
490 let opt = Sr1Optimizer::default_config();
491 let x0 = vec![3.0, -2.0, 1.0];
492 let result = opt.minimize(&quadratic, &x0).expect("SR1 minimize failed");
493 for xi in &result.x {
494 assert!(xi.abs() < 0.01, "Expected x≈0, got {}", xi);
495 }
496 }
497
498 #[test]
499 fn test_sr1_positive_definite_approx() {
500 let n = 2;
503 let mut b = vec![1.0, 0.0, 0.0, 1.0];
504 let pairs = vec![
505 (vec![1.0, 0.0], vec![2.0, 0.0]),
506 (vec![0.0, 1.0], vec![0.0, 3.0]),
507 ];
508 for (s, y) in &pairs {
509 sr1_update_dense(&mut b, s, y, n, 1e-8);
510 }
511 for (s, _y) in &pairs {
513 let bs = mat_vec(&b, s, n);
514 let sts: f64 = s.iter().zip(bs.iter()).map(|(si, bsi)| si * bsi).sum();
515 assert!(sts > 0.0, "B should be positive in s direction");
516 }
517 }
518
519 #[test]
520 fn test_sr1_symmetric_update() {
521 let n = 3;
522 let mut b = vec![2.0, 1.0, 0.0, 1.0, 3.0, 0.5, 0.0, 0.5, 4.0];
523 let s = vec![0.5, -0.3, 0.1];
524 let y = vec![1.5, 0.2, 0.4];
525 sr1_update_dense(&mut b, &s, &y, n, 1e-8);
526 for i in 0..n {
528 for j in 0..n {
529 assert!(
530 (b[i * n + j] - b[j * n + i]).abs() < 1e-10,
531 "B not symmetric at ({},{}) vs ({},{}) : {} vs {}",
532 i,
533 j,
534 j,
535 i,
536 b[i * n + j],
537 b[j * n + i]
538 );
539 }
540 }
541 }
542}