scirs2_optimize/unconstrained/
line_search.rs1use crate::unconstrained::utils::clip_step;
4use crate::unconstrained::Bounds;
5use scirs2_core::ndarray::{Array1, ArrayView1};
6
7#[allow(clippy::too_many_arguments)]
9#[allow(dead_code)]
10pub fn backtracking_line_search<F, S>(
11 fun: &mut F,
12 x: &ArrayView1<f64>,
13 f0: f64,
14 direction: &ArrayView1<f64>,
15 grad: &ArrayView1<f64>,
16 alpha_init: f64,
17 c1: f64,
18 rho: f64,
19 bounds: Option<&Bounds>,
20) -> (f64, f64)
21where
22 F: FnMut(&ArrayView1<f64>) -> S,
23 S: Into<f64>,
24{
25 let mut alpha = alpha_init;
26 let slope = grad.dot(direction);
27
28 if let Some(bounds) = bounds {
30 alpha = clip_step(x, direction, alpha, &bounds.lower, &bounds.upper);
31 }
32
33 for _ in 0..50 {
35 let x_new = x + alpha * direction;
36 let f_new = fun(&x_new.view()).into();
37
38 if f_new <= f0 + c1 * alpha * slope {
39 return (alpha, f_new);
40 }
41
42 alpha *= rho;
43
44 if alpha < 1e-16 {
45 break;
46 }
47 }
48
49 (alpha, f0)
50}
51
52#[allow(clippy::too_many_arguments)]
54#[allow(dead_code)]
55pub fn strong_wolfe_line_search<F, S, G>(
56 fun: &mut F,
57 grad_fun: &mut G,
58 x: &ArrayView1<f64>,
59 f0: f64,
60 direction: &ArrayView1<f64>,
61 grad0: &ArrayView1<f64>,
62 alpha_init: f64,
63 c1: f64,
64 c2: f64,
65 bounds: Option<&Bounds>,
66) -> Result<(f64, f64, Array1<f64>), &'static str>
67where
68 F: FnMut(&ArrayView1<f64>) -> S,
69 S: Into<f64>,
70 G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
71{
72 let mut alpha = alpha_init;
73 let phi0 = f0;
74 let dphi0 = grad0.dot(direction);
75
76 if dphi0 >= 0.0 {
77 return Err("Search direction must be a descent direction");
78 }
79
80 if let Some(bounds) = bounds {
82 alpha = clip_step(x, direction, alpha, &bounds.lower, &bounds.upper);
83 }
84
85 let mut alpha_lo = 0.0;
86 let alpha_hi = alpha;
87 let mut phi_lo = phi0;
88 let mut dphi_lo = dphi0;
89
90 for _ in 0..20 {
91 let x_new = x + alpha * direction;
92 let phi = fun(&x_new.view()).into();
93
94 if phi > phi0 + c1 * alpha * dphi0 || (phi >= phi_lo && alpha_lo > 0.0) {
95 return zoom(
96 fun, grad_fun, x, direction, alpha_lo, alpha, phi_lo, phi, dphi_lo, phi0, dphi0,
97 c1, c2,
98 );
99 }
100
101 let grad_new = grad_fun(&x_new.view());
102 let dphi = grad_new.dot(direction);
103
104 if dphi.abs() <= -c2 * dphi0 {
105 return Ok((alpha, phi, grad_new));
106 }
107
108 if dphi >= 0.0 {
109 return zoom(
110 fun, grad_fun, x, direction, alpha, alpha_lo, phi, phi_lo, dphi, phi0, dphi0, c1,
111 c2,
112 );
113 }
114
115 alpha_lo = alpha;
116 phi_lo = phi;
117 dphi_lo = dphi;
118
119 alpha = 0.5 * (alpha + alpha_hi);
120 }
121
122 Err("Line search failed to find a step satisfying the strong Wolfe conditions")
123}
124
125#[allow(clippy::too_many_arguments)]
126#[allow(dead_code)]
127fn zoom<F, S, G>(
128 fun: &mut F,
129 grad_fun: &mut G,
130 x: &ArrayView1<f64>,
131 direction: &ArrayView1<f64>,
132 mut alpha_lo: f64,
133 mut alpha_hi: f64,
134 mut phi_lo: f64,
135 mut _phi_hi: f64,
136 mut _dphi_lo: f64,
137 phi0: f64,
138 dphi0: f64,
139 c1: f64,
140 c2: f64,
141) -> Result<(f64, f64, Array1<f64>), &'static str>
142where
143 F: FnMut(&ArrayView1<f64>) -> S,
144 S: Into<f64>,
145 G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
146{
147 for _ in 0..10 {
148 let alpha = 0.5 * (alpha_lo + alpha_hi);
149 let x_new = x + alpha * direction;
150 let phi = fun(&x_new.view()).into();
151
152 if phi > phi0 + c1 * alpha * dphi0 || phi >= phi_lo {
153 alpha_hi = alpha;
154 _phi_hi = phi;
155 } else {
156 let grad_new = grad_fun(&x_new.view());
157 let dphi = grad_new.dot(direction);
158
159 if dphi.abs() <= -c2 * dphi0 {
160 return Ok((alpha, phi, grad_new));
161 }
162
163 if dphi * (alpha_hi - alpha_lo) >= 0.0 {
164 alpha_hi = alpha_lo;
165 _phi_hi = phi_lo;
166 }
167
168 alpha_lo = alpha;
169 phi_lo = phi;
170 _dphi_lo = dphi;
171 }
172
173 if (alpha_hi - alpha_lo).abs() < 1e-8 {
174 break;
175 }
176 }
177
178 Err("Zoom failed to find acceptable step")
179}
180
181#[allow(dead_code)]
183pub fn bracketing_line_search<F, S>(
184 fun: &mut F,
185 x: &ArrayView1<f64>,
186 direction: &ArrayView1<f64>,
187 bounds: Option<&Bounds>,
188) -> f64
189where
190 F: FnMut(&ArrayView1<f64>) -> S,
191 S: Into<f64>,
192{
193 let mut a = 0.0;
194 let mut b = 1.0;
195
196 if let Some(bounds) = bounds {
198 b = clip_step(x, direction, b, &bounds.lower, &bounds.upper);
199 }
200
201 let fa = fun(&x.view()).into();
202 let mut x_new = x + b * direction;
203 let fb = fun(&x_new.view()).into();
204
205 if fb < fa {
207 while b < 10.0 {
208 let c = 2.0 * b;
209 x_new = x + c * direction;
210 let fc = fun(&x_new.view()).into();
211
212 if fc >= fb {
213 break;
214 }
215
216 a = b;
217 b = c;
218
219 if let Some(bounds) = bounds {
220 let max_step = clip_step(x, direction, b, &bounds.lower, &bounds.upper);
221 if max_step <= b {
222 b = max_step;
223 break;
224 }
225 }
226 }
227 }
228
229 for _ in 0..20 {
231 let mid = 0.5 * (a + b);
232 x_new = x + mid * direction;
233 let fmid = fun(&x_new.view()).into();
234
235 x_new = x + (mid - 0.01) * direction;
236 let fleft = fun(&x_new.view()).into();
237
238 if fleft < fmid {
239 b = mid;
240 } else {
241 a = mid;
242 }
243
244 if (b - a) < 1e-6 {
245 break;
246 }
247 }
248
249 0.5 * (a + b)
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_backtracking_line_search() {
258 let mut quadratic = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[1] * x[1] };
259
260 let x = Array1::from_vec(vec![1.0, 1.0]);
261 let f0 = quadratic(&x.view());
262 let direction = Array1::from_vec(vec![-1.0, -1.0]);
263 let grad = Array1::from_vec(vec![2.0, 2.0]);
264
265 let (alpha, _f_new) = backtracking_line_search(
266 &mut quadratic,
267 &x.view(),
268 f0,
269 &direction.view(),
270 &grad.view(),
271 1.0,
272 0.0001,
273 0.5,
274 None,
275 );
276
277 assert!(alpha > 0.0);
278 assert!(alpha <= 1.0);
279 }
280}