1use crate::error::OptimizeError;
8use crate::unconstrained::utils::clip_step;
9use crate::unconstrained::Bounds;
10use scirs2_core::ndarray::{Array1, ArrayView1};
11
12type ZoomSearchResult = ((f64, f64, Array1<f64>), usize, usize);
14
15#[derive(Debug, Clone)]
17pub struct StrongWolfeOptions {
18 pub c1: f64,
20 pub c2: f64,
22 pub initial_step: f64,
24 pub max_step: f64,
26 pub min_step: f64,
28 pub max_fev: usize,
30 pub tolerance: f64,
32 pub use_safeguarded_interpolation: bool,
34 pub use_extrapolation: bool,
36}
37
38impl Default for StrongWolfeOptions {
39 fn default() -> Self {
40 Self {
41 c1: 1e-4,
42 c2: 0.9,
43 initial_step: 1.0,
44 max_step: 1e10,
45 min_step: 1e-12,
46 max_fev: 100,
47 tolerance: 1e-10,
48 use_safeguarded_interpolation: true,
49 use_extrapolation: true,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct StrongWolfeResult {
57 pub alpha: f64,
59 pub f_new: f64,
61 pub g_new: Array1<f64>,
63 pub nfev: usize,
65 pub ngev: usize,
67 pub success: bool,
69 pub message: String,
71}
72
73#[allow(clippy::too_many_arguments)]
75#[allow(dead_code)]
76pub fn strong_wolfe_line_search<F, G, S>(
77 fun: &mut F,
78 grad_fun: &mut G,
79 x: &ArrayView1<f64>,
80 f0: f64,
81 g0: &ArrayView1<f64>,
82 direction: &ArrayView1<f64>,
83 options: &StrongWolfeOptions,
84 bounds: Option<&Bounds>,
85) -> Result<StrongWolfeResult, OptimizeError>
86where
87 F: FnMut(&ArrayView1<f64>) -> S,
88 G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
89 S: Into<f64>,
90{
91 let derphi0 = g0.dot(direction);
93 if derphi0 >= 0.0 {
94 return Err(OptimizeError::ValueError(
95 "Search direction must be a descent direction".to_string(),
96 ));
97 }
98
99 if options.c1 <= 0.0 || options.c1 >= options.c2 || options.c2 >= 1.0 {
100 return Err(OptimizeError::ValueError(
101 "Invalid Wolfe parameters: must have 0 < c1 < c2 < 1".to_string(),
102 ));
103 }
104
105 let mut alpha = options.initial_step;
106 let mut nfev = 0;
107 let mut ngev = 0;
108
109 if let Some(bounds) = bounds {
111 alpha = alpha.min(clip_step(x, direction, alpha, &bounds.lower, &bounds.upper));
112 }
113 alpha = alpha.min(options.max_step).max(options.min_step);
114
115 let (interval_result, fev1, gev1) = find_interval(
117 fun, grad_fun, x, f0, derphi0, direction, alpha, options, bounds,
118 )?;
119
120 nfev += fev1;
121 ngev += gev1;
122
123 match interval_result {
124 IntervalResult::Found(alpha, f_alpha, g_alpha) => Ok(StrongWolfeResult {
125 alpha,
126 f_new: f_alpha,
127 g_new: g_alpha,
128 nfev,
129 ngev,
130 success: true,
131 message: "Strong Wolfe conditions satisfied in interval search".to_string(),
132 }),
133 IntervalResult::Bracket(alpha_lo, alpha_hi, f_lo, f_hi, g_lo) => {
134 let (zoom_result, fev2, gev2) = zoom_search(
136 fun, grad_fun, x, f0, derphi0, direction, alpha_lo, alpha_hi, f_lo, f_hi, g_lo,
137 options, bounds,
138 )?;
139
140 nfev += fev2;
141 ngev += gev2;
142
143 Ok(StrongWolfeResult {
144 alpha: zoom_result.0,
145 f_new: zoom_result.1,
146 g_new: zoom_result.2,
147 nfev,
148 ngev,
149 success: true,
150 message: "Strong Wolfe conditions satisfied in zoom phase".to_string(),
151 })
152 }
153 IntervalResult::Failed => Ok(StrongWolfeResult {
154 alpha: options.min_step,
155 f_new: f0,
156 g_new: g0.to_owned(),
157 nfev,
158 ngev,
159 success: false,
160 message: "Failed to find acceptable interval".to_string(),
161 }),
162 }
163}
164
165#[derive(Debug)]
166enum IntervalResult {
167 Found(f64, f64, Array1<f64>), Bracket(f64, f64, f64, f64, f64), Failed,
170}
171
172#[allow(clippy::too_many_arguments)]
174#[allow(dead_code)]
175fn find_interval<F, G, S>(
176 fun: &mut F,
177 grad_fun: &mut G,
178 x: &ArrayView1<f64>,
179 f0: f64,
180 derphi0: f64,
181 direction: &ArrayView1<f64>,
182 mut alpha: f64,
183 options: &StrongWolfeOptions,
184 bounds: Option<&Bounds>,
185) -> Result<(IntervalResult, usize, usize), OptimizeError>
186where
187 F: FnMut(&ArrayView1<f64>) -> S,
188 G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
189 S: Into<f64>,
190{
191 let mut nfev = 0;
192 let mut ngev = 0;
193 let mut alpha_prev = 0.0;
194 let mut f_prev = f0;
195 let mut derphi_prev = derphi0;
196
197 for i in 0..options.max_fev {
198 if let Some(bounds) = bounds {
200 alpha = alpha.min(clip_step(x, direction, alpha, &bounds.lower, &bounds.upper));
201 }
202 alpha = alpha.min(options.max_step).max(options.min_step);
203
204 let x_alpha = x + alpha * direction;
206 let f_alpha = fun(&x_alpha.view()).into();
207 nfev += 1;
208
209 if f_alpha > f0 + options.c1 * alpha * derphi0 || (f_alpha >= f_prev && i > 0) {
211 return Ok((
213 IntervalResult::Bracket(alpha_prev, alpha, f_prev, f_alpha, derphi_prev),
214 nfev,
215 ngev,
216 ));
217 }
218
219 let g_alpha = grad_fun(&x_alpha.view());
221 let derphi_alpha = g_alpha.dot(direction);
222 ngev += 1;
223
224 if derphi_alpha.abs() <= -options.c2 * derphi0 {
226 return Ok((IntervalResult::Found(alpha, f_alpha, g_alpha), nfev, ngev));
228 }
229
230 if derphi_alpha >= 0.0 {
232 return Ok((
233 IntervalResult::Bracket(alpha, alpha_prev, f_alpha, f_prev, derphi_alpha),
234 nfev,
235 ngev,
236 ));
237 }
238
239 alpha_prev = alpha;
241 f_prev = f_alpha;
242 derphi_prev = derphi_alpha;
243
244 if options.use_extrapolation {
246 alpha = if i == 0 {
247 alpha * 2.0
248 } else {
249 alpha * (1.0 + 2.0 * derphi_alpha.abs() / derphi0.abs()).min(3.0)
251 };
252 } else {
253 alpha *= 2.0;
254 }
255
256 if alpha > options.max_step {
258 alpha = options.max_step;
259 }
260
261 if (alpha - alpha_prev).abs() < options.tolerance {
263 break;
264 }
265 }
266
267 Ok((IntervalResult::Failed, nfev, ngev))
268}
269
270#[allow(clippy::too_many_arguments)]
272#[allow(dead_code)]
273fn zoom_search<F, G, S>(
274 fun: &mut F,
275 grad_fun: &mut G,
276 x: &ArrayView1<f64>,
277 f0: f64,
278 derphi0: f64,
279 direction: &ArrayView1<f64>,
280 mut alpha_lo: f64,
281 mut alpha_hi: f64,
282 mut f_lo: f64,
283 mut f_hi: f64,
284 mut derphi_lo: f64,
285 options: &StrongWolfeOptions,
286 bounds: Option<&Bounds>,
287) -> Result<ZoomSearchResult, OptimizeError>
288where
289 F: FnMut(&ArrayView1<f64>) -> S,
290 G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
291 S: Into<f64>,
292{
293 let mut nfev = 0;
294 let mut ngev = 0;
295
296 for _ in 0..options.max_fev {
297 let alpha = if options.use_safeguarded_interpolation {
299 safeguarded_interpolation(alpha_lo, alpha_hi, f_lo, f_hi, derphi_lo, derphi0)
300 } else {
301 0.5 * (alpha_lo + alpha_hi)
302 };
303
304 let x_alpha = x + alpha * direction;
306 let f_alpha = fun(&x_alpha.view()).into();
307 nfev += 1;
308
309 if f_alpha > f0 + options.c1 * alpha * derphi0 || f_alpha >= f_lo {
311 alpha_hi = alpha;
313 f_hi = f_alpha;
314 } else {
315 let g_alpha = grad_fun(&x_alpha.view());
317 let derphi_alpha = g_alpha.dot(direction);
318 ngev += 1;
319
320 if derphi_alpha.abs() <= -options.c2 * derphi0 {
322 return Ok(((alpha, f_alpha, g_alpha), nfev, ngev));
324 }
325
326 if derphi_alpha * (alpha_hi - alpha_lo) >= 0.0 {
328 alpha_hi = alpha_lo;
329 f_hi = f_lo;
330 }
331
332 alpha_lo = alpha;
333 f_lo = f_alpha;
334 derphi_lo = derphi_alpha;
335 }
336
337 if (alpha_hi - alpha_lo).abs() < options.tolerance {
339 break;
340 }
341 }
342
343 let alpha = if f_lo < f_hi { alpha_lo } else { alpha_hi };
345 let x_alpha = x + alpha * direction;
346 let f_alpha = fun(&x_alpha.view()).into();
347 let g_alpha = grad_fun(&x_alpha.view());
348 nfev += 1;
349 ngev += 1;
350
351 Ok(((alpha, f_alpha, g_alpha), nfev, ngev))
352}
353
354#[allow(dead_code)]
356fn safeguarded_interpolation(
357 alpha_lo: f64,
358 alpha_hi: f64,
359 f_lo: f64,
360 f_hi: f64,
361 derphi_lo: f64,
362 _derphi0: f64,
363) -> f64 {
364 let delta = alpha_hi - alpha_lo;
365
366 let a = (f_hi - f_lo - derphi_lo * delta) / (delta * delta);
368 let b = derphi_lo;
369
370 if a.abs() > 1e-10 {
371 let discriminant = b * b - 3.0 * a * (f_lo - f_hi + derphi_lo * delta);
373 if discriminant >= 0.0 {
374 let alpha_c = alpha_lo + (-b + discriminant.sqrt()) / (3.0 * a);
375
376 let safeguard_lo = alpha_lo + 0.1 * delta;
378 let safeguard_hi = alpha_hi - 0.1 * delta;
379
380 if alpha_c >= safeguard_lo && alpha_c <= safeguard_hi {
381 return alpha_c;
382 }
383 }
384 }
385
386 if derphi_lo.abs() > 1e-10 {
388 let alpha_q =
389 alpha_lo - 0.5 * derphi_lo * delta * delta / (f_hi - f_lo - derphi_lo * delta);
390 let safeguard_lo = alpha_lo + 0.1 * delta;
391 let safeguard_hi = alpha_hi - 0.1 * delta;
392
393 if alpha_q >= safeguard_lo && alpha_q <= safeguard_hi {
394 return alpha_q;
395 }
396 }
397
398 0.5 * (alpha_lo + alpha_hi)
400}
401
402#[allow(dead_code)]
404pub fn create_strong_wolfe_options_for_method(method: &str) -> StrongWolfeOptions {
405 match method.to_lowercase().as_str() {
406 "bfgs" | "lbfgs" | "sr1" | "dfp" => StrongWolfeOptions {
407 c1: 1e-4,
408 c2: 0.9,
409 initial_step: 1.0,
410 max_step: 1e4,
411 min_step: 1e-12,
412 max_fev: 50,
413 tolerance: 1e-10,
414 use_safeguarded_interpolation: true,
415 use_extrapolation: true,
416 },
417 "cg" | "conjugate_gradient" => StrongWolfeOptions {
418 c1: 1e-4,
419 c2: 0.1, initial_step: 1.0,
421 max_step: 1e4,
422 min_step: 1e-12,
423 max_fev: 50,
424 tolerance: 1e-10,
425 use_safeguarded_interpolation: true,
426 use_extrapolation: true,
427 },
428 "newton" => StrongWolfeOptions {
429 c1: 1e-4,
430 c2: 0.5, initial_step: 1.0,
432 max_step: 1e6,
433 min_step: 1e-15,
434 max_fev: 100,
435 tolerance: 1e-12,
436 use_safeguarded_interpolation: true,
437 use_extrapolation: false, },
439 _ => StrongWolfeOptions::default(),
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use approx::assert_abs_diff_eq;
447
448 #[test]
449 fn test_strong_wolfe_quadratic() {
450 let mut quadratic = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[1] * x[1] };
451
452 let mut grad_quadratic =
453 |x: &ArrayView1<f64>| -> Array1<f64> { Array1::from_vec(vec![2.0 * x[0], 2.0 * x[1]]) };
454
455 let x = Array1::from_vec(vec![1.0, 1.0]);
456 let f0 = quadratic(&x.view());
457 let g0 = grad_quadratic(&x.view());
458 let direction = Array1::from_vec(vec![-1.0, -1.0]);
459
460 let options = StrongWolfeOptions::default();
461 let result = strong_wolfe_line_search(
462 &mut quadratic,
463 &mut grad_quadratic,
464 &x.view(),
465 f0,
466 &g0.view(),
467 &direction.view(),
468 &options,
469 None,
470 )
471 .unwrap();
472
473 assert!(result.success);
474 assert!(result.alpha > 0.0);
475
476 assert_abs_diff_eq!(result.alpha, 1.0, epsilon = 1e-6);
478 }
479
480 #[test]
481 fn test_strong_wolfe_rosenbrock() {
482 let mut rosenbrock = |x: &ArrayView1<f64>| -> f64 {
483 let a = 1.0;
484 let b = 100.0;
485 (a - x[0]).powi(2) + b * (x[1] - x[0].powi(2)).powi(2)
486 };
487
488 let mut grad_rosenbrock = |x: &ArrayView1<f64>| -> Array1<f64> {
489 let a = 1.0;
490 let b = 100.0;
491 let grad_x0 = -2.0 * (a - x[0]) - 4.0 * b * x[0] * (x[1] - x[0].powi(2));
492 let grad_x1 = 2.0 * b * (x[1] - x[0].powi(2));
493 Array1::from_vec(vec![grad_x0, grad_x1])
494 };
495
496 let x = Array1::from_vec(vec![0.0, 0.0]);
497 let f0 = rosenbrock(&x.view());
498 let g0 = grad_rosenbrock(&x.view());
499 let direction = -&g0; let options = create_strong_wolfe_options_for_method("bfgs");
502 let result = strong_wolfe_line_search(
503 &mut rosenbrock,
504 &mut grad_rosenbrock,
505 &x.view(),
506 f0,
507 &g0.view(),
508 &direction.view(),
509 &options,
510 None,
511 )
512 .unwrap();
513
514 assert!(result.success);
515 assert!(result.alpha > 0.0);
516 assert!(result.f_new < f0); }
518
519 #[test]
520 fn test_safeguarded_interpolation() {
521 let alpha_lo = 0.0;
522 let alpha_hi = 1.0;
523 let f_lo = 1.0;
524 let f_hi = 0.5;
525 let derphi_lo = -1.0;
526 let derphi0 = -1.0;
527
528 let alpha = safeguarded_interpolation(alpha_lo, alpha_hi, f_lo, f_hi, derphi_lo, derphi0);
529
530 assert!(alpha >= alpha_lo + 0.1 * (alpha_hi - alpha_lo));
532 assert!(alpha <= alpha_hi - 0.1 * (alpha_hi - alpha_lo));
533 }
534
535 #[test]
536 fn test_method_specific_options() {
537 let bfgs_opts = create_strong_wolfe_options_for_method("bfgs");
538 assert_eq!(bfgs_opts.c2, 0.9);
539
540 let cg_opts = create_strong_wolfe_options_for_method("cg");
541 assert_eq!(cg_opts.c2, 0.1);
542
543 let newton_opts = create_strong_wolfe_options_for_method("newton");
544 assert_eq!(newton_opts.c2, 0.5);
545 assert!(!newton_opts.use_extrapolation);
546 }
547}