Skip to main content

torsh_functional/optimization/
line_search.rs

1//! Line search methods for optimization algorithms
2//!
3//! This module provides various line search methods including backtracking line search
4//! and Wolfe line search, which are commonly used in optimization algorithms to find
5//! appropriate step sizes.
6
7use super::utilities::{dot_product, tensor_add, tensor_scalar_mul};
8use torsh_core::{Result as TorshResult, TorshError};
9use torsh_tensor::Tensor;
10
11/// Line search methods for optimization
12#[derive(Debug, Clone, Copy)]
13pub enum LineSearchMethod {
14    /// Backtracking line search with Armijo condition
15    Backtracking,
16    /// Exact line search (minimize along search direction)
17    Exact,
18    /// Wolfe conditions line search
19    Wolfe,
20    /// Strong Wolfe conditions
21    StrongWolfe,
22}
23
24/// Backtracking line search parameters
25#[derive(Debug, Clone)]
26pub struct BacktrackingParams {
27    /// Initial step size
28    pub alpha0: f32,
29    /// Armijo parameter (typically 1e-4)
30    pub c1: f32,
31    /// Backtracking factor (typically 0.5)
32    pub rho: f32,
33    /// Maximum number of backtracking steps
34    pub max_iter: usize,
35}
36
37impl Default for BacktrackingParams {
38    fn default() -> Self {
39        Self {
40            alpha0: 1.0,
41            c1: 1e-4,
42            rho: 0.5,
43            max_iter: 50,
44        }
45    }
46}
47
48/// Perform backtracking line search
49///
50/// Finds step size that satisfies Armijo condition:
51/// f(x + α*p) ≤ f(x) + c1*α*∇f(x)ᵀp
52///
53/// # Arguments
54/// * `objective` - Objective function f(x)
55/// * `gradient` - Gradient function ∇f(x)
56/// * `x` - Current point
57/// * `p` - Search direction
58/// * `params` - Line search parameters
59pub fn backtracking_line_search<F, G>(
60    objective: F,
61    gradient: G,
62    x: &Tensor,
63    p: &Tensor,
64    params: Option<BacktrackingParams>,
65) -> TorshResult<f32>
66where
67    F: Fn(&Tensor) -> TorshResult<f32>,
68    G: Fn(&Tensor) -> TorshResult<Tensor>,
69{
70    let params = params.unwrap_or_default();
71
72    let f0 = objective(x)?;
73    let grad0 = gradient(x)?;
74
75    // Compute directional derivative: ∇f(x)ᵀp
76    let directional_deriv = dot_product(&grad0, p)?;
77
78    if directional_deriv >= 0.0 {
79        return Err(TorshError::InvalidArgument(
80            "Search direction is not a descent direction".to_string(),
81        ));
82    }
83
84    let mut alpha = params.alpha0;
85
86    for _ in 0..params.max_iter {
87        // Compute x + α*p
88        let x_new = tensor_add(x, &tensor_scalar_mul(p, alpha)?)?;
89        let f_new = objective(&x_new)?;
90
91        // Check Armijo condition
92        if f_new <= f0 + params.c1 * alpha * directional_deriv {
93            return Ok(alpha);
94        }
95
96        alpha *= params.rho;
97    }
98
99    // Return the last step size if no convergence
100    Ok(alpha)
101}
102
103/// Wolfe line search parameters
104#[derive(Debug, Clone)]
105pub struct WolfeParams {
106    /// Armijo parameter (typically 1e-4)
107    pub c1: f32,
108    /// Curvature parameter (typically 0.9)
109    pub c2: f32,
110    /// Initial step size
111    pub alpha0: f32,
112    /// Maximum step size
113    pub alpha_max: f32,
114    /// Maximum number of iterations
115    pub max_iter: usize,
116}
117
118impl Default for WolfeParams {
119    fn default() -> Self {
120        Self {
121            c1: 1e-4,
122            c2: 0.9,
123            alpha0: 1.0,
124            alpha_max: 100.0,
125            max_iter: 20,
126        }
127    }
128}
129
130/// Perform Wolfe line search
131///
132/// Finds step size that satisfies both Armijo and curvature conditions.
133///
134/// # Arguments
135/// * `objective` - Objective function f(x)
136/// * `gradient` - Gradient function ∇f(x)
137/// * `x` - Current point
138/// * `p` - Search direction
139/// * `params` - Wolfe line search parameters
140pub fn wolfe_line_search<F, G>(
141    objective: F,
142    gradient: G,
143    x: &Tensor,
144    p: &Tensor,
145    params: Option<WolfeParams>,
146) -> TorshResult<f32>
147where
148    F: Fn(&Tensor) -> TorshResult<f32>,
149    G: Fn(&Tensor) -> TorshResult<Tensor>,
150{
151    let params = params.unwrap_or_default();
152
153    let f0 = objective(x)?;
154    let grad0 = gradient(x)?;
155    let directional_deriv0 = dot_product(&grad0, p)?;
156
157    if directional_deriv0 >= 0.0 {
158        return Err(TorshError::InvalidArgument(
159            "Search direction is not a descent direction".to_string(),
160        ));
161    }
162
163    let mut alpha_lo = 0.0;
164    let mut alpha_hi = params.alpha_max;
165    let mut alpha = params.alpha0;
166
167    for _ in 0..params.max_iter {
168        let x_new = tensor_add(x, &tensor_scalar_mul(p, alpha)?)?;
169        let f_new = objective(&x_new)?;
170
171        // Check Armijo condition
172        if f_new > f0 + params.c1 * alpha * directional_deriv0 {
173            alpha_hi = alpha;
174            alpha = (alpha_lo + alpha_hi) / 2.0;
175            continue;
176        }
177
178        let grad_new = gradient(&x_new)?;
179        let directional_deriv_new = dot_product(&grad_new, p)?;
180
181        // Check curvature condition
182        if directional_deriv_new.abs() <= -params.c2 * directional_deriv0 {
183            return Ok(alpha);
184        }
185
186        if directional_deriv_new >= 0.0 {
187            alpha_hi = alpha_lo;
188        }
189
190        alpha_lo = alpha;
191        alpha = if alpha_hi == params.alpha_max {
192            2.0 * alpha
193        } else {
194            (alpha_lo + alpha_hi) / 2.0
195        };
196    }
197
198    Ok(alpha)
199}