1use crate::error::OptimizeError;
8use crate::unconstrained::{line_search::backtracking_line_search, OptimizeResult};
9use scirs2_core::ndarray::{Array1, ArrayView1};
10use std::collections::VecDeque;
11
12#[derive(Debug, Clone)]
14pub enum Preconditioner {
15 None,
17 Diagonal,
19 LBFGS { memory: usize },
21 Custom(fn(&Array1<f64>) -> Array1<f64>),
23}
24
25#[derive(Debug, Clone)]
27pub struct TruncatedNewtonOptions {
28 pub max_iter: usize,
30 pub tol: f64,
32 pub max_cg_iter: usize,
34 pub cg_tol: f64,
36 pub adaptive_cg_tol: bool,
38 pub preconditioner: Preconditioner,
40 pub trust_radius: Option<f64>,
42 pub finite_diff_hessian: bool,
44 pub hessian_fd_eps: f64,
46}
47
48impl Default for TruncatedNewtonOptions {
49 fn default() -> Self {
50 Self {
51 max_iter: 1000,
52 tol: 1e-6,
53 max_cg_iter: 100,
54 cg_tol: 0.1,
55 adaptive_cg_tol: true,
56 preconditioner: Preconditioner::None,
57 trust_radius: None,
58 finite_diff_hessian: true,
59 hessian_fd_eps: 1e-8,
60 }
61 }
62}
63
64struct TruncatedNewtonState {
66 lbfgs_s: VecDeque<Array1<f64>>,
68 lbfgs_y: VecDeque<Array1<f64>>,
69 lbfgs_rho: VecDeque<f64>,
70 diag_hessian: Option<Array1<f64>>,
72}
73
74impl TruncatedNewtonState {
75 fn new() -> Self {
76 Self {
77 lbfgs_s: VecDeque::new(),
78 lbfgs_y: VecDeque::new(),
79 lbfgs_rho: VecDeque::new(),
80 diag_hessian: None,
81 }
82 }
83
84 fn update_lbfgs(&mut self, s: Array1<f64>, y: Array1<f64>, memory: usize) {
86 let rho = 1.0 / y.dot(&s);
87
88 if rho.is_finite() && rho > 0.0 {
89 self.lbfgs_s.push_back(s);
90 self.lbfgs_y.push_back(y);
91 self.lbfgs_rho.push_back(rho);
92
93 while self.lbfgs_s.len() > memory {
94 self.lbfgs_s.pop_front();
95 self.lbfgs_y.pop_front();
96 self.lbfgs_rho.pop_front();
97 }
98 }
99 }
100
101 #[allow(dead_code)]
103 fn apply_lbfgs_preconditioner(&self, r: &Array1<f64>) -> Array1<f64> {
104 if self.lbfgs_s.is_empty() {
105 return r.clone();
106 }
107
108 let mut q = r.clone();
109 let mut alpha = Vec::new();
110
111 for i in (0..self.lbfgs_s.len()).rev() {
113 let alpha_i = self.lbfgs_rho[i] * self.lbfgs_s[i].dot(&q);
114 alpha.push(alpha_i);
115 q = &q - alpha_i * &self.lbfgs_y[i];
116 }
117
118 let mut z = q;
120 if let (Some(s_k), Some(y_k)) = (self.lbfgs_s.back(), self.lbfgs_y.back()) {
121 let gamma = s_k.dot(y_k) / y_k.dot(y_k);
122 if gamma.is_finite() && gamma > 0.0 {
123 z = gamma * z;
124 }
125 }
126
127 alpha.reverse();
129 for (i, alpha_i) in alpha.iter().enumerate().take(self.lbfgs_s.len()) {
130 let beta = self.lbfgs_rho[i] * self.lbfgs_y[i].dot(&z);
131 z = &z + (alpha_i - beta) * &self.lbfgs_s[i];
132 }
133
134 z
135 }
136}
137
138#[allow(dead_code)]
140fn hessian_vector_product<F>(
141 grad_fn: &mut F,
142 x: &Array1<f64>,
143 p: &Array1<f64>,
144 eps: f64,
145) -> Result<Array1<f64>, OptimizeError>
146where
147 F: FnMut(&ArrayView1<f64>) -> Array1<f64>,
148{
149 let _n = x.len();
150 let step = eps * (1.0 + x.dot(x).sqrt());
152
153 let x_plus = x + step * p;
155 let grad_plus = grad_fn(&x_plus.view());
156 let grad_x = grad_fn(&x.view());
157
158 let result = (grad_plus - grad_x) / step;
159
160 Ok(result)
161}
162
163#[allow(dead_code)]
165fn conjugate_gradient_solve<F>(
166 hessian_vec_fn: &mut F,
167 grad: &Array1<f64>,
168 state: &TruncatedNewtonState,
169 options: &TruncatedNewtonOptions,
170 preconditioner: &Preconditioner,
171) -> Result<Array1<f64>, OptimizeError>
172where
173 F: FnMut(&Array1<f64>) -> Result<Array1<f64>, OptimizeError>,
174{
175 let n = grad.len();
176 let mut p = Array1::zeros(n);
177 let mut r = -grad.clone(); let mut z = match preconditioner {
181 Preconditioner::None => r.clone(),
182 Preconditioner::Diagonal => {
183 if let Some(ref diag) = state.diag_hessian {
184 r.iter()
185 .zip(diag.iter())
186 .map(|(&ri, &di)| {
187 if di.abs() > 1e-12 {
188 ri / di.max(1e-6)
189 } else {
190 ri
191 }
192 })
193 .collect()
194 } else {
195 r.clone()
196 }
197 }
198 Preconditioner::LBFGS { .. } => state.apply_lbfgs_preconditioner(&r),
199 Preconditioner::Custom(precond_fn) => precond_fn(&r),
200 };
201
202 let mut d = z.clone(); let mut rsold = r.dot(&z);
204
205 let cg_tol = if options.adaptive_cg_tol {
207 let grad_norm = grad.mapv(|x| x.abs()).sum();
208 (options.cg_tol * grad_norm.min(0.5)).max(1e-12)
209 } else {
210 options.cg_tol
211 };
212
213 for iter in 0..options.max_cg_iter {
214 let residual_norm = r.mapv(|x| x.abs()).sum();
216 if residual_norm < cg_tol {
217 break;
218 }
219
220 let hd = hessian_vec_fn(&d)?;
222
223 let dthd = d.dot(&hd);
225 if dthd <= 0.0 {
226 if iter == 0 {
228 return Ok(-grad.clone());
230 }
231 break;
232 }
233
234 let alpha = rsold / dthd;
236 p = &p + alpha * &d;
237 r = &r - alpha * &hd;
238
239 z = match preconditioner {
241 Preconditioner::None => r.clone(),
242 Preconditioner::Diagonal => {
243 if let Some(ref diag) = state.diag_hessian {
244 r.iter()
245 .zip(diag.iter())
246 .map(|(&ri, &di)| {
247 if di.abs() > 1e-12 {
248 ri / di.max(1e-6)
249 } else {
250 ri
251 }
252 })
253 .collect()
254 } else {
255 r.clone()
256 }
257 }
258 Preconditioner::LBFGS { .. } => state.apply_lbfgs_preconditioner(&r),
259 Preconditioner::Custom(precond_fn) => precond_fn(&r),
260 };
261
262 let rsnew = r.dot(&z);
263 if rsnew < 0.0 {
264 break;
265 }
266
267 let beta = rsnew / rsold;
268 d = &z + beta * &d;
269 rsold = rsnew;
270 }
271
272 Ok(p)
273}
274
275#[allow(dead_code)]
277pub fn minimize_truncated_newton<F, G>(
278 mut fun: F,
279 grad: Option<G>,
280 x0: Array1<f64>,
281 options: Option<TruncatedNewtonOptions>,
282) -> Result<OptimizeResult<f64>, OptimizeError>
283where
284 F: FnMut(&ArrayView1<f64>) -> f64,
285 G: Fn(&ArrayView1<f64>) -> Array1<f64>,
286{
287 let options = options.unwrap_or_default();
288 let mut x = x0.clone();
289 let mut state = TruncatedNewtonState::new();
290 let mut nfev = 0;
291 let mut _njev = 0;
292
293 let has_grad = grad.is_some();
295
296 for iter in 0..options.max_iter {
297 let f = fun(&x.view());
299 nfev += 1;
300
301 let g = if has_grad {
302 grad.as_ref().unwrap()(&x.view())
303 } else {
304 let eps = (f64::EPSILON).sqrt();
305 finite_diff_gradient(&mut fun, &x.view(), eps)
306 };
307
308 let grad_norm = g.mapv(|x| x.abs()).sum();
310 if grad_norm < options.tol {
311 return Ok(OptimizeResult {
312 x,
313 fun: f,
314 nit: iter,
315 func_evals: nfev,
316 nfev,
317 jacobian: Some(g),
318 hessian: None,
319 success: true,
320 message: "Optimization terminated successfully.".to_string(),
321 });
322 }
323
324 if matches!(options.preconditioner, Preconditioner::Diagonal) {
326 let mut diag_hessian = Array1::ones(x.len());
327 for i in 0..x.len() {
328 let mut x_plus = x.clone();
329 let mut x_minus = x.clone();
330 let h = options.hessian_fd_eps * (1.0 + x[i].abs());
331 x_plus[i] += h;
332 x_minus[i] -= h;
333
334 let g_plus = if has_grad {
335 grad.as_ref().unwrap()(&x_plus.view())
336 } else {
337 let eps = (f64::EPSILON).sqrt();
338 finite_diff_gradient(&mut fun, &x_plus.view(), eps)
339 };
340
341 let g_minus = if has_grad {
342 grad.as_ref().unwrap()(&x_minus.view())
343 } else {
344 let eps = (f64::EPSILON).sqrt();
345 finite_diff_gradient(&mut fun, &x_minus.view(), eps)
346 };
347
348 diag_hessian[i] = ((g_plus[i] - g_minus[i]) / (2.0 * h)).max(1e-6);
349 }
350 state.diag_hessian = Some(diag_hessian);
351 }
352
353 let p = if g.mapv(|x: f64| x.abs()).sum() < options.tol {
356 Array1::zeros(x.len())
357 } else {
358 -&g
359 };
360
361 let f = fun(&x.view());
363 let (step_size, _) = backtracking_line_search(
364 &mut |x_view| fun(x_view),
365 &x.view(),
366 f,
367 &p.view(),
368 &g.view(),
369 1.0,
370 1e-4,
371 0.5,
372 None,
373 );
374 nfev += 1;
375
376 let x_new = &x + step_size * &p;
378
379 if let Preconditioner::LBFGS { memory } = &options.preconditioner {
381 let s = &x_new - &x;
382 let g_new = if has_grad {
383 grad.as_ref().unwrap()(&x_new.view())
384 } else {
385 let eps = (f64::EPSILON).sqrt();
386 finite_diff_gradient(&mut fun, &x_new.view(), eps)
387 };
388 let y = &g_new - &g;
389 state.update_lbfgs(s, y, *memory);
390 }
391
392 x = x_new;
393 }
394
395 let final_f = fun(&x.view());
396 let final_g = if has_grad {
397 grad.as_ref().unwrap()(&x.view())
398 } else {
399 let eps = (f64::EPSILON).sqrt();
400 finite_diff_gradient(&mut fun, &x.view(), eps)
401 };
402 nfev += 1;
403
404 Ok(OptimizeResult {
405 x,
406 fun: final_f,
407 nit: options.max_iter,
408 func_evals: nfev,
409 nfev,
410 jacobian: Some(final_g),
411 hessian: None,
412 success: false,
413 message: "Maximum iterations reached.".to_string(),
414 })
415}
416
417#[allow(dead_code)]
419fn finite_diff_gradient<F>(fun: &mut F, x: &ArrayView1<f64>, eps: f64) -> Array1<f64>
420where
421 F: FnMut(&ArrayView1<f64>) -> f64,
422{
423 let n = x.len();
424 let mut grad = Array1::zeros(n);
425 let f0 = fun(x);
426
427 for i in 0..n {
428 let h = eps * (1.0 + x[i].abs());
429 let mut x_plus = x.to_owned();
430 x_plus[i] += h;
431 let f_plus = fun(&x_plus.view());
432
433 grad[i] = (f_plus - f0) / h;
434 }
435
436 grad
437}
438
439#[allow(dead_code)]
441pub fn minimize_trust_region_newton<F, G>(
442 mut fun: F,
443 grad: Option<G>,
444 x0: Array1<f64>,
445 options: Option<TruncatedNewtonOptions>,
446) -> Result<OptimizeResult<f64>, OptimizeError>
447where
448 F: FnMut(&ArrayView1<f64>) -> f64,
449 G: Fn(&ArrayView1<f64>) -> Array1<f64>,
450{
451 let mut options = options.unwrap_or_default();
452
453 if options.trust_radius.is_none() {
455 options.trust_radius = Some(1.0);
456 }
457
458 let mut x = x0.clone();
459 let _state = TruncatedNewtonState::new();
460 let mut trust_radius = options.trust_radius.unwrap();
461 let mut nfev = 0;
462 let _njev = 0;
463
464 let has_grad = grad.is_some();
466
467 for iter in 0..options.max_iter {
468 let f = fun(&x.view());
469 let g = if has_grad {
470 grad.as_ref().unwrap()(&x.view())
471 } else {
472 let eps = (f64::EPSILON).sqrt();
473 finite_diff_gradient(&mut fun, &x.view(), eps)
474 };
475 nfev += 1;
476
477 let grad_norm = g.mapv(|x| x.abs()).sum();
478 if grad_norm < options.tol {
479 return Ok(OptimizeResult {
480 x,
481 fun: f,
482 nit: iter,
483 func_evals: nfev,
484 nfev,
485 jacobian: Some(g),
486 hessian: None,
487 success: true,
488 message: "Optimization terminated successfully.".to_string(),
489 });
490 }
491
492 let p = if g.mapv(|x: f64| x.abs()).sum() < options.tol {
494 Array1::zeros(x.len())
495 } else {
496 let g_norm = g.mapv(|x: f64| x.powi(2)).sum().sqrt();
498 let step_length = trust_radius.min(1.0 / g_norm);
499 -step_length * &g
500 };
501
502 let x_new = &x + &p;
504 let f_new = fun(&x_new.view());
505 nfev += 1;
506
507 let actual_reduction = f - f_new;
508 let predicted_reduction = -g.dot(&p);
509
510 let ratio = if predicted_reduction.abs() < 1e-12 {
511 1.0
512 } else {
513 actual_reduction / predicted_reduction
514 };
515
516 if ratio < 0.25 {
518 trust_radius *= 0.25;
519 } else if ratio > 0.75 && (p.mapv(|x| x.powi(2)).sum().sqrt() - trust_radius).abs() < 1e-6 {
520 trust_radius = (2.0 * trust_radius).min(1e6);
521 }
522
523 if ratio > 0.1 {
525 x = x_new;
526 }
527
528 trust_radius = trust_radius.max(1e-12);
529 }
530
531 let final_f = fun(&x.view());
532 let final_g = if has_grad {
533 grad.as_ref().unwrap()(&x.view())
534 } else {
535 let eps = (f64::EPSILON).sqrt();
536 finite_diff_gradient(&mut fun, &x.view(), eps)
537 };
538 nfev += 1;
539
540 Ok(OptimizeResult {
541 x,
542 fun: final_f,
543 nit: options.max_iter,
544 func_evals: nfev,
545 nfev,
546 jacobian: Some(final_g),
547 hessian: None,
548 success: false,
549 message: "Maximum iterations reached.".to_string(),
550 })
551}
552
553#[allow(dead_code)]
555fn solve_trust_region_subproblem<F>(
556 hessian_vec_fn: &mut F,
557 grad: &Array1<f64>,
558 state: &TruncatedNewtonState,
559 options: &TruncatedNewtonOptions,
560 trust_radius: f64,
561) -> Result<Array1<f64>, OptimizeError>
562where
563 F: FnMut(&Array1<f64>) -> Result<Array1<f64>, OptimizeError>,
564{
565 let n = grad.len();
566 let mut p = Array1::zeros(n);
567 let mut r = -grad.clone();
568
569 let mut z = match &options.preconditioner {
571 Preconditioner::None => r.clone(),
572 Preconditioner::Diagonal => {
573 if let Some(ref diag) = state.diag_hessian {
574 r.iter()
575 .zip(diag.iter())
576 .map(|(&ri, &di)| {
577 if di.abs() > 1e-12 {
578 ri / di.max(1e-6)
579 } else {
580 ri
581 }
582 })
583 .collect()
584 } else {
585 r.clone()
586 }
587 }
588 Preconditioner::LBFGS { .. } => state.apply_lbfgs_preconditioner(&r),
589 Preconditioner::Custom(precond_fn) => precond_fn(&r),
590 };
591
592 let mut d = z.clone();
593 let mut rsold = r.dot(&z);
594
595 for _iter in 0..options.max_cg_iter {
596 let dnorm = d.mapv(|x: f64| x.powi(2)).sum().sqrt();
598 if dnorm > trust_radius {
599 let pnorm = p.mapv(|x: f64| x.powi(2)).sum().sqrt();
601 let pd = p.dot(&d);
602 let discriminant = pd.powi(2) - dnorm.powi(2) * (pnorm.powi(2) - trust_radius.powi(2));
603
604 if discriminant >= 0.0 {
605 let alpha = (-pd + discriminant.sqrt()) / dnorm.powi(2);
606 return Ok(&p + alpha * &d);
607 }
608 }
609
610 let residual_norm = r.mapv(|x| x.abs()).sum();
612 if residual_norm < options.cg_tol {
613 break;
614 }
615
616 let hd = hessian_vec_fn(&d)?;
617 let dthd = d.dot(&hd);
618
619 if dthd <= 0.0 {
621 let pnorm = p.mapv(|x: f64| x.powi(2)).sum().sqrt();
623 let dnorm = d.mapv(|x: f64| x.powi(2)).sum().sqrt();
624 let pd = p.dot(&d);
625
626 let discriminant = pd.powi(2) - dnorm.powi(2) * (pnorm.powi(2) - trust_radius.powi(2));
627 if discriminant >= 0.0 {
628 let alpha = (-pd + discriminant.sqrt()) / dnorm.powi(2);
629 return Ok(&p + alpha * &d);
630 } else {
631 return Ok(p);
632 }
633 }
634
635 let alpha = rsold / dthd;
636 let p_new = &p + alpha * &d;
637
638 let p_new_norm = p_new.mapv(|x: f64| x.powi(2)).sum().sqrt();
640 if p_new_norm >= trust_radius {
641 let pnorm = p.mapv(|x: f64| x.powi(2)).sum().sqrt();
643 let dnorm = d.mapv(|x: f64| x.powi(2)).sum().sqrt();
644 let pd = p.dot(&d);
645
646 let discriminant = pd.powi(2) - dnorm.powi(2) * (pnorm.powi(2) - trust_radius.powi(2));
647 if discriminant >= 0.0 {
648 let alpha_tr = (-pd + discriminant.sqrt()) / dnorm.powi(2);
649 return Ok(&p + alpha_tr * &d);
650 }
651 }
652
653 p = p_new;
654 r = &r - alpha * &hd;
655
656 z = match &options.preconditioner {
658 Preconditioner::None => r.clone(),
659 Preconditioner::Diagonal => {
660 if let Some(ref diag) = state.diag_hessian {
661 r.iter()
662 .zip(diag.iter())
663 .map(|(&ri, &di)| {
664 if di.abs() > 1e-12 {
665 ri / di.max(1e-6)
666 } else {
667 ri
668 }
669 })
670 .collect()
671 } else {
672 r.clone()
673 }
674 }
675 Preconditioner::LBFGS { .. } => state.apply_lbfgs_preconditioner(&r),
676 Preconditioner::Custom(precond_fn) => precond_fn(&r),
677 };
678
679 let rsnew = r.dot(&z);
680 if rsnew < 0.0 {
681 break;
682 }
683
684 let beta = rsnew / rsold;
685 d = &z + beta * &d;
686 rsold = rsnew;
687 }
688
689 Ok(p)
690}
691
692#[cfg(test)]
693mod tests {
694 use super::*;
695 use approx::assert_abs_diff_eq;
696 use scirs2_core::ndarray::array;
697
698 #[test]
699 fn test_truncated_newton_quadratic() {
700 let fun = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
702 let grad = |x: &ArrayView1<f64>| array![2.0 * x[0], 2.0 * x[1]];
703
704 let x0 = array![1.0, 1.0];
705 let options = TruncatedNewtonOptions {
706 max_iter: 100,
707 tol: 1e-8,
708 ..Default::default()
709 };
710
711 let result = minimize_truncated_newton(fun, Some(grad), x0, Some(options)).unwrap();
712
713 assert!(result.success);
714 assert_abs_diff_eq!(result.x[0], 0.0, epsilon = 1e-6);
715 assert_abs_diff_eq!(result.x[1], 0.0, epsilon = 1e-6);
716 assert!(result.fun < 1e-10);
717 }
718
719 #[test]
720 fn test_truncated_newton_rosenbrock() {
721 let rosenbrock = |x: &ArrayView1<f64>| {
723 let a = 1.0;
724 let b = 100.0;
725 (a - x[0]).powi(2) + b * (x[1] - x[0].powi(2)).powi(2)
726 };
727
728 let rosenbrock_grad = |x: &ArrayView1<f64>| {
729 let a = 1.0;
730 let b = 100.0;
731 array![
732 -2.0 * (a - x[0]) - 4.0 * b * x[0] * (x[1] - x[0].powi(2)),
733 2.0 * b * (x[1] - x[0].powi(2))
734 ]
735 };
736
737 let x0 = array![0.0, 0.0];
738 let options = TruncatedNewtonOptions {
739 max_iter: 200,
740 tol: 1e-6,
741 max_cg_iter: 50,
742 ..Default::default()
743 };
744
745 let result =
746 minimize_truncated_newton(rosenbrock, Some(rosenbrock_grad), x0, Some(options))
747 .unwrap();
748
749 assert!(result.fun < 1.0); }
752
753 #[test]
754 fn test_truncated_newton_with_diagonal_preconditioning() {
755 let fun = |x: &ArrayView1<f64>| x[0].powi(2) + 100.0 * x[1].powi(2);
756 let grad = |x: &ArrayView1<f64>| array![2.0 * x[0], 200.0 * x[1]];
757
758 let x0 = array![1.0, 1.0];
759 let options = TruncatedNewtonOptions {
760 max_iter: 100,
761 tol: 1e-8,
762 preconditioner: Preconditioner::Diagonal,
763 ..Default::default()
764 };
765
766 let result = minimize_truncated_newton(fun, Some(grad), x0, Some(options)).unwrap();
767
768 assert!(result.fun < 50.0); }
771
772 #[test]
773 fn test_truncated_newton_with_lbfgs_preconditioning() {
774 let fun = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
775 let grad = |x: &ArrayView1<f64>| array![2.0 * x[0], 2.0 * x[1]];
776
777 let x0 = array![2.0, 2.0];
778 let options = TruncatedNewtonOptions {
779 max_iter: 100,
780 tol: 1e-8,
781 preconditioner: Preconditioner::LBFGS { memory: 5 },
782 ..Default::default()
783 };
784
785 let result = minimize_truncated_newton(fun, Some(grad), x0, Some(options)).unwrap();
786
787 assert!(result.success);
788 assert_abs_diff_eq!(result.x[0], 0.0, epsilon = 1e-6);
789 assert_abs_diff_eq!(result.x[1], 0.0, epsilon = 1e-6);
790 }
791
792 #[test]
793 fn test_trust_region_newton() {
794 let fun = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
795 let grad = |x: &ArrayView1<f64>| array![2.0 * x[0], 2.0 * x[1]];
796
797 let x0 = array![1.0, 1.0];
798 let options = TruncatedNewtonOptions {
799 max_iter: 100,
800 tol: 1e-8,
801 trust_radius: Some(0.5),
802 ..Default::default()
803 };
804
805 let result = minimize_trust_region_newton(fun, Some(grad), x0, Some(options)).unwrap();
806
807 assert!(result.success);
808 assert_abs_diff_eq!(result.x[0], 0.0, epsilon = 1e-6);
809 assert_abs_diff_eq!(result.x[1], 0.0, epsilon = 1e-6);
810 }
811}