1use crate::error::OptimizeError;
18use crate::unconstrained::Bounds;
19use scirs2_core::ndarray::{Array1, ArrayView1};
20
21#[derive(Debug, Clone)]
25pub struct LineSearchResult {
26 pub alpha: f64,
28 pub f_val: f64,
30 pub grad: Option<Array1<f64>>,
32 pub derphi: Option<f64>,
34 pub n_fev: usize,
36 pub n_gev: usize,
38 pub success: bool,
40 pub message: String,
42}
43
44#[derive(Debug, Clone)]
48pub struct StrongWolfeConfig {
49 pub c1: f64,
51 pub c2: f64,
53 pub alpha_init: f64,
55 pub alpha_max: f64,
57 pub max_fev: usize,
59 pub alpha_min: f64,
61}
62
63impl Default for StrongWolfeConfig {
64 fn default() -> Self {
65 Self {
66 c1: 1e-4,
67 c2: 0.9,
68 alpha_init: 1.0,
69 alpha_max: 1e10,
70 max_fev: 100,
71 alpha_min: 1e-14,
72 }
73 }
74}
75
76pub struct StrongWolfe {
84 pub config: StrongWolfeConfig,
86}
87
88impl StrongWolfe {
89 pub fn new(config: StrongWolfeConfig) -> Self {
91 Self { config }
92 }
93
94 pub fn default_search() -> Self {
96 Self {
97 config: StrongWolfeConfig::default(),
98 }
99 }
100
101 pub fn search<Phi, DPhi>(
109 &self,
110 mut phi: Phi,
111 mut dphi: DPhi,
112 phi0: f64,
113 dphi0: f64,
114 ) -> Result<LineSearchResult, OptimizeError>
115 where
116 Phi: FnMut(f64) -> f64,
117 DPhi: FnMut(f64) -> f64,
118 {
119 let cfg = &self.config;
120 if dphi0 >= 0.0 {
121 return Err(OptimizeError::ValueError(
122 "Initial directional derivative must be negative (descent direction required)"
123 .to_string(),
124 ));
125 }
126
127 let mut n_fev = 1usize; let mut n_gev = 1usize; let mut alpha_prev = 0.0f64;
131 let mut alpha = cfg.alpha_init.min(cfg.alpha_max);
132 let mut phi_prev = phi0;
133 let mut dphi_prev = dphi0;
135 let _ = dphi_prev;
136
137 for _iter in 0..cfg.max_fev {
138 let phi_a = phi(alpha);
139 n_fev += 1;
140
141 if phi_a > phi0 + cfg.c1 * alpha * dphi0 || (phi_a >= phi_prev && _iter > 0) {
143 let (result_alpha, result_phi, nf, ng) = self.zoom(
144 &mut phi, &mut dphi, alpha_prev, phi_prev, alpha, phi_a, phi0, dphi0,
145 );
146 n_fev += nf;
147 n_gev += ng;
148 return Ok(LineSearchResult {
149 alpha: result_alpha,
150 f_val: result_phi,
151 grad: None,
152 derphi: None,
153 n_fev,
154 n_gev,
155 success: true,
156 message: "Strong Wolfe conditions satisfied (zoom from upper bracket)"
157 .to_string(),
158 });
159 }
160
161 let dphi_a = dphi(alpha);
162 n_gev += 1;
163
164 if dphi_a.abs() <= -cfg.c2 * dphi0 {
166 return Ok(LineSearchResult {
167 alpha,
168 f_val: phi_a,
169 grad: None,
170 derphi: Some(dphi_a),
171 n_fev,
172 n_gev,
173 success: true,
174 message: "Strong Wolfe conditions satisfied".to_string(),
175 });
176 }
177
178 if dphi_a >= 0.0 {
179 let (result_alpha, result_phi, nf, ng) = self.zoom(
180 &mut phi, &mut dphi, alpha, phi_a, alpha_prev, phi_prev, phi0, dphi0,
181 );
182 n_fev += nf;
183 n_gev += ng;
184 return Ok(LineSearchResult {
185 alpha: result_alpha,
186 f_val: result_phi,
187 grad: None,
188 derphi: None,
189 n_fev,
190 n_gev,
191 success: true,
192 message: "Strong Wolfe conditions satisfied (zoom from positive derivative)"
193 .to_string(),
194 });
195 }
196
197 alpha_prev = alpha;
199 phi_prev = phi_a;
200 dphi_prev = dphi_a;
201 let alpha_new = (alpha + cfg.alpha_max) * 0.5;
202 alpha = cubic_min_bracket(alpha_prev, phi_prev, dphi_a, alpha_new, phi_a)
203 .unwrap_or(alpha_new)
204 .clamp(alpha * 1.1, cfg.alpha_max);
205 }
206
207 let f_alpha = phi(alpha);
209 n_fev += 1;
210 Ok(LineSearchResult {
211 alpha,
212 f_val: f_alpha,
213 grad: None,
214 derphi: None,
215 n_fev,
216 n_gev,
217 success: false,
218 message: "Strong Wolfe search did not converge within max evaluations".to_string(),
219 })
220 }
221
222 fn zoom<Phi, DPhi>(
224 &self,
225 phi: &mut Phi,
226 dphi: &mut DPhi,
227 alpha_lo: f64,
228 phi_lo: f64,
229 alpha_hi: f64,
230 phi_hi: f64,
231 phi0: f64,
232 dphi0: f64,
233 ) -> (f64, f64, usize, usize)
234 where
235 Phi: FnMut(f64) -> f64,
236 DPhi: FnMut(f64) -> f64,
237 {
238 let cfg = &self.config;
239 let mut n_fev = 0usize;
240 let mut n_gev = 0usize;
241
242 let mut a_lo = alpha_lo;
243 let mut f_lo = phi_lo;
244 let mut a_hi = alpha_hi;
245 let mut f_hi = phi_hi;
246
247 for _ in 0..cfg.max_fev {
248 let alpha_j = cubic_min_bracket(a_lo, f_lo, dphi(a_lo), a_hi, f_hi)
250 .unwrap_or((a_lo + a_hi) * 0.5)
251 .clamp(a_lo.min(a_hi) + 1e-10, a_lo.max(a_hi) - 1e-10);
252 n_gev += 1; let phi_j = phi(alpha_j);
255 n_fev += 1;
256
257 if phi_j > phi0 + cfg.c1 * alpha_j * dphi0 || phi_j >= f_lo {
258 a_hi = alpha_j;
259 f_hi = phi_j;
260 } else {
261 let dphi_j = dphi(alpha_j);
262 n_gev += 1;
263
264 if dphi_j.abs() <= -cfg.c2 * dphi0 {
265 return (alpha_j, phi_j, n_fev, n_gev);
266 }
267
268 if dphi_j * (a_hi - a_lo) >= 0.0 {
269 a_hi = a_lo;
270 f_hi = f_lo;
271 }
272 a_lo = alpha_j;
273 f_lo = phi_j;
274 }
275
276 if (a_hi - a_lo).abs() < cfg.alpha_min {
277 break;
278 }
279 }
280
281 (a_lo, f_lo, n_fev, n_gev)
282 }
283}
284
285fn cubic_min_bracket(a: f64, fa: f64, dfa: f64, b: f64, fb: f64) -> Option<f64> {
289 let d1 = dfa + (fb - fa) / (b - a) * 2.0 - (fb - fa) / (b - a);
290 let ab = b - a;
292 let d = dfa;
293 let d2 = 3.0 * (fa - fb) / ab + d;
294 let discr = d2 * d2 - d * ((fb - fa) / ab * 3.0 - d2);
295 let _ = d1;
296 if discr < 0.0 {
297 return None;
298 }
299 let t = d2 - discr.sqrt();
300 let denom = (2.0 * d2 - dfa / ab).abs();
301 if denom < 1e-300 {
302 return None;
303 }
304 let alpha = a + ab * t / denom;
305 Some(alpha)
306}
307
308#[derive(Debug, Clone)]
312pub struct HagerZhangConfig {
313 pub delta: f64,
315 pub sigma: f64,
317 pub epsilon: f64,
319 pub theta: f64,
321 pub gamma: f64,
323 pub max_fev: usize,
325 pub alpha_init: f64,
327}
328
329impl Default for HagerZhangConfig {
330 fn default() -> Self {
331 Self {
332 delta: 0.1,
333 sigma: 0.9,
334 epsilon: 1e-6,
335 theta: 0.5,
336 gamma: 0.66,
337 max_fev: 50,
338 alpha_init: 1.0,
339 }
340 }
341}
342
343pub struct HagerZhang {
348 pub config: HagerZhangConfig,
350}
351
352impl HagerZhang {
353 pub fn new(config: HagerZhangConfig) -> Self {
355 Self { config }
356 }
357
358 pub fn default_search() -> Self {
360 Self {
361 config: HagerZhangConfig::default(),
362 }
363 }
364
365 pub fn search<Phi, DPhi>(
371 &self,
372 mut phi: Phi,
373 mut dphi: DPhi,
374 phi0: f64,
375 dphi0: f64,
376 ) -> Result<LineSearchResult, OptimizeError>
377 where
378 Phi: FnMut(f64) -> f64,
379 DPhi: FnMut(f64) -> f64,
380 {
381 if dphi0 >= 0.0 {
382 return Err(OptimizeError::ValueError(
383 "Initial directional derivative must be negative".to_string(),
384 ));
385 }
386
387 let cfg = &self.config;
388 let mut n_fev = 1usize;
389 let mut n_gev = 1usize;
390 let c = cfg.epsilon * phi0.abs().max(1.0);
391
392 let wolfe1 = |pa: f64, a: f64| pa <= phi0 + cfg.delta * a * dphi0;
397 let approx_wolfe1 =
398 |pa: f64, a: f64| pa <= phi0 + c && cfg.delta * dphi0 >= (phi0 - pa) / a.max(1e-300);
399 let wolfe2 = |da: f64| da.abs() <= cfg.sigma * dphi0.abs();
400 let approx_wolfe2 =
401 |da: f64| (2.0 * cfg.delta - 1.0) * dphi0 <= da && da <= cfg.sigma * dphi0;
402 let _ = approx_wolfe1;
403
404 let mut a = 0.0f64;
406 let mut b = cfg.alpha_init;
407 let mut fa = phi0;
408 let mut fb = phi(b);
409 n_fev += 1;
410 let mut db = dphi(b);
411 n_gev += 1;
412
413 if wolfe1(fb, b) && wolfe2(db) {
415 return Ok(LineSearchResult {
416 alpha: b,
417 f_val: fb,
418 grad: None,
419 derphi: Some(db),
420 n_fev,
421 n_gev,
422 success: true,
423 message: "Hager-Zhang: initial step satisfies Wolfe conditions".to_string(),
424 });
425 }
426
427 for _ in 0..cfg.max_fev {
429 let mid = a + cfg.theta * (b - a);
430 let fm = phi(mid);
431 n_fev += 1;
432 let dm = dphi(mid);
433 n_gev += 1;
434
435 if wolfe1(fm, mid) && wolfe2(dm) {
436 return Ok(LineSearchResult {
437 alpha: mid,
438 f_val: fm,
439 grad: None,
440 derphi: Some(dm),
441 n_fev,
442 n_gev,
443 success: true,
444 message: "Hager-Zhang: bisection converged".to_string(),
445 });
446 }
447
448 if approx_wolfe2(dm) && fm <= phi0 + c {
450 return Ok(LineSearchResult {
451 alpha: mid,
452 f_val: fm,
453 grad: None,
454 derphi: Some(dm),
455 n_fev,
456 n_gev,
457 success: true,
458 message: "Hager-Zhang: approximate Wolfe satisfied".to_string(),
459 });
460 }
461
462 if dm < 0.0 && fm <= phi0 + c {
464 a = mid;
465 fa = fm;
466 } else {
467 b = mid;
468 fb = fm;
469 db = dm;
470 }
471
472 if (b - a).abs() < 1e-14 {
474 break;
475 }
476 let _ = fa;
477 let _ = fb;
478 let _ = db;
479 }
480
481 let alpha_best = a + cfg.theta * (b - a);
483 let f_best = phi(alpha_best);
484 n_fev += 1;
485
486 Ok(LineSearchResult {
487 alpha: alpha_best,
488 f_val: f_best,
489 grad: None,
490 derphi: None,
491 n_fev,
492 n_gev,
493 success: false,
494 message: "Hager-Zhang: max evaluations reached".to_string(),
495 })
496 }
497}
498
499pub struct SafeguardedPowell {
507 pub c1: f64,
509 pub max_fev: usize,
511 pub bracket_tol: f64,
513 pub alpha_init: f64,
515}
516
517impl Default for SafeguardedPowell {
518 fn default() -> Self {
519 Self {
520 c1: 1e-4,
521 max_fev: 50,
522 bracket_tol: 1e-14,
523 alpha_init: 1.0,
524 }
525 }
526}
527
528impl SafeguardedPowell {
529 pub fn new(c1: f64, max_fev: usize, alpha_init: f64) -> Self {
531 Self {
532 c1,
533 max_fev,
534 bracket_tol: 1e-14,
535 alpha_init,
536 }
537 }
538
539 pub fn search<Phi, DPhi>(
541 &self,
542 mut phi: Phi,
543 mut dphi: DPhi,
544 phi0: f64,
545 dphi0: f64,
546 ) -> Result<LineSearchResult, OptimizeError>
547 where
548 Phi: FnMut(f64) -> f64,
549 DPhi: FnMut(f64) -> f64,
550 {
551 if dphi0 >= 0.0 {
552 return Err(OptimizeError::ValueError(
553 "Initial directional derivative must be negative".to_string(),
554 ));
555 }
556
557 let mut n_fev = 1usize;
558 let mut n_gev = 1usize;
559
560 let mut alpha = self.alpha_init;
561 let mut alpha_lo = 0.0f64;
562 let mut f_lo = phi0;
563 let mut d_lo = dphi0;
564
565 for _ in 0..self.max_fev {
566 let fa = phi(alpha);
567 n_fev += 1;
568
569 if fa <= phi0 + self.c1 * alpha * dphi0 {
570 let da = dphi(alpha);
572 n_gev += 1;
573 return Ok(LineSearchResult {
574 alpha,
575 f_val: fa,
576 grad: None,
577 derphi: Some(da),
578 n_fev,
579 n_gev,
580 success: true,
581 message: "Safeguarded Powell: sufficient decrease satisfied".to_string(),
582 });
583 }
584
585 let alpha_new =
588 cubic_interpolate_safeguarded(alpha_lo, f_lo, d_lo, alpha, fa, self.bracket_tol);
589
590 if fa < f_lo {
592 alpha_lo = alpha;
594 f_lo = fa;
595 d_lo = dphi(alpha_new.min(alpha));
596 n_gev += 1;
597 }
598
599 if (alpha_new - alpha_lo).abs() < self.bracket_tol {
600 break;
601 }
602 alpha = alpha_new;
603 }
604
605 Ok(LineSearchResult {
607 alpha,
608 f_val: phi(alpha),
609 grad: None,
610 derphi: None,
611 n_fev: n_fev + 1,
612 n_gev,
613 success: false,
614 message: "Safeguarded Powell: max evaluations reached".to_string(),
615 })
616 }
617}
618
619fn cubic_interpolate_safeguarded(a: f64, fa: f64, da: f64, b: f64, fb: f64, tol: f64) -> f64 {
621 let ab = b - a;
623 if ab.abs() < tol {
624 return (a + b) * 0.5;
625 }
626 let d1 = da * ab;
628 let d2 = fb - fa;
629 let d3 = d2 * 3.0 / ab - da;
630 let d4 = d3 * d3 - da * (d2 * 3.0 / ab - da);
631 if d4 < 0.0 {
632 return (a + b) * 0.5;
633 }
634 let sqrt_d4 = d4.sqrt();
635 let t = 1.0 - (d3 + sqrt_d4) / (d3 + sqrt_d4 + d1 / ab);
636 let alpha_cubic = a + t.clamp(0.0, 1.0) * ab;
637
638 alpha_cubic.clamp(a + tol.abs(), b - tol.abs())
640}
641
642#[derive(Debug, Clone)]
650pub struct BacktrackingArmijo {
651 pub c1: f64,
653 pub rho: f64,
655 pub alpha_init: f64,
657 pub alpha_min: f64,
659 pub max_steps: usize,
661 pub bounds: Option<Bounds>,
663}
664
665impl Default for BacktrackingArmijo {
666 fn default() -> Self {
667 Self {
668 c1: 1e-4,
669 rho: 0.5,
670 alpha_init: 1.0,
671 alpha_min: 1e-14,
672 max_steps: 60,
673 bounds: None,
674 }
675 }
676}
677
678impl BacktrackingArmijo {
679 pub fn new(c1: f64, rho: f64, alpha_init: f64, bounds: Option<Bounds>) -> Self {
681 Self {
682 c1,
683 rho,
684 alpha_init,
685 alpha_min: 1e-14,
686 max_steps: 60,
687 bounds,
688 }
689 }
690
691 pub fn search<F>(
702 &self,
703 fun: &mut F,
704 x: &ArrayView1<f64>,
705 d: &ArrayView1<f64>,
706 f0: f64,
707 slope: f64,
708 ) -> LineSearchResult
709 where
710 F: FnMut(&ArrayView1<f64>) -> f64,
711 {
712 let mut alpha = self.alpha_init;
713 let n = x.len();
714 let mut n_fev = 0usize;
715
716 if slope >= 0.0 {
718 return LineSearchResult {
719 alpha: 1e-14,
720 f_val: f0,
721 grad: None,
722 derphi: None,
723 n_fev: 0,
724 n_gev: 0,
725 success: false,
726 message: "Backtracking: non-descent direction".to_string(),
727 };
728 }
729
730 for _ in 0..self.max_steps {
731 let mut x_new = Array1::zeros(n);
732 for i in 0..n {
733 x_new[i] = x[i] + alpha * d[i];
734 }
735
736 if let Some(ref b) = self.bounds {
738 if let Some(s) = x_new.as_slice_mut() {
739 b.project(s);
740 }
741 }
742
743 n_fev += 1;
744 let f_new = fun(&x_new.view());
745
746 if f_new <= f0 + self.c1 * alpha * slope {
747 return LineSearchResult {
748 alpha,
749 f_val: f_new,
750 grad: None,
751 derphi: None,
752 n_fev,
753 n_gev: 0,
754 success: true,
755 message: "Armijo condition satisfied".to_string(),
756 };
757 }
758
759 alpha *= self.rho;
760 if alpha < self.alpha_min {
761 return LineSearchResult {
762 alpha: self.alpha_min,
763 f_val: f_new,
764 grad: None,
765 derphi: None,
766 n_fev,
767 n_gev: 0,
768 success: false,
769 message: "Backtracking: alpha below minimum".to_string(),
770 };
771 }
772 }
773
774 let mut x_last = Array1::zeros(n);
775 for i in 0..n {
776 x_last[i] = x[i] + alpha * d[i];
777 }
778 n_fev += 1;
779 let f_last = fun(&x_last.view());
780
781 LineSearchResult {
782 alpha,
783 f_val: f_last,
784 grad: None,
785 derphi: None,
786 n_fev,
787 n_gev: 0,
788 success: false,
789 message: "Backtracking: max steps reached".to_string(),
790 }
791 }
792
793 pub fn search_scalar<Phi>(&self, mut phi: Phi, phi0: f64, dphi0: f64) -> LineSearchResult
795 where
796 Phi: FnMut(f64) -> f64,
797 {
798 let mut alpha = self.alpha_init;
799 let mut n_fev = 0usize;
800
801 if dphi0 >= 0.0 {
802 return LineSearchResult {
803 alpha: 1e-14,
804 f_val: phi0,
805 grad: None,
806 derphi: None,
807 n_fev: 0,
808 n_gev: 0,
809 success: false,
810 message: "Backtracking: non-descent direction".to_string(),
811 };
812 }
813
814 for _ in 0..self.max_steps {
815 n_fev += 1;
816 let fa = phi(alpha);
817 if fa <= phi0 + self.c1 * alpha * dphi0 {
818 return LineSearchResult {
819 alpha,
820 f_val: fa,
821 grad: None,
822 derphi: None,
823 n_fev,
824 n_gev: 0,
825 success: true,
826 message: "Armijo condition satisfied".to_string(),
827 };
828 }
829 alpha *= self.rho;
830 if alpha < self.alpha_min {
831 break;
832 }
833 }
834
835 LineSearchResult {
836 alpha,
837 f_val: phi(alpha),
838 grad: None,
839 derphi: None,
840 n_fev: n_fev + 1,
841 n_gev: 0,
842 success: false,
843 message: "Backtracking: max steps reached".to_string(),
844 }
845 }
846}
847
848#[cfg(test)]
851mod tests {
852 use super::*;
853
854 fn phi_quadratic(alpha: f64) -> f64 {
856 (1.0 - alpha).powi(2)
857 }
858 fn dphi_quadratic(alpha: f64) -> f64 {
859 -2.0 * (1.0 - alpha)
860 }
861
862 #[test]
863 fn test_strong_wolfe_quadratic() {
864 let sw = StrongWolfe::default_search();
865 let result = sw
866 .search(phi_quadratic, dphi_quadratic, 1.0, -2.0)
867 .expect("StrongWolfe failed");
868 assert!(result.success);
870 assert!(result.alpha > 0.0 && result.alpha <= 2.0);
871 assert!(result.f_val < 1.0);
873 }
874
875 #[test]
876 fn test_hager_zhang_quadratic() {
877 let hz = HagerZhang::default_search();
878 let result = hz
879 .search(phi_quadratic, dphi_quadratic, 1.0, -2.0)
880 .expect("HagerZhang failed");
881 assert!(result.alpha > 0.0);
882 assert!(result.f_val <= 1.0);
883 }
884
885 #[test]
886 fn test_backtracking_armijo_quadratic() {
887 let bt = BacktrackingArmijo::default();
888 let result = bt.search_scalar(phi_quadratic, 1.0, -2.0);
889 assert!(result.success);
890 assert!(result.f_val <= 1.0);
891 }
892
893 #[test]
894 fn test_safeguarded_powell_quadratic() {
895 let pw = SafeguardedPowell::default();
896 let result = pw
897 .search(phi_quadratic, dphi_quadratic, 1.0, -2.0)
898 .expect("Powell failed");
899 assert!(result.alpha > 0.0);
900 assert!(result.f_val < 1.0);
901 }
902
903 #[test]
904 fn test_backtracking_armijo_bad_direction() {
905 let bt = BacktrackingArmijo::default();
906 let result = bt.search_scalar(phi_quadratic, 1.0, 1.0); assert!(!result.success);
908 }
909
910 #[test]
911 fn test_strong_wolfe_bad_direction() {
912 let sw = StrongWolfe::default_search();
913 let err = sw.search(phi_quadratic, dphi_quadratic, 1.0, 1.0);
914 assert!(err.is_err());
915 }
916}