Skip to main content

torsh_functional/optimization/
gradient_descent.rs

1//! Gradient descent variants for optimization
2//!
3//! This module provides various gradient descent optimization algorithms including
4//! basic gradient descent, momentum gradient descent, and Adam optimizer.
5
6use super::line_search::{backtracking_line_search, wolfe_line_search, LineSearchMethod};
7use super::utilities::{
8    tensor_add, tensor_elementwise_div, tensor_elementwise_mul, tensor_full_like, tensor_norm,
9    tensor_scalar_mul, tensor_sqrt, tensor_sub, tensor_zeros_like,
10};
11use torsh_core::Result as TorshResult;
12use torsh_tensor::Tensor;
13
14/// Gradient descent optimizer parameters
15#[derive(Debug, Clone)]
16pub struct GradientDescentParams {
17    /// Learning rate
18    pub learning_rate: f32,
19    /// Maximum number of iterations
20    pub max_iter: usize,
21    /// Convergence tolerance
22    pub tolerance: f32,
23    /// Line search method
24    pub line_search: Option<LineSearchMethod>,
25}
26
27impl Default for GradientDescentParams {
28    fn default() -> Self {
29        Self {
30            learning_rate: 0.01,
31            max_iter: 1000,
32            tolerance: 1e-6,
33            line_search: Some(LineSearchMethod::Backtracking),
34        }
35    }
36}
37
38/// Basic gradient descent optimization
39///
40/// # Arguments
41/// * `objective` - Objective function to minimize
42/// * `gradient` - Gradient function
43/// * `x0` - Initial point
44/// * `params` - Optimization parameters
45pub fn gradient_descent<F, G>(
46    objective: F,
47    gradient: G,
48    x0: &Tensor,
49    params: Option<GradientDescentParams>,
50) -> TorshResult<(Tensor, Vec<f32>)>
51where
52    F: Fn(&Tensor) -> TorshResult<f32>,
53    G: Fn(&Tensor) -> TorshResult<Tensor>,
54{
55    let params = params.unwrap_or_default();
56    let mut x = x0.clone();
57    let mut objective_values = Vec::new();
58
59    for iter in 0..params.max_iter {
60        let f_val = objective(&x)?;
61        objective_values.push(f_val);
62
63        let grad = gradient(&x)?;
64        let grad_norm = tensor_norm(&grad)?;
65
66        if grad_norm < params.tolerance {
67            break;
68        }
69
70        // Search direction is negative gradient
71        let p = tensor_scalar_mul(&grad, -1.0)?;
72
73        // Determine step size
74        let alpha = match params.line_search {
75            Some(LineSearchMethod::Backtracking) => {
76                backtracking_line_search(&objective, &gradient, &x, &p, None)?
77            }
78            Some(LineSearchMethod::Wolfe) => {
79                wolfe_line_search(&objective, &gradient, &x, &p, None)?
80            }
81            _ => params.learning_rate,
82        };
83
84        // Update: x = x + α*p
85        x = tensor_add(&x, &tensor_scalar_mul(&p, alpha)?)?;
86
87        if iter % 100 == 0 {
88            println!(
89                "Iteration {}: f = {:.6e}, |∇f| = {:.6e}, α = {:.6e}",
90                iter, f_val, grad_norm, alpha
91            );
92        }
93    }
94
95    Ok((x, objective_values))
96}
97
98/// Momentum gradient descent parameters
99#[derive(Debug, Clone)]
100pub struct MomentumParams {
101    /// Learning rate
102    pub learning_rate: f32,
103    /// Momentum parameter (typically 0.9)
104    pub momentum: f32,
105    /// Maximum number of iterations
106    pub max_iter: usize,
107    /// Convergence tolerance
108    pub tolerance: f32,
109}
110
111impl Default for MomentumParams {
112    fn default() -> Self {
113        Self {
114            learning_rate: 0.01,
115            momentum: 0.9,
116            max_iter: 1000,
117            tolerance: 1e-6,
118        }
119    }
120}
121
122/// Momentum gradient descent
123///
124/// Updates: v = β*v + ∇f(x), x = x - α*v
125///
126/// # Arguments
127/// * `objective` - Objective function to minimize
128/// * `gradient` - Gradient function
129/// * `x0` - Initial point
130/// * `params` - Optimization parameters
131pub fn momentum_gradient_descent<F, G>(
132    objective: F,
133    gradient: G,
134    x0: &Tensor,
135    params: Option<MomentumParams>,
136) -> TorshResult<(Tensor, Vec<f32>)>
137where
138    F: Fn(&Tensor) -> TorshResult<f32>,
139    G: Fn(&Tensor) -> TorshResult<Tensor>,
140{
141    let params = params.unwrap_or_default();
142    let mut x = x0.clone();
143    let mut v = tensor_zeros_like(&x)?; // Initialize momentum to zero
144    let mut objective_values = Vec::new();
145
146    for iter in 0..params.max_iter {
147        let f_val = objective(&x)?;
148        objective_values.push(f_val);
149
150        let grad = gradient(&x)?;
151        let grad_norm = tensor_norm(&grad)?;
152
153        if grad_norm < params.tolerance {
154            break;
155        }
156
157        // Update momentum: v = β*v + ∇f(x)
158        v = tensor_add(&tensor_scalar_mul(&v, params.momentum)?, &grad)?;
159
160        // Update parameters: x = x - α*v
161        x = tensor_sub(&x, &tensor_scalar_mul(&v, params.learning_rate)?)?;
162
163        if iter % 100 == 0 {
164            println!(
165                "Iteration {}: f = {:.6e}, |∇f| = {:.6e}",
166                iter, f_val, grad_norm
167            );
168        }
169    }
170
171    Ok((x, objective_values))
172}
173
174/// Adam optimizer parameters
175#[derive(Debug, Clone)]
176pub struct AdamParams {
177    /// Learning rate
178    pub learning_rate: f32,
179    /// First moment decay rate (typically 0.9)
180    pub beta1: f32,
181    /// Second moment decay rate (typically 0.999)
182    pub beta2: f32,
183    /// Small constant for numerical stability (typically 1e-8)
184    pub epsilon: f32,
185    /// Maximum number of iterations
186    pub max_iter: usize,
187    /// Convergence tolerance
188    pub tolerance: f32,
189}
190
191impl Default for AdamParams {
192    fn default() -> Self {
193        Self {
194            learning_rate: 0.001,
195            beta1: 0.9,
196            beta2: 0.999,
197            epsilon: 1e-8,
198            max_iter: 1000,
199            tolerance: 1e-6,
200        }
201    }
202}
203
204/// Adam optimizer
205///
206/// Adaptive moment estimation optimizer that combines momentum and RMSprop.
207///
208/// # Arguments
209/// * `objective` - Objective function to minimize
210/// * `gradient` - Gradient function
211/// * `x0` - Initial point
212/// * `params` - Optimization parameters
213pub fn adam_optimizer<F, G>(
214    objective: F,
215    gradient: G,
216    x0: &Tensor,
217    params: Option<AdamParams>,
218) -> TorshResult<(Tensor, Vec<f32>)>
219where
220    F: Fn(&Tensor) -> TorshResult<f32>,
221    G: Fn(&Tensor) -> TorshResult<Tensor>,
222{
223    let params = params.unwrap_or_default();
224    let mut x = x0.clone();
225    let mut m = tensor_zeros_like(&x)?; // First moment
226    let mut v = tensor_zeros_like(&x)?; // Second moment
227    let mut objective_values = Vec::new();
228
229    for iter in 0..params.max_iter {
230        let t = (iter + 1) as f32; // Time step
231
232        let f_val = objective(&x)?;
233        objective_values.push(f_val);
234
235        let grad = gradient(&x)?;
236        let grad_norm = tensor_norm(&grad)?;
237
238        if grad_norm < params.tolerance {
239            break;
240        }
241
242        // Update first moment: m = β1*m + (1-β1)*g
243        m = tensor_add(
244            &tensor_scalar_mul(&m, params.beta1)?,
245            &tensor_scalar_mul(&grad, 1.0 - params.beta1)?,
246        )?;
247
248        // Update second moment: v = β2*v + (1-β2)*g²
249        let grad_squared = tensor_elementwise_mul(&grad, &grad)?;
250        v = tensor_add(
251            &tensor_scalar_mul(&v, params.beta2)?,
252            &tensor_scalar_mul(&grad_squared, 1.0 - params.beta2)?,
253        )?;
254
255        // Bias correction
256        let m_hat = tensor_scalar_mul(&m, 1.0 / (1.0 - params.beta1.powf(t)))?;
257        let v_hat = tensor_scalar_mul(&v, 1.0 / (1.0 - params.beta2.powf(t)))?;
258
259        // Update parameters: x = x - α * m̂ / (√v̂ + ε)
260        let denominator = tensor_add(
261            &tensor_sqrt(&v_hat)?,
262            &tensor_full_like(&x, params.epsilon)?,
263        )?;
264        let update = tensor_elementwise_div(&m_hat, &denominator)?;
265        x = tensor_sub(&x, &tensor_scalar_mul(&update, params.learning_rate)?)?;
266
267        if iter % 100 == 0 {
268            println!(
269                "Iteration {}: f = {:.6e}, |∇f| = {:.6e}",
270                iter, f_val, grad_norm
271            );
272        }
273    }
274
275    Ok((x, objective_values))
276}