scirs2_optimize/unconstrained/
line_search.rs

1//! Line search algorithms for optimization
2
3use crate::unconstrained::utils::clip_step;
4use crate::unconstrained::Bounds;
5use scirs2_core::ndarray::{Array1, ArrayView1};
6
7/// Backtracking line search with Armijo condition
8#[allow(clippy::too_many_arguments)]
9#[allow(dead_code)]
10pub fn backtracking_line_search<F, S>(
11    fun: &mut F,
12    x: &ArrayView1<f64>,
13    f0: f64,
14    direction: &ArrayView1<f64>,
15    grad: &ArrayView1<f64>,
16    alpha_init: f64,
17    c1: f64,
18    rho: f64,
19    bounds: Option<&Bounds>,
20) -> (f64, f64)
21where
22    F: FnMut(&ArrayView1<f64>) -> S,
23    S: Into<f64>,
24{
25    let mut alpha = alpha_init;
26    let slope = grad.dot(direction);
27
28    // Handle bounds constraints
29    if let Some(bounds) = bounds {
30        alpha = clip_step(x, direction, alpha, &bounds.lower, &bounds.upper);
31    }
32
33    // Backtracking loop
34    for _ in 0..50 {
35        let x_new = x + alpha * direction;
36        let f_new = fun(&x_new.view()).into();
37
38        if f_new <= f0 + c1 * alpha * slope {
39            return (alpha, f_new);
40        }
41
42        alpha *= rho;
43
44        if alpha < 1e-16 {
45            break;
46        }
47    }
48
49    (alpha, f0)
50}
51
52/// Strong Wolfe line search
53#[allow(clippy::too_many_arguments)]
54#[allow(dead_code)]
55pub fn strong_wolfe_line_search<F, S, G>(
56    fun: &mut F,
57    grad_fun: &mut G,
58    x: &ArrayView1<f64>,
59    f0: f64,
60    direction: &ArrayView1<f64>,
61    grad0: &ArrayView1<f64>,
62    alpha_init: f64,
63    c1: f64,
64    c2: f64,
65    bounds: Option<&Bounds>,
66) -> Result<(f64, f64, Array1<f64>), &'static str>
67where
68    F: FnMut(&ArrayView1<f64>) -> S,
69    S: Into<f64>,
70    G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
71{
72    let mut alpha = alpha_init;
73    let phi0 = f0;
74    let dphi0 = grad0.dot(direction);
75
76    if dphi0 >= 0.0 {
77        return Err("Search direction must be a descent direction");
78    }
79
80    // Handle bounds constraints
81    if let Some(bounds) = bounds {
82        alpha = clip_step(x, direction, alpha, &bounds.lower, &bounds.upper);
83    }
84
85    let mut alpha_lo = 0.0;
86    let alpha_hi = alpha;
87    let mut phi_lo = phi0;
88    let mut dphi_lo = dphi0;
89
90    for _ in 0..20 {
91        let x_new = x + alpha * direction;
92        let phi = fun(&x_new.view()).into();
93
94        if phi > phi0 + c1 * alpha * dphi0 || (phi >= phi_lo && alpha_lo > 0.0) {
95            return zoom(
96                fun, grad_fun, x, direction, alpha_lo, alpha, phi_lo, phi, dphi_lo, phi0, dphi0,
97                c1, c2,
98            );
99        }
100
101        let grad_new = grad_fun(&x_new.view());
102        let dphi = grad_new.dot(direction);
103
104        if dphi.abs() <= -c2 * dphi0 {
105            return Ok((alpha, phi, grad_new));
106        }
107
108        if dphi >= 0.0 {
109            return zoom(
110                fun, grad_fun, x, direction, alpha, alpha_lo, phi, phi_lo, dphi, phi0, dphi0, c1,
111                c2,
112            );
113        }
114
115        alpha_lo = alpha;
116        phi_lo = phi;
117        dphi_lo = dphi;
118
119        alpha = 0.5 * (alpha + alpha_hi);
120    }
121
122    Err("Line search failed to find a step satisfying the strong Wolfe conditions")
123}
124
125#[allow(clippy::too_many_arguments)]
126#[allow(dead_code)]
127fn zoom<F, S, G>(
128    fun: &mut F,
129    grad_fun: &mut G,
130    x: &ArrayView1<f64>,
131    direction: &ArrayView1<f64>,
132    mut alpha_lo: f64,
133    mut alpha_hi: f64,
134    mut phi_lo: f64,
135    mut _phi_hi: f64,
136    mut _dphi_lo: f64,
137    phi0: f64,
138    dphi0: f64,
139    c1: f64,
140    c2: f64,
141) -> Result<(f64, f64, Array1<f64>), &'static str>
142where
143    F: FnMut(&ArrayView1<f64>) -> S,
144    S: Into<f64>,
145    G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
146{
147    for _ in 0..10 {
148        let alpha = 0.5 * (alpha_lo + alpha_hi);
149        let x_new = x + alpha * direction;
150        let phi = fun(&x_new.view()).into();
151
152        if phi > phi0 + c1 * alpha * dphi0 || phi >= phi_lo {
153            alpha_hi = alpha;
154            _phi_hi = phi;
155        } else {
156            let grad_new = grad_fun(&x_new.view());
157            let dphi = grad_new.dot(direction);
158
159            if dphi.abs() <= -c2 * dphi0 {
160                return Ok((alpha, phi, grad_new));
161            }
162
163            if dphi * (alpha_hi - alpha_lo) >= 0.0 {
164                alpha_hi = alpha_lo;
165                _phi_hi = phi_lo;
166            }
167
168            alpha_lo = alpha;
169            phi_lo = phi;
170            _dphi_lo = dphi;
171        }
172
173        if (alpha_hi - alpha_lo).abs() < 1e-8 {
174            break;
175        }
176    }
177
178    Err("Zoom failed to find acceptable step")
179}
180
181/// Simple bracketing line search
182#[allow(dead_code)]
183pub fn bracketing_line_search<F, S>(
184    fun: &mut F,
185    x: &ArrayView1<f64>,
186    direction: &ArrayView1<f64>,
187    bounds: Option<&Bounds>,
188) -> f64
189where
190    F: FnMut(&ArrayView1<f64>) -> S,
191    S: Into<f64>,
192{
193    let mut a = 0.0;
194    let mut b = 1.0;
195
196    // Handle bounds
197    if let Some(bounds) = bounds {
198        b = clip_step(x, direction, b, &bounds.lower, &bounds.upper);
199    }
200
201    let fa = fun(&x.view()).into();
202    let mut x_new = x + b * direction;
203    let fb = fun(&x_new.view()).into();
204
205    // If b is better than a, expand the interval
206    if fb < fa {
207        while b < 10.0 {
208            let c = 2.0 * b;
209            x_new = x + c * direction;
210            let fc = fun(&x_new.view()).into();
211
212            if fc >= fb {
213                break;
214            }
215
216            a = b;
217            b = c;
218
219            if let Some(bounds) = bounds {
220                let max_step = clip_step(x, direction, b, &bounds.lower, &bounds.upper);
221                if max_step <= b {
222                    b = max_step;
223                    break;
224                }
225            }
226        }
227    }
228
229    // Binary search for the minimum
230    for _ in 0..20 {
231        let mid = 0.5 * (a + b);
232        x_new = x + mid * direction;
233        let fmid = fun(&x_new.view()).into();
234
235        x_new = x + (mid - 0.01) * direction;
236        let fleft = fun(&x_new.view()).into();
237
238        if fleft < fmid {
239            b = mid;
240        } else {
241            a = mid;
242        }
243
244        if (b - a) < 1e-6 {
245            break;
246        }
247    }
248
249    0.5 * (a + b)
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_backtracking_line_search() {
258        let mut quadratic = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[1] * x[1] };
259
260        let x = Array1::from_vec(vec![1.0, 1.0]);
261        let f0 = quadratic(&x.view());
262        let direction = Array1::from_vec(vec![-1.0, -1.0]);
263        let grad = Array1::from_vec(vec![2.0, 2.0]);
264
265        let (alpha, _f_new) = backtracking_line_search(
266            &mut quadratic,
267            &x.view(),
268            f0,
269            &direction.view(),
270            &grad.view(),
271            1.0,
272            0.0001,
273            0.5,
274            None,
275        );
276
277        assert!(alpha > 0.0);
278        assert!(alpha <= 1.0);
279    }
280}