stepwise/algos/
gradient_descent.rs

1use crate::{
2    metrics::{GradL2Norm, Metric},
3    Algo, VectorExt as _,
4};
5use std::{convert::Infallible, ops::ControlFlow};
6
7///
8/// Simple Gradient Descent method.
9///
10/// See <https://en.wikipedia.org/wiki/Gradient_descent> for details on the algorithm.
11///
12/// The [`x`](Self::x) returned is a vector.
13///
14/// # Hyper-parameters
15///
16/// The only hyperparameter is the learning rate. Typically a value of 0.01 works well. Too high
17/// a learning rate can cause the algorithm to diverge or overshoot and oscillate around the solution,
18/// while too low a learning rate can cause the algorithm to converge too slowly.
19///
20/// The learning rate can be adaptively changed during the optimization process. For example,
21/// you can decrease the learning rate as the optimization progresses.
22//
23/// # Example
24///
25/// ```rust  
26/// use stepwise::{assert_approx_eq, problems::sphere_grad};
27/// use stepwise::{algos::GradientDescent, Driver, fixed_iters};
28///
29/// /// n-sphere gradient
30/// let sphere_grad = |x: &[f64]| x.iter().map(|x| 2.0 * x).collect::<Vec<_>>();
31///
32/// let learning_rate = 0.1;
33/// let initial_estimate = vec![5.5, 6.5];
34///
35/// let gd = GradientDescent::new(learning_rate, initial_estimate, sphere_grad);
36/// let (solved, _step) = fixed_iters(gd, 500)
37///     .solve()
38///     .expect("failed to solve");
39///
40/// assert_approx_eq!(solved.x(), [0.0, 0.0].as_slice());
41/// ```
42///
43/// # Example
44/// The learning rate is adaptively changed during the optimization process.
45///
46/// ```rust
47/// # use stepwise::{
48/// #    assert_approx_eq, assert_approx_ne, problems::sphere_grad,
49/// #    algos::GradientDescent, Driver, fixed_iters,
50/// # };
51/// # fn main() {
52/// # let x0 = vec![5.55, 5.55];
53/// let gd = GradientDescent::new(0.1, x0, sphere_grad);
54///
55/// let (solved, _step) = fixed_iters(gd, 200)
56///     .on_step(|algo, _step| algo.learning_rate *= 0.99 )
57///     .solve()
58///     .expect("failed to solve");
59///
60/// # let x = solved.x();
61/// # assert_approx_eq!([0.0, 0.0].as_slice(), &x, 0.01);
62/// # }
63/// ```
64#[derive(Debug, Clone, PartialEq)]
65pub struct GradientDescent<G> {
66    pub gradient_fn: G,
67    pub gradient: Vec<f64>,
68    pub x: Vec<f64>,
69    pub learning_rate: f64,
70}
71
72impl<G> GradientDescent<G>
73where
74    G: FnMut(&[f64]) -> Vec<f64>,
75{
76    /// Create a new `GradientDescent` algo with gradient_fn and initial guess x0.
77    /// The learning rate can be changed during the optimization process.
78    pub fn new(learning_rate: f64, x0: Vec<f64>, gradient_fn: G) -> Self {
79        Self {
80            gradient_fn,
81            gradient: vec![0.0; x0.len()],
82            x: x0,
83            learning_rate,
84        }
85    }
86
87    pub fn x(&self) -> &[f64] {
88        &self.x
89    }
90
91    pub fn update_gradient(&mut self) {
92        self.gradient = (self.gradient_fn)(&self.x);
93        #[allow(clippy::needless_range_loop)]
94        for i in 0..self.x.len() {
95            self.x[i] -= self.learning_rate * self.gradient[i];
96        }
97    }
98}
99
100/// Implement the `Algo` trait for `GradientDescent`.
101impl<G> Algo for GradientDescent<G>
102where
103    G: FnMut(&[f64]) -> Vec<f64>,
104{
105    type Error = Infallible;
106
107    fn step(&mut self) -> (ControlFlow<()>, Result<(), Infallible>) {
108        self.update_gradient();
109        // allow stepwise Driver to decide when to cease iteration
110        (ControlFlow::Continue(()), Ok(()))
111    }
112}
113
114impl<G> Metric<&GradientDescent<G>> for GradL2Norm {
115    type Output = f64;
116
117    fn observe_opt(&mut self, algo: &GradientDescent<G>) -> Option<Self::Output> {
118        Some(algo.gradient.norm_l2())
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use crate::{assert_approx_eq, fixed_iters, problems::sphere_grad};
126    use std::error::Error;
127
128    #[test]
129    fn test_doc_gradient_descent() {
130        let sphere_grad = |x: &[f64]| x.iter().map(|x| 2.0 * x).collect::<Vec<_>>();
131
132        let learning_rate = 0.1;
133        let initial_estimate = vec![5.55, 5.55];
134        let gd = GradientDescent::new(learning_rate, initial_estimate, sphere_grad);
135        let (solved, _step) = fixed_iters(gd, 500).solve().expect("failed to solve");
136        assert_approx_eq!(solved.x(), &[0.0, 0.0]);
137    }
138
139    #[test]
140    fn gradient_descent_core() -> Result<(), Box<dyn Error>> {
141        let gd = GradientDescent::new(0.01, vec![5.0, 5.0], sphere_grad);
142        let driver = fixed_iters(gd, 1000);
143        let (solved, _step) = driver
144            .on_step(|v, s| {
145                if s.iteration() % 100 == 0 {
146                    println!("{s:?} x: {:.9?}", v.x())
147                }
148            })
149            .solve()?;
150
151        let x = solved.x();
152        assert_approx_eq!(x, &[0.0, 0.0], 1e-5);
153        Ok(())
154    }
155}