1use crate::error::OptimizeError;
11use crate::unconstrained::utils::clip_step;
12use crate::unconstrained::Bounds;
13use scirs2_core::ndarray::{Array1, ArrayView1};
14use std::collections::VecDeque;
15
16#[derive(Debug, Clone, Copy)]
18pub enum LineSearchMethod {
19 HagerZhang,
21 NonMonotone,
23 EnhancedStrongWolfe,
25 MoreThuente,
27 Adaptive,
29}
30
31#[derive(Debug, Clone)]
33pub struct AdvancedLineSearchOptions {
34 pub method: LineSearchMethod,
36 pub c1: f64,
38 pub c2: f64,
40 pub max_ls_iter: usize,
42 pub alpha_min: f64,
44 pub alpha_max: f64,
46 pub alpha_init: f64,
48 pub step_tol: f64,
50 pub nm_memory: usize,
52 pub enable_adaptation: bool,
54 pub interpolation: InterpolationStrategy,
56}
57
58#[derive(Debug, Clone, Copy)]
60pub enum InterpolationStrategy {
61 Linear,
63 Quadratic,
65 Cubic,
67 Adaptive,
69}
70
71impl Default for AdvancedLineSearchOptions {
72 fn default() -> Self {
73 Self {
74 method: LineSearchMethod::HagerZhang,
75 c1: 1e-4,
76 c2: 0.9,
77 max_ls_iter: 20,
78 alpha_min: 1e-16,
79 alpha_max: 1e6,
80 alpha_init: 1.0,
81 step_tol: 1e-12,
82 nm_memory: 10,
83 enable_adaptation: true,
84 interpolation: InterpolationStrategy::Cubic,
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct LineSearchResult {
92 pub alpha: f64,
94 pub f_new: f64,
96 pub grad_new: Option<Array1<f64>>,
98 pub n_fev: usize,
100 pub n_gev: usize,
102 pub success: bool,
104 pub message: String,
106 pub stats: LineSearchStats,
108}
109
110#[derive(Debug, Clone)]
112pub struct LineSearchStats {
113 pub n_bracket: usize,
115 pub n_zoom: usize,
117 pub final_width: f64,
119 pub max_f_eval: f64,
121 pub interpolation_used: InterpolationStrategy,
123}
124
125pub struct NonMonotoneState {
127 f_history: VecDeque<f64>,
129 max_memory: usize,
131}
132
133impl NonMonotoneState {
134 fn new(max_memory: usize) -> Self {
135 Self {
136 f_history: VecDeque::new(),
137 max_memory,
138 }
139 }
140
141 fn update(&mut self, f_new: f64) {
142 self.f_history.push_back(f_new);
143 while self.f_history.len() > self.max_memory {
144 self.f_history.pop_front();
145 }
146 }
147
148 fn get_reference_value(&self) -> f64 {
149 self.f_history
150 .iter()
151 .fold(f64::NEG_INFINITY, |a, &b| a.max(b))
152 }
153}
154
155#[allow(dead_code)]
157pub fn advanced_line_search<F, G>(
158 fun: &mut F,
159 grad_fun: Option<&mut G>,
160 x: &ArrayView1<f64>,
161 f0: f64,
162 direction: &ArrayView1<f64>,
163 grad0: &ArrayView1<f64>,
164 options: &AdvancedLineSearchOptions,
165 bounds: Option<&Bounds>,
166 nm_state: Option<&mut NonMonotoneState>,
167) -> Result<LineSearchResult, OptimizeError>
168where
169 F: FnMut(&ArrayView1<f64>) -> f64,
170 G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
171{
172 match options.method {
173 LineSearchMethod::HagerZhang => {
174 hager_zhang_line_search(fun, grad_fun, x, f0, direction, grad0, options, bounds)
175 }
176 LineSearchMethod::NonMonotone => non_monotone_line_search(
177 fun, grad_fun, x, f0, direction, grad0, options, bounds, nm_state,
178 ),
179 LineSearchMethod::EnhancedStrongWolfe => {
180 enhanced_strong_wolfe(fun, grad_fun, x, f0, direction, grad0, options, bounds)
181 }
182 LineSearchMethod::MoreThuente => {
183 more_thuente_line_search(fun, grad_fun, x, f0, direction, grad0, options, bounds)
184 }
185 LineSearchMethod::Adaptive => {
186 adaptive_line_search(fun, grad_fun, x, f0, direction, grad0, options, bounds)
187 }
188 }
189}
190
191#[allow(dead_code)]
194fn hager_zhang_line_search<F, G>(
195 fun: &mut F,
196 mut grad_fun: Option<&mut G>,
197 x: &ArrayView1<f64>,
198 f0: f64,
199 direction: &ArrayView1<f64>,
200 grad0: &ArrayView1<f64>,
201 options: &AdvancedLineSearchOptions,
202 bounds: Option<&Bounds>,
203) -> Result<LineSearchResult, OptimizeError>
204where
205 F: FnMut(&ArrayView1<f64>) -> f64,
206 G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
207{
208 let mut stats = LineSearchStats {
209 n_bracket: 0,
210 n_zoom: 0,
211 final_width: 0.0,
212 max_f_eval: f0,
213 interpolation_used: options.interpolation,
214 };
215
216 let mut n_fev = 0;
217 let mut n_gev = 0;
218
219 let dphi0 = grad0.dot(direction);
220 if dphi0 >= 0.0 {
221 return Err(OptimizeError::ValueError(
222 "Search direction is not a descent direction".to_string(),
223 ));
224 }
225
226 let _epsilon = 1e-6;
228 let _theta = 0.5;
229 let _gamma = 0.66;
230 let sigma = 0.1;
231
232 let mut alpha = options.alpha_init;
233 if let Some(bounds) = bounds {
234 alpha = clip_step(x, direction, alpha, &bounds.lower, &bounds.upper);
235 }
236
237 let mut alpha_lo = 0.0;
239 let mut alpha_hi = alpha;
240 let mut phi_lo = f0;
241 let mut dphi_lo = dphi0;
242
243 let x_new = x + alpha * direction;
245 let phi = fun(&x_new.view());
246 n_fev += 1;
247 stats.max_f_eval = stats.max_f_eval.max(phi);
248
249 #[allow(clippy::explicit_counter_loop)]
251 for i in 0..options.max_ls_iter {
252 if phi <= f0 + options.c1 * alpha * dphi0 {
254 if let Some(ref mut grad_fun) = grad_fun {
256 let grad_new = grad_fun(&x_new.view());
257 n_gev += 1;
258 let dphi = grad_new.dot(direction);
259
260 if dphi >= options.c2 * dphi0 {
262 return Ok(LineSearchResult {
263 alpha,
264 f_new: phi,
265 grad_new: Some(grad_new),
266 n_fev,
267 n_gev,
268 success: true,
269 message: format!("Hager-Zhang converged in {} iterations", i + 1),
270 stats,
271 });
272 }
273
274 alpha_lo = alpha;
276 phi_lo = phi;
277 dphi_lo = dphi;
278
279 if dphi >= 0.0 {
281 alpha_hi = alpha;
282 break;
283 }
284
285 let alpha_new = if i == 0 {
287 alpha * (1.0 + 4.0 * (phi - f0 - alpha * dphi0) / (alpha * dphi0))
289 } else {
290 alpha + (alpha - alpha_lo) * dphi / (dphi_lo - dphi)
292 };
293
294 alpha = alpha_new
295 .max(alpha + sigma * (alpha_hi - alpha))
296 .min(alpha_hi);
297
298 if let Some(bounds) = bounds {
299 alpha = clip_step(x, direction, alpha, &bounds.lower, &bounds.upper);
300 }
301 } else {
302 alpha_lo = alpha;
304 phi_lo = phi;
305 alpha *= 2.0;
306 if let Some(bounds) = bounds {
307 alpha = clip_step(x, direction, alpha, &bounds.lower, &bounds.upper);
308 }
309 }
310 } else {
311 alpha_hi = alpha;
313 break;
314 }
315
316 if alpha >= options.alpha_max {
317 let _alpha_max = options.alpha_max; break;
319 }
320
321 let x_new = x + alpha * direction;
323 let phi = fun(&x_new.view());
324 n_fev += 1;
325 stats.max_f_eval = stats.max_f_eval.max(phi);
326 }
327
328 stats.n_bracket = 1;
329
330 for zoom_iter in 0..options.max_ls_iter {
332 stats.n_zoom += 1;
333
334 if (alpha_hi - alpha_lo).abs() < options.step_tol {
335 break;
336 }
337
338 alpha = interpolate_hager_zhang(alpha_lo, alpha_hi, phi_lo, f0, dphi0, dphi_lo, options);
340
341 if let Some(bounds) = bounds {
342 alpha = clip_step(x, direction, alpha, &bounds.lower, &bounds.upper);
343 }
344
345 let x_new = x + alpha * direction;
346 let phi = fun(&x_new.view());
347 n_fev += 1;
348 stats.max_f_eval = stats.max_f_eval.max(phi);
349
350 if phi <= f0 + options.c1 * alpha * dphi0 {
352 if let Some(ref mut grad_fun) = grad_fun {
353 let grad_new = grad_fun(&x_new.view());
354 n_gev += 1;
355 let dphi = grad_new.dot(direction);
356
357 if dphi >= options.c2 * dphi0 {
359 stats.final_width = (alpha_hi - alpha_lo).abs();
360 return Ok(LineSearchResult {
361 alpha,
362 f_new: phi,
363 grad_new: Some(grad_new),
364 n_fev,
365 n_gev,
366 success: true,
367 message: format!(
368 "Hager-Zhang zoom converged in {} iterations",
369 zoom_iter + 1
370 ),
371 stats,
372 });
373 }
374
375 if dphi >= 0.0 {
377 alpha_hi = alpha;
378 } else {
379 alpha_lo = alpha;
380 phi_lo = phi;
381 dphi_lo = dphi;
382 }
383 } else {
384 return Ok(LineSearchResult {
386 alpha,
387 f_new: phi,
388 grad_new: None,
389 n_fev,
390 n_gev,
391 success: true,
392 message: "Hager-Zhang converged (no gradient)".to_string(),
393 stats,
394 });
395 }
396 } else {
397 alpha_hi = alpha;
398 }
399 }
400
401 stats.final_width = (alpha_hi - alpha_lo).abs();
403 let x_final = x + alpha_lo * direction;
404 let f_final = fun(&x_final.view());
405 n_fev += 1;
406
407 Ok(LineSearchResult {
408 alpha: alpha_lo,
409 f_new: f_final,
410 grad_new: None,
411 n_fev,
412 n_gev,
413 success: false,
414 message: "Hager-Zhang reached maximum iterations".to_string(),
415 stats,
416 })
417}
418
419#[allow(dead_code)]
421fn interpolate_hager_zhang(
422 alpha_lo: f64,
423 alpha_hi: f64,
424 phi_lo: f64,
425 phi0: f64,
426 dphi0: f64,
427 dphi_lo: f64,
428 options: &AdvancedLineSearchOptions,
429) -> f64 {
430 match options.interpolation {
431 InterpolationStrategy::Cubic => {
432 let d1 = dphi_lo + dphi0 - 3.0 * (phi0 - phi_lo) / (alpha_lo - 0.0);
434 let d2_sign = if d1 * d1 - dphi_lo * dphi0 >= 0.0 {
435 1.0
436 } else {
437 -1.0
438 };
439 let d2 = d2_sign * (d1 * d1 - dphi_lo * dphi0).abs().sqrt();
440
441 let alpha_c =
442 alpha_lo - (alpha_lo - 0.0) * (dphi_lo + d2 - d1) / (dphi_lo - dphi0 + 2.0 * d2);
443
444 alpha_c
446 .max(alpha_lo + 0.01 * (alpha_hi - alpha_lo))
447 .min(alpha_hi - 0.01 * (alpha_hi - alpha_lo))
448 }
449 InterpolationStrategy::Quadratic => {
450 let alpha_q = alpha_lo
452 - 0.5 * dphi_lo * (alpha_lo * alpha_lo) / (phi_lo - phi0 - dphi0 * alpha_lo);
453 alpha_q
454 .max(alpha_lo + 0.01 * (alpha_hi - alpha_lo))
455 .min(alpha_hi - 0.01 * (alpha_hi - alpha_lo))
456 }
457 _ => {
458 0.5 * (alpha_lo + alpha_hi)
460 }
461 }
462}
463
464#[allow(dead_code)]
466fn non_monotone_line_search<F, G>(
467 fun: &mut F,
468 mut grad_fun: Option<&mut G>,
469 x: &ArrayView1<f64>,
470 f0: f64,
471 direction: &ArrayView1<f64>,
472 grad0: &ArrayView1<f64>,
473 options: &AdvancedLineSearchOptions,
474 bounds: Option<&Bounds>,
475 nm_state: Option<&mut NonMonotoneState>,
476) -> Result<LineSearchResult, OptimizeError>
477where
478 F: FnMut(&ArrayView1<f64>) -> f64,
479 G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
480{
481 let mut stats = LineSearchStats {
482 n_bracket: 0,
483 n_zoom: 0,
484 final_width: 0.0,
485 max_f_eval: f0,
486 interpolation_used: options.interpolation,
487 };
488
489 let mut n_fev = 0;
490 let mut n_gev = 0;
491
492 let dphi0 = grad0.dot(direction);
493 if dphi0 >= 0.0 {
494 return Err(OptimizeError::ValueError(
495 "Search direction is not a descent direction".to_string(),
496 ));
497 }
498
499 let f_ref = if let Some(ref nm_state_ref) = nm_state {
501 nm_state_ref.get_reference_value()
502 } else {
503 f0
504 };
505
506 let mut alpha = options.alpha_init;
507 if let Some(bounds) = bounds {
508 alpha = clip_step(x, direction, alpha, &bounds.lower, &bounds.upper);
509 }
510
511 #[allow(clippy::explicit_counter_loop)]
513 for i in 0..options.max_ls_iter {
514 let x_new = x + alpha * direction;
515 let phi = fun(&x_new.view());
516 n_fev += 1;
517 stats.max_f_eval = stats.max_f_eval.max(phi);
518
519 if phi <= f_ref + options.c1 * alpha * dphi0 {
521 if let Some(ref mut grad_fun) = grad_fun {
523 let grad_new = grad_fun(&x_new.view());
524 n_gev += 1;
525 let dphi = grad_new.dot(direction);
526
527 if dphi >= options.c2 * dphi0 {
528 if let Some(nm_state) = nm_state {
530 nm_state.update(phi);
531 }
532
533 return Ok(LineSearchResult {
534 alpha,
535 f_new: phi,
536 grad_new: Some(grad_new),
537 n_fev,
538 n_gev,
539 success: true,
540 message: format!("Non-monotone converged in {} iterations", i + 1),
541 stats,
542 });
543 }
544 } else {
545 if let Some(nm_state) = nm_state {
547 nm_state.update(phi);
548 }
549
550 return Ok(LineSearchResult {
551 alpha,
552 f_new: phi,
553 grad_new: None,
554 n_fev,
555 n_gev,
556 success: true,
557 message: format!(
558 "Non-monotone converged in {} iterations (no gradient)",
559 i + 1
560 ),
561 stats,
562 });
563 }
564 }
565
566 alpha *= 0.5;
568
569 if alpha < options.alpha_min {
570 break;
571 }
572 }
573
574 Err(OptimizeError::ComputationError(
575 "Non-monotone line search failed to find acceptable step".to_string(),
576 ))
577}
578
579#[allow(dead_code)]
581fn enhanced_strong_wolfe<F, G>(
582 fun: &mut F,
583 grad_fun: Option<&mut G>,
584 x: &ArrayView1<f64>,
585 f0: f64,
586 direction: &ArrayView1<f64>,
587 grad0: &ArrayView1<f64>,
588 options: &AdvancedLineSearchOptions,
589 bounds: Option<&Bounds>,
590) -> Result<LineSearchResult, OptimizeError>
591where
592 F: FnMut(&ArrayView1<f64>) -> f64,
593 G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
594{
595 let mut stats = LineSearchStats {
596 n_bracket: 0,
597 n_zoom: 0,
598 final_width: 0.0,
599 max_f_eval: f0,
600 interpolation_used: options.interpolation,
601 };
602
603 let mut n_fev = 0;
604 let mut n_gev = 0;
605
606 let dphi0 = grad0.dot(direction);
607 if dphi0 >= 0.0 {
608 return Err(OptimizeError::ValueError(
609 "Search direction is not a descent direction".to_string(),
610 ));
611 }
612
613 if grad_fun.is_none() {
614 return Err(OptimizeError::ValueError(
615 "Enhanced Strong Wolfe requires gradient function".to_string(),
616 ));
617 }
618
619 let grad_fun = grad_fun.unwrap();
620
621 let mut alpha = options.alpha_init;
622 if let Some(bounds) = bounds {
623 alpha = clip_step(x, direction, alpha, &bounds.lower, &bounds.upper);
624 }
625
626 let mut alpha_prev = 0.0;
627 let mut phi_prev = f0;
628 let alpha_max = options.alpha_max;
629
630 #[allow(clippy::explicit_counter_loop)]
632 for i in 0..options.max_ls_iter {
633 let x_new = x + alpha * direction;
634 let phi = fun(&x_new.view());
635 n_fev += 1;
636 stats.max_f_eval = stats.max_f_eval.max(phi);
637
638 if phi > f0 + options.c1 * alpha * dphi0 || (phi >= phi_prev && i > 0) {
640 stats.n_bracket = 1;
642 return enhanced_zoom(
643 fun, grad_fun, x, direction, alpha_prev, alpha, phi_prev, phi, f0, dphi0, options,
644 &mut stats, &mut n_fev, &mut n_gev,
645 );
646 }
647
648 let grad_new = grad_fun(&x_new.view());
649 n_gev += 1;
650 let dphi = grad_new.dot(direction);
651
652 if dphi.abs() <= -options.c2 * dphi0 {
654 stats.final_width = 0.0;
655 return Ok(LineSearchResult {
656 alpha,
657 f_new: phi,
658 grad_new: Some(grad_new),
659 n_fev,
660 n_gev,
661 success: true,
662 message: format!("Enhanced Strong Wolfe converged in {} iterations", i + 1),
663 stats,
664 });
665 }
666
667 if dphi >= 0.0 {
668 stats.n_bracket = 1;
670 return enhanced_zoom(
671 fun, grad_fun, x, direction, alpha, alpha_prev, phi, phi_prev, f0, dphi0, options,
672 &mut stats, &mut n_fev, &mut n_gev,
673 );
674 }
675
676 alpha_prev = alpha;
678 phi_prev = phi;
679
680 alpha = match options.interpolation {
681 InterpolationStrategy::Cubic => {
682 interpolate_cubic(alpha, alpha_max, phi, f0, dphi, dphi0)
683 }
684 InterpolationStrategy::Quadratic => interpolate_quadratic(alpha, phi, f0, dphi, dphi0),
685 _ => 2.0 * alpha,
686 };
687
688 alpha = alpha.min(alpha_max);
689
690 if let Some(bounds) = bounds {
691 alpha = clip_step(x, direction, alpha, &bounds.lower, &bounds.upper);
692 }
693 }
694
695 Err(OptimizeError::ComputationError(
696 "Enhanced Strong Wolfe failed to converge".to_string(),
697 ))
698}
699
700#[allow(dead_code)]
702fn enhanced_zoom<F, G>(
703 fun: &mut F,
704 grad_fun: &mut G,
705 x: &ArrayView1<f64>,
706 direction: &ArrayView1<f64>,
707 mut alpha_lo: f64,
708 mut alpha_hi: f64,
709 mut phi_lo: f64,
710 mut phi_hi: f64,
711 phi0: f64,
712 dphi0: f64,
713 options: &AdvancedLineSearchOptions,
714 stats: &mut LineSearchStats,
715 n_fev: &mut usize,
716 n_gev: &mut usize,
717) -> Result<LineSearchResult, OptimizeError>
718where
719 F: FnMut(&ArrayView1<f64>) -> f64,
720 G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
721{
722 for zoom_iter in 0..options.max_ls_iter {
723 stats.n_zoom += 1;
724
725 if (alpha_hi - alpha_lo).abs() < options.step_tol {
726 break;
727 }
728
729 let alpha = match options.interpolation {
731 InterpolationStrategy::Cubic => {
732 interpolate_cubic_zoom(alpha_lo, alpha_hi, phi_lo, phi_hi, phi0, dphi0)
733 }
734 InterpolationStrategy::Quadratic => {
735 interpolate_quadratic_zoom(alpha_lo, alpha_hi, phi_lo, phi_hi)
736 }
737 _ => 0.5 * (alpha_lo + alpha_hi),
738 };
739
740 let x_new = x + alpha * direction;
741 let phi = fun(&x_new.view());
742 *n_fev += 1;
743 stats.max_f_eval = stats.max_f_eval.max(phi);
744
745 if phi > phi0 + options.c1 * alpha * dphi0 || phi >= phi_lo {
746 alpha_hi = alpha;
747 phi_hi = phi;
748 } else {
749 let grad_new = grad_fun(&x_new.view());
750 *n_gev += 1;
751 let dphi = grad_new.dot(direction);
752
753 if dphi.abs() <= -options.c2 * dphi0 {
754 stats.final_width = (alpha_hi - alpha_lo).abs();
755 return Ok(LineSearchResult {
756 alpha,
757 f_new: phi,
758 grad_new: Some(grad_new),
759 n_fev: *n_fev,
760 n_gev: *n_gev,
761 success: true,
762 message: format!("Enhanced zoom converged in {} iterations", zoom_iter + 1),
763 stats: stats.clone(),
764 });
765 }
766
767 if dphi * (alpha_hi - alpha_lo) >= 0.0 {
768 alpha_hi = alpha_lo;
769 phi_hi = phi_lo;
770 }
771
772 alpha_lo = alpha;
773 phi_lo = phi;
774 }
775 }
776
777 Err(OptimizeError::ComputationError(
778 "Enhanced zoom failed to converge".to_string(),
779 ))
780}
781
782#[allow(dead_code)]
784fn more_thuente_line_search<F, G>(
785 fun: &mut F,
786 grad_fun: Option<&mut G>,
787 x: &ArrayView1<f64>,
788 f0: f64,
789 direction: &ArrayView1<f64>,
790 grad0: &ArrayView1<f64>,
791 options: &AdvancedLineSearchOptions,
792 bounds: Option<&Bounds>,
793) -> Result<LineSearchResult, OptimizeError>
794where
795 F: FnMut(&ArrayView1<f64>) -> f64,
796 G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
797{
798 enhanced_strong_wolfe(fun, grad_fun, x, f0, direction, grad0, options, bounds)
801}
802
803#[allow(dead_code)]
805fn adaptive_line_search<F, G>(
806 fun: &mut F,
807 grad_fun: Option<&mut G>,
808 x: &ArrayView1<f64>,
809 f0: f64,
810 direction: &ArrayView1<f64>,
811 grad0: &ArrayView1<f64>,
812 options: &AdvancedLineSearchOptions,
813 bounds: Option<&Bounds>,
814) -> Result<LineSearchResult, OptimizeError>
815where
816 F: FnMut(&ArrayView1<f64>) -> f64,
817 G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
818{
819 let grad_norm = grad0.mapv(|x| x.abs()).sum();
821 let _direction_norm = direction.mapv(|x| x.abs()).sum();
822
823 let mut adaptive_options = options.clone();
825
826 if grad_norm > 1e2 {
827 adaptive_options.c1 = 1e-3;
829 adaptive_options.c2 = 0.1;
830 adaptive_options.method = LineSearchMethod::HagerZhang;
831 } else if grad_norm < 1e-3 {
832 adaptive_options.c1 = 1e-5;
834 adaptive_options.c2 = 0.9;
835 adaptive_options.method = LineSearchMethod::EnhancedStrongWolfe;
836 } else {
837 adaptive_options.method = LineSearchMethod::HagerZhang;
839 }
840
841 match adaptive_options.method {
843 LineSearchMethod::HagerZhang => hager_zhang_line_search(
844 fun,
845 grad_fun,
846 x,
847 f0,
848 direction,
849 grad0,
850 &adaptive_options,
851 bounds,
852 ),
853 LineSearchMethod::EnhancedStrongWolfe => enhanced_strong_wolfe(
854 fun,
855 grad_fun,
856 x,
857 f0,
858 direction,
859 grad0,
860 &adaptive_options,
861 bounds,
862 ),
863 _ => {
864 hager_zhang_line_search(
866 fun,
867 grad_fun,
868 x,
869 f0,
870 direction,
871 grad0,
872 &adaptive_options,
873 bounds,
874 )
875 }
876 }
877}
878
879#[allow(dead_code)]
882fn interpolate_cubic(
883 alpha: f64,
884 alpha_max: f64,
885 phi: f64,
886 phi0: f64,
887 dphi: f64,
888 dphi0: f64,
889) -> f64 {
890 let d1 = dphi + dphi0 - 3.0 * (phi0 - phi) / alpha;
891 let d2_term = d1 * d1 - dphi * dphi0;
892 if d2_term >= 0.0 {
893 let d2 = d2_term.sqrt();
894 let alpha_c = alpha * (1.0 - (dphi + d2 - d1) / (dphi - dphi0 + 2.0 * d2));
895 alpha_c.max(1.1 * alpha).min(0.9 * alpha_max)
896 } else {
897 2.0 * alpha
898 }
899}
900
901#[allow(dead_code)]
902fn interpolate_quadratic(alpha: f64, phi: f64, phi0: f64, dphi: f64, dphi0: f64) -> f64 {
903 let alpha_q = -dphi0 * alpha * alpha / (2.0 * (phi - phi0 - dphi0 * alpha));
904 alpha_q.max(1.1 * alpha)
905}
906
907#[allow(dead_code)]
908fn interpolate_cubic_zoom(
909 alpha_lo: f64,
910 alpha_hi: f64,
911 phi_lo: f64,
912 phi_hi: f64,
913 _phi0: f64,
914 dphi0: f64,
915) -> f64 {
916 let d = alpha_hi - alpha_lo;
917 let a = (phi_hi - phi_lo - dphi0 * d) / (d * d);
918 let b = dphi0;
919
920 if a != 0.0 {
921 let discriminant = b * b - 3.0 * a * phi_lo;
922 if discriminant >= 0.0 && a > 0.0 {
923 let alpha_c = alpha_lo + (-b + discriminant.sqrt()) / (3.0 * a);
924 return alpha_c.max(alpha_lo + 0.01 * d).min(alpha_hi - 0.01 * d);
925 }
926 }
927
928 0.5 * (alpha_lo + alpha_hi)
930}
931
932#[allow(dead_code)]
933fn interpolate_quadratic_zoom(alpha_lo: f64, alpha_hi: f64, phi_lo: f64, phi_hi: f64) -> f64 {
934 let d = alpha_hi - alpha_lo;
935 let a = (phi_hi - phi_lo) / (d * d);
936
937 if a > 0.0 {
938 let alpha_q = alpha_lo + 0.5 * d;
939 alpha_q.max(alpha_lo + 0.01 * d).min(alpha_hi - 0.01 * d)
940 } else {
941 0.5 * (alpha_lo + alpha_hi)
942 }
943}
944
945#[allow(dead_code)]
947pub fn create_non_monotone_state(_memory_size: usize) -> NonMonotoneState {
948 NonMonotoneState::new(_memory_size)
949}
950
951#[cfg(test)]
952mod tests {
953 use super::*;
954 use approx::assert_abs_diff_eq;
955
956 #[test]
957 fn test_hager_zhang_line_search() {
958 let mut quadratic = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[1] * x[1] };
959
960 let mut grad =
961 |x: &ArrayView1<f64>| -> Array1<f64> { Array1::from_vec(vec![2.0 * x[0], 2.0 * x[1]]) };
962
963 let x = Array1::from_vec(vec![1.0, 1.0]);
964 let f0 = quadratic(&x.view());
965 let direction = Array1::from_vec(vec![-1.0, -1.0]);
966 let grad0 = grad(&x.view());
967
968 let options = AdvancedLineSearchOptions::default();
969
970 let result = hager_zhang_line_search(
971 &mut quadratic,
972 Some(&mut grad),
973 &x.view(),
974 f0,
975 &direction.view(),
976 &grad0.view(),
977 &options,
978 None,
979 )
980 .unwrap();
981
982 assert!(result.success);
983 assert!(result.alpha > 0.0);
984 assert!(result.f_new < f0);
985 }
986
987 #[test]
988 fn test_non_monotone_state() {
989 let mut nm_state = NonMonotoneState::new(3);
990
991 nm_state.update(10.0);
992 assert_abs_diff_eq!(nm_state.get_reference_value(), 10.0, epsilon = 1e-10);
993
994 nm_state.update(5.0);
995 assert_abs_diff_eq!(nm_state.get_reference_value(), 10.0, epsilon = 1e-10);
996
997 nm_state.update(15.0);
998 assert_abs_diff_eq!(nm_state.get_reference_value(), 15.0, epsilon = 1e-10);
999
1000 nm_state.update(8.0);
1002 nm_state.update(12.0);
1003 assert_abs_diff_eq!(nm_state.get_reference_value(), 15.0, epsilon = 1e-10);
1004 }
1005
1006 #[test]
1007 fn test_interpolation_methods() {
1008 let alpha = interpolate_cubic(1.0, 10.0, 5.0, 10.0, -2.0, -5.0);
1009 assert!(alpha > 1.0);
1010
1011 let alpha_q = interpolate_quadratic(1.0, 5.0, 10.0, -2.0, -5.0);
1012 assert!(alpha_q > 1.0);
1013 }
1014}