scirs2_optimize/unconstrained/
strong_wolfe.rs

1//! Enhanced Strong Wolfe conditions line search implementation
2//!
3//! This module provides robust implementations of line search algorithms that
4//! satisfy the Strong Wolfe conditions, which are essential for ensuring
5//! convergence of quasi-Newton methods.
6
7use crate::error::OptimizeError;
8use crate::unconstrained::utils::clip_step;
9use crate::unconstrained::Bounds;
10use scirs2_core::ndarray::{Array1, ArrayView1};
11
12/// Type alias for zoom search result to reduce type complexity
13type ZoomSearchResult = ((f64, f64, Array1<f64>), usize, usize);
14
15/// Strong Wolfe line search options
16#[derive(Debug, Clone)]
17pub struct StrongWolfeOptions {
18    /// Armijo condition parameter (typical: 1e-4)
19    pub c1: f64,
20    /// Curvature condition parameter (typical: 0.9 for Newton methods, 0.1 for CG)
21    pub c2: f64,
22    /// Initial step size
23    pub initial_step: f64,
24    /// Maximum step size
25    pub max_step: f64,
26    /// Minimum step size
27    pub min_step: f64,
28    /// Maximum number of function evaluations
29    pub max_fev: usize,
30    /// Tolerance for convergence
31    pub tolerance: f64,
32    /// Whether to use safeguarded interpolation
33    pub use_safeguarded_interpolation: bool,
34    /// Whether to use extrapolation in the first phase
35    pub use_extrapolation: bool,
36}
37
38impl Default for StrongWolfeOptions {
39    fn default() -> Self {
40        Self {
41            c1: 1e-4,
42            c2: 0.9,
43            initial_step: 1.0,
44            max_step: 1e10,
45            min_step: 1e-12,
46            max_fev: 100,
47            tolerance: 1e-10,
48            use_safeguarded_interpolation: true,
49            use_extrapolation: true,
50        }
51    }
52}
53
54/// Result of Strong Wolfe line search
55#[derive(Debug, Clone)]
56pub struct StrongWolfeResult {
57    /// Step size found
58    pub alpha: f64,
59    /// Function value at the step
60    pub f_new: f64,
61    /// Gradient at the step
62    pub g_new: Array1<f64>,
63    /// Number of function evaluations used
64    pub nfev: usize,
65    /// Number of gradient evaluations used
66    pub ngev: usize,
67    /// Whether the search was successful
68    pub success: bool,
69    /// Reason for termination
70    pub message: String,
71}
72
73/// Enhanced Strong Wolfe line search with robust implementation
74#[allow(clippy::too_many_arguments)]
75#[allow(dead_code)]
76pub fn strong_wolfe_line_search<F, G, S>(
77    fun: &mut F,
78    grad_fun: &mut G,
79    x: &ArrayView1<f64>,
80    f0: f64,
81    g0: &ArrayView1<f64>,
82    direction: &ArrayView1<f64>,
83    options: &StrongWolfeOptions,
84    bounds: Option<&Bounds>,
85) -> Result<StrongWolfeResult, OptimizeError>
86where
87    F: FnMut(&ArrayView1<f64>) -> S,
88    G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
89    S: Into<f64>,
90{
91    // Validate inputs
92    let derphi0 = g0.dot(direction);
93    if derphi0 >= 0.0 {
94        return Err(OptimizeError::ValueError(
95            "Search direction must be a descent direction".to_string(),
96        ));
97    }
98
99    if options.c1 <= 0.0 || options.c1 >= options.c2 || options.c2 >= 1.0 {
100        return Err(OptimizeError::ValueError(
101            "Invalid Wolfe parameters: must have 0 < c1 < c2 < 1".to_string(),
102        ));
103    }
104
105    let mut alpha = options.initial_step;
106    let mut nfev = 0;
107    let mut ngev = 0;
108
109    // Apply bounds constraints to initial step
110    if let Some(bounds) = bounds {
111        alpha = alpha.min(clip_step(x, direction, alpha, &bounds.lower, &bounds.upper));
112    }
113    alpha = alpha.min(options.max_step).max(options.min_step);
114
115    // Phase 1: Find an interval containing acceptable points
116    let (interval_result, fev1, gev1) = find_interval(
117        fun, grad_fun, x, f0, derphi0, direction, alpha, options, bounds,
118    )?;
119
120    nfev += fev1;
121    ngev += gev1;
122
123    match interval_result {
124        IntervalResult::Found(alpha, f_alpha, g_alpha) => Ok(StrongWolfeResult {
125            alpha,
126            f_new: f_alpha,
127            g_new: g_alpha,
128            nfev,
129            ngev,
130            success: true,
131            message: "Strong Wolfe conditions satisfied in interval search".to_string(),
132        }),
133        IntervalResult::Bracket(alpha_lo, alpha_hi, f_lo, f_hi, g_lo) => {
134            // Phase 2: Zoom to find exact step
135            let (zoom_result, fev2, gev2) = zoom_search(
136                fun, grad_fun, x, f0, derphi0, direction, alpha_lo, alpha_hi, f_lo, f_hi, g_lo,
137                options, bounds,
138            )?;
139
140            nfev += fev2;
141            ngev += gev2;
142
143            Ok(StrongWolfeResult {
144                alpha: zoom_result.0,
145                f_new: zoom_result.1,
146                g_new: zoom_result.2,
147                nfev,
148                ngev,
149                success: true,
150                message: "Strong Wolfe conditions satisfied in zoom phase".to_string(),
151            })
152        }
153        IntervalResult::Failed => Ok(StrongWolfeResult {
154            alpha: options.min_step,
155            f_new: f0,
156            g_new: g0.to_owned(),
157            nfev,
158            ngev,
159            success: false,
160            message: "Failed to find acceptable interval".to_string(),
161        }),
162    }
163}
164
165#[derive(Debug)]
166enum IntervalResult {
167    Found(f64, f64, Array1<f64>),     // alpha, f(alpha), g(alpha)
168    Bracket(f64, f64, f64, f64, f64), // alpha_lo, alpha_hi, f_lo, f_hi, derphi_lo
169    Failed,
170}
171
172/// Phase 1: Find an interval containing acceptable points
173#[allow(clippy::too_many_arguments)]
174#[allow(dead_code)]
175fn find_interval<F, G, S>(
176    fun: &mut F,
177    grad_fun: &mut G,
178    x: &ArrayView1<f64>,
179    f0: f64,
180    derphi0: f64,
181    direction: &ArrayView1<f64>,
182    mut alpha: f64,
183    options: &StrongWolfeOptions,
184    bounds: Option<&Bounds>,
185) -> Result<(IntervalResult, usize, usize), OptimizeError>
186where
187    F: FnMut(&ArrayView1<f64>) -> S,
188    G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
189    S: Into<f64>,
190{
191    let mut nfev = 0;
192    let mut ngev = 0;
193    let mut alpha_prev = 0.0;
194    let mut f_prev = f0;
195    let mut derphi_prev = derphi0;
196
197    for i in 0..options.max_fev {
198        // Ensure alpha is within bounds
199        if let Some(bounds) = bounds {
200            alpha = alpha.min(clip_step(x, direction, alpha, &bounds.lower, &bounds.upper));
201        }
202        alpha = alpha.min(options.max_step).max(options.min_step);
203
204        // Evaluate function at alpha
205        let x_alpha = x + alpha * direction;
206        let f_alpha = fun(&x_alpha.view()).into();
207        nfev += 1;
208
209        // Check Armijo condition and sufficient decrease
210        if f_alpha > f0 + options.c1 * alpha * derphi0 || (f_alpha >= f_prev && i > 0) {
211            // Found bracket: [alpha_prev, alpha]
212            return Ok((
213                IntervalResult::Bracket(alpha_prev, alpha, f_prev, f_alpha, derphi_prev),
214                nfev,
215                ngev,
216            ));
217        }
218
219        // Evaluate gradient at alpha
220        let g_alpha = grad_fun(&x_alpha.view());
221        let derphi_alpha = g_alpha.dot(direction);
222        ngev += 1;
223
224        // Check curvature condition (Strong Wolfe conditions)
225        if derphi_alpha.abs() <= -options.c2 * derphi0 {
226            // Found acceptable point
227            return Ok((IntervalResult::Found(alpha, f_alpha, g_alpha), nfev, ngev));
228        }
229
230        // Check if we've found a bracket due to positive derivative
231        if derphi_alpha >= 0.0 {
232            return Ok((
233                IntervalResult::Bracket(alpha, alpha_prev, f_alpha, f_prev, derphi_alpha),
234                nfev,
235                ngev,
236            ));
237        }
238
239        // Update for next iteration
240        alpha_prev = alpha;
241        f_prev = f_alpha;
242        derphi_prev = derphi_alpha;
243
244        // Extrapolate to get next alpha
245        if options.use_extrapolation {
246            alpha = if i == 0 {
247                alpha * 2.0
248            } else {
249                // Use safer extrapolation based on derivative information
250                alpha * (1.0 + 2.0 * derphi_alpha.abs() / derphi0.abs()).min(3.0)
251            };
252        } else {
253            alpha *= 2.0;
254        }
255
256        // Safety check: don't go too far
257        if alpha > options.max_step {
258            alpha = options.max_step;
259        }
260
261        // Check for convergence
262        if (alpha - alpha_prev).abs() < options.tolerance {
263            break;
264        }
265    }
266
267    Ok((IntervalResult::Failed, nfev, ngev))
268}
269
270/// Phase 2: Zoom search within bracket to find exact step
271#[allow(clippy::too_many_arguments)]
272#[allow(dead_code)]
273fn zoom_search<F, G, S>(
274    fun: &mut F,
275    grad_fun: &mut G,
276    x: &ArrayView1<f64>,
277    f0: f64,
278    derphi0: f64,
279    direction: &ArrayView1<f64>,
280    mut alpha_lo: f64,
281    mut alpha_hi: f64,
282    mut f_lo: f64,
283    mut f_hi: f64,
284    mut derphi_lo: f64,
285    options: &StrongWolfeOptions,
286    bounds: Option<&Bounds>,
287) -> Result<ZoomSearchResult, OptimizeError>
288where
289    F: FnMut(&ArrayView1<f64>) -> S,
290    G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
291    S: Into<f64>,
292{
293    let mut nfev = 0;
294    let mut ngev = 0;
295
296    for _ in 0..options.max_fev {
297        // Interpolate to find new trial point
298        let alpha = if options.use_safeguarded_interpolation {
299            safeguarded_interpolation(alpha_lo, alpha_hi, f_lo, f_hi, derphi_lo, derphi0)
300        } else {
301            0.5 * (alpha_lo + alpha_hi)
302        };
303
304        // Evaluate function at trial point
305        let x_alpha = x + alpha * direction;
306        let f_alpha = fun(&x_alpha.view()).into();
307        nfev += 1;
308
309        // Check Armijo condition
310        if f_alpha > f0 + options.c1 * alpha * derphi0 || f_alpha >= f_lo {
311            // Trial point violates Armijo condition, shrink interval
312            alpha_hi = alpha;
313            f_hi = f_alpha;
314        } else {
315            // Trial point satisfies Armijo condition, check curvature
316            let g_alpha = grad_fun(&x_alpha.view());
317            let derphi_alpha = g_alpha.dot(direction);
318            ngev += 1;
319
320            // Check Strong Wolfe conditions
321            if derphi_alpha.abs() <= -options.c2 * derphi0 {
322                // Found acceptable point
323                return Ok(((alpha, f_alpha, g_alpha), nfev, ngev));
324            }
325
326            // Update interval based on derivative sign
327            if derphi_alpha * (alpha_hi - alpha_lo) >= 0.0 {
328                alpha_hi = alpha_lo;
329                f_hi = f_lo;
330            }
331
332            alpha_lo = alpha;
333            f_lo = f_alpha;
334            derphi_lo = derphi_alpha;
335        }
336
337        // Check for convergence
338        if (alpha_hi - alpha_lo).abs() < options.tolerance {
339            break;
340        }
341    }
342
343    // If we reach here, return the best point found
344    let alpha = if f_lo < f_hi { alpha_lo } else { alpha_hi };
345    let x_alpha = x + alpha * direction;
346    let f_alpha = fun(&x_alpha.view()).into();
347    let g_alpha = grad_fun(&x_alpha.view());
348    nfev += 1;
349    ngev += 1;
350
351    Ok(((alpha, f_alpha, g_alpha), nfev, ngev))
352}
353
354/// Safeguarded cubic/quadratic interpolation for zoom phase
355#[allow(dead_code)]
356fn safeguarded_interpolation(
357    alpha_lo: f64,
358    alpha_hi: f64,
359    f_lo: f64,
360    f_hi: f64,
361    derphi_lo: f64,
362    _derphi0: f64,
363) -> f64 {
364    let delta = alpha_hi - alpha_lo;
365
366    // Try cubic interpolation first
367    let a = (f_hi - f_lo - derphi_lo * delta) / (delta * delta);
368    let b = derphi_lo;
369
370    if a.abs() > 1e-10 {
371        // Cubic interpolation
372        let discriminant = b * b - 3.0 * a * (f_lo - f_hi + derphi_lo * delta);
373        if discriminant >= 0.0 {
374            let alpha_c = alpha_lo + (-b + discriminant.sqrt()) / (3.0 * a);
375
376            // Safeguard: ensure the interpolated point is within bounds
377            let safeguard_lo = alpha_lo + 0.1 * delta;
378            let safeguard_hi = alpha_hi - 0.1 * delta;
379
380            if alpha_c >= safeguard_lo && alpha_c <= safeguard_hi {
381                return alpha_c;
382            }
383        }
384    }
385
386    // Fallback to quadratic interpolation
387    if derphi_lo.abs() > 1e-10 {
388        let alpha_q =
389            alpha_lo - 0.5 * derphi_lo * delta * delta / (f_hi - f_lo - derphi_lo * delta);
390        let safeguard_lo = alpha_lo + 0.1 * delta;
391        let safeguard_hi = alpha_hi - 0.1 * delta;
392
393        if alpha_q >= safeguard_lo && alpha_q <= safeguard_hi {
394            return alpha_q;
395        }
396    }
397
398    // Ultimate fallback: bisection
399    0.5 * (alpha_lo + alpha_hi)
400}
401
402/// Create Strong Wolfe options optimized for specific optimization methods
403#[allow(dead_code)]
404pub fn create_strong_wolfe_options_for_method(method: &str) -> StrongWolfeOptions {
405    match method.to_lowercase().as_str() {
406        "bfgs" | "lbfgs" | "sr1" | "dfp" => StrongWolfeOptions {
407            c1: 1e-4,
408            c2: 0.9,
409            initial_step: 1.0,
410            max_step: 1e4,
411            min_step: 1e-12,
412            max_fev: 50,
413            tolerance: 1e-10,
414            use_safeguarded_interpolation: true,
415            use_extrapolation: true,
416        },
417        "cg" | "conjugate_gradient" => StrongWolfeOptions {
418            c1: 1e-4,
419            c2: 0.1, // Smaller c2 for CG methods
420            initial_step: 1.0,
421            max_step: 1e4,
422            min_step: 1e-12,
423            max_fev: 50,
424            tolerance: 1e-10,
425            use_safeguarded_interpolation: true,
426            use_extrapolation: true,
427        },
428        "newton" => StrongWolfeOptions {
429            c1: 1e-4,
430            c2: 0.5, // Moderate c2 for Newton methods
431            initial_step: 1.0,
432            max_step: 1e6,
433            min_step: 1e-15,
434            max_fev: 100,
435            tolerance: 1e-12,
436            use_safeguarded_interpolation: true,
437            use_extrapolation: false, // More conservative for Newton
438        },
439        _ => StrongWolfeOptions::default(),
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446    use approx::assert_abs_diff_eq;
447
448    #[test]
449    fn test_strong_wolfe_quadratic() {
450        let mut quadratic = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[1] * x[1] };
451
452        let mut grad_quadratic =
453            |x: &ArrayView1<f64>| -> Array1<f64> { Array1::from_vec(vec![2.0 * x[0], 2.0 * x[1]]) };
454
455        let x = Array1::from_vec(vec![1.0, 1.0]);
456        let f0 = quadratic(&x.view());
457        let g0 = grad_quadratic(&x.view());
458        let direction = Array1::from_vec(vec![-1.0, -1.0]);
459
460        let options = StrongWolfeOptions::default();
461        let result = strong_wolfe_line_search(
462            &mut quadratic,
463            &mut grad_quadratic,
464            &x.view(),
465            f0,
466            &g0.view(),
467            &direction.view(),
468            &options,
469            None,
470        )
471        .unwrap();
472
473        assert!(result.success);
474        assert!(result.alpha > 0.0);
475
476        // For this quadratic, the exact minimum along the line should be at alpha = 1.0
477        assert_abs_diff_eq!(result.alpha, 1.0, epsilon = 1e-6);
478    }
479
480    #[test]
481    fn test_strong_wolfe_rosenbrock() {
482        let mut rosenbrock = |x: &ArrayView1<f64>| -> f64 {
483            let a = 1.0;
484            let b = 100.0;
485            (a - x[0]).powi(2) + b * (x[1] - x[0].powi(2)).powi(2)
486        };
487
488        let mut grad_rosenbrock = |x: &ArrayView1<f64>| -> Array1<f64> {
489            let a = 1.0;
490            let b = 100.0;
491            let grad_x0 = -2.0 * (a - x[0]) - 4.0 * b * x[0] * (x[1] - x[0].powi(2));
492            let grad_x1 = 2.0 * b * (x[1] - x[0].powi(2));
493            Array1::from_vec(vec![grad_x0, grad_x1])
494        };
495
496        let x = Array1::from_vec(vec![0.0, 0.0]);
497        let f0 = rosenbrock(&x.view());
498        let g0 = grad_rosenbrock(&x.view());
499        let direction = -&g0; // Steepest descent direction
500
501        let options = create_strong_wolfe_options_for_method("bfgs");
502        let result = strong_wolfe_line_search(
503            &mut rosenbrock,
504            &mut grad_rosenbrock,
505            &x.view(),
506            f0,
507            &g0.view(),
508            &direction.view(),
509            &options,
510            None,
511        )
512        .unwrap();
513
514        assert!(result.success);
515        assert!(result.alpha > 0.0);
516        assert!(result.f_new < f0); // Should decrease function value
517    }
518
519    #[test]
520    fn test_safeguarded_interpolation() {
521        let alpha_lo = 0.0;
522        let alpha_hi = 1.0;
523        let f_lo = 1.0;
524        let f_hi = 0.5;
525        let derphi_lo = -1.0;
526        let derphi0 = -1.0;
527
528        let alpha = safeguarded_interpolation(alpha_lo, alpha_hi, f_lo, f_hi, derphi_lo, derphi0);
529
530        // Should be within the safeguarded bounds
531        assert!(alpha >= alpha_lo + 0.1 * (alpha_hi - alpha_lo));
532        assert!(alpha <= alpha_hi - 0.1 * (alpha_hi - alpha_lo));
533    }
534
535    #[test]
536    fn test_method_specific_options() {
537        let bfgs_opts = create_strong_wolfe_options_for_method("bfgs");
538        assert_eq!(bfgs_opts.c2, 0.9);
539
540        let cg_opts = create_strong_wolfe_options_for_method("cg");
541        assert_eq!(cg_opts.c2, 0.1);
542
543        let newton_opts = create_strong_wolfe_options_for_method("newton");
544        assert_eq!(newton_opts.c2, 0.5);
545        assert!(!newton_opts.use_extrapolation);
546    }
547}