torsh_functional/optimization/
line_search.rs1use super::utilities::{dot_product, tensor_add, tensor_scalar_mul};
8use torsh_core::{Result as TorshResult, TorshError};
9use torsh_tensor::Tensor;
10
11#[derive(Debug, Clone, Copy)]
13pub enum LineSearchMethod {
14 Backtracking,
16 Exact,
18 Wolfe,
20 StrongWolfe,
22}
23
24#[derive(Debug, Clone)]
26pub struct BacktrackingParams {
27 pub alpha0: f32,
29 pub c1: f32,
31 pub rho: f32,
33 pub max_iter: usize,
35}
36
37impl Default for BacktrackingParams {
38 fn default() -> Self {
39 Self {
40 alpha0: 1.0,
41 c1: 1e-4,
42 rho: 0.5,
43 max_iter: 50,
44 }
45 }
46}
47
48pub fn backtracking_line_search<F, G>(
60 objective: F,
61 gradient: G,
62 x: &Tensor,
63 p: &Tensor,
64 params: Option<BacktrackingParams>,
65) -> TorshResult<f32>
66where
67 F: Fn(&Tensor) -> TorshResult<f32>,
68 G: Fn(&Tensor) -> TorshResult<Tensor>,
69{
70 let params = params.unwrap_or_default();
71
72 let f0 = objective(x)?;
73 let grad0 = gradient(x)?;
74
75 let directional_deriv = dot_product(&grad0, p)?;
77
78 if directional_deriv >= 0.0 {
79 return Err(TorshError::InvalidArgument(
80 "Search direction is not a descent direction".to_string(),
81 ));
82 }
83
84 let mut alpha = params.alpha0;
85
86 for _ in 0..params.max_iter {
87 let x_new = tensor_add(x, &tensor_scalar_mul(p, alpha)?)?;
89 let f_new = objective(&x_new)?;
90
91 if f_new <= f0 + params.c1 * alpha * directional_deriv {
93 return Ok(alpha);
94 }
95
96 alpha *= params.rho;
97 }
98
99 Ok(alpha)
101}
102
103#[derive(Debug, Clone)]
105pub struct WolfeParams {
106 pub c1: f32,
108 pub c2: f32,
110 pub alpha0: f32,
112 pub alpha_max: f32,
114 pub max_iter: usize,
116}
117
118impl Default for WolfeParams {
119 fn default() -> Self {
120 Self {
121 c1: 1e-4,
122 c2: 0.9,
123 alpha0: 1.0,
124 alpha_max: 100.0,
125 max_iter: 20,
126 }
127 }
128}
129
130pub fn wolfe_line_search<F, G>(
141 objective: F,
142 gradient: G,
143 x: &Tensor,
144 p: &Tensor,
145 params: Option<WolfeParams>,
146) -> TorshResult<f32>
147where
148 F: Fn(&Tensor) -> TorshResult<f32>,
149 G: Fn(&Tensor) -> TorshResult<Tensor>,
150{
151 let params = params.unwrap_or_default();
152
153 let f0 = objective(x)?;
154 let grad0 = gradient(x)?;
155 let directional_deriv0 = dot_product(&grad0, p)?;
156
157 if directional_deriv0 >= 0.0 {
158 return Err(TorshError::InvalidArgument(
159 "Search direction is not a descent direction".to_string(),
160 ));
161 }
162
163 let mut alpha_lo = 0.0;
164 let mut alpha_hi = params.alpha_max;
165 let mut alpha = params.alpha0;
166
167 for _ in 0..params.max_iter {
168 let x_new = tensor_add(x, &tensor_scalar_mul(p, alpha)?)?;
169 let f_new = objective(&x_new)?;
170
171 if f_new > f0 + params.c1 * alpha * directional_deriv0 {
173 alpha_hi = alpha;
174 alpha = (alpha_lo + alpha_hi) / 2.0;
175 continue;
176 }
177
178 let grad_new = gradient(&x_new)?;
179 let directional_deriv_new = dot_product(&grad_new, p)?;
180
181 if directional_deriv_new.abs() <= -params.c2 * directional_deriv0 {
183 return Ok(alpha);
184 }
185
186 if directional_deriv_new >= 0.0 {
187 alpha_hi = alpha_lo;
188 }
189
190 alpha_lo = alpha;
191 alpha = if alpha_hi == params.alpha_max {
192 2.0 * alpha
193 } else {
194 (alpha_lo + alpha_hi) / 2.0
195 };
196 }
197
198 Ok(alpha)
199}