stepwise/algos/gradient_descent.rs
1use crate::{
2 metrics::{GradL2Norm, Metric},
3 Algo, VectorExt as _,
4};
5use std::{convert::Infallible, ops::ControlFlow};
6
7///
8/// Simple Gradient Descent method.
9///
10/// See <https://en.wikipedia.org/wiki/Gradient_descent> for details on the algorithm.
11///
12/// The [`x`](Self::x) returned is a vector.
13///
14/// # Hyper-parameters
15///
16/// The only hyperparameter is the learning rate. Typically a value of 0.01 works well. Too high
17/// a learning rate can cause the algorithm to diverge or overshoot and oscillate around the solution,
18/// while too low a learning rate can cause the algorithm to converge too slowly.
19///
20/// The learning rate can be adaptively changed during the optimization process. For example,
21/// you can decrease the learning rate as the optimization progresses.
22//
23/// # Example
24///
25/// ```rust
26/// use stepwise::{assert_approx_eq, problems::sphere_grad};
27/// use stepwise::{algos::GradientDescent, Driver, fixed_iters};
28///
29/// /// n-sphere gradient
30/// let sphere_grad = |x: &[f64]| x.iter().map(|x| 2.0 * x).collect::<Vec<_>>();
31///
32/// let learning_rate = 0.1;
33/// let initial_estimate = vec![5.5, 6.5];
34///
35/// let gd = GradientDescent::new(learning_rate, initial_estimate, sphere_grad);
36/// let (solved, _step) = fixed_iters(gd, 500)
37/// .solve()
38/// .expect("failed to solve");
39///
40/// assert_approx_eq!(solved.x(), [0.0, 0.0].as_slice());
41/// ```
42///
43/// # Example
44/// The learning rate is adaptively changed during the optimization process.
45///
46/// ```rust
47/// # use stepwise::{
48/// # assert_approx_eq, assert_approx_ne, problems::sphere_grad,
49/// # algos::GradientDescent, Driver, fixed_iters,
50/// # };
51/// # fn main() {
52/// # let x0 = vec![5.55, 5.55];
53/// let gd = GradientDescent::new(0.1, x0, sphere_grad);
54///
55/// let (solved, _step) = fixed_iters(gd, 200)
56/// .on_step(|algo, _step| algo.learning_rate *= 0.99 )
57/// .solve()
58/// .expect("failed to solve");
59///
60/// # let x = solved.x();
61/// # assert_approx_eq!([0.0, 0.0].as_slice(), &x, 0.01);
62/// # }
63/// ```
64#[derive(Debug, Clone, PartialEq)]
65pub struct GradientDescent<G> {
66 pub gradient_fn: G,
67 pub gradient: Vec<f64>,
68 pub x: Vec<f64>,
69 pub learning_rate: f64,
70}
71
72impl<G> GradientDescent<G>
73where
74 G: FnMut(&[f64]) -> Vec<f64>,
75{
76 /// Create a new `GradientDescent` algo with gradient_fn and initial guess x0.
77 /// The learning rate can be changed during the optimization process.
78 pub fn new(learning_rate: f64, x0: Vec<f64>, gradient_fn: G) -> Self {
79 Self {
80 gradient_fn,
81 gradient: vec![0.0; x0.len()],
82 x: x0,
83 learning_rate,
84 }
85 }
86
87 pub fn x(&self) -> &[f64] {
88 &self.x
89 }
90
91 pub fn update_gradient(&mut self) {
92 self.gradient = (self.gradient_fn)(&self.x);
93 #[allow(clippy::needless_range_loop)]
94 for i in 0..self.x.len() {
95 self.x[i] -= self.learning_rate * self.gradient[i];
96 }
97 }
98}
99
100/// Implement the `Algo` trait for `GradientDescent`.
101impl<G> Algo for GradientDescent<G>
102where
103 G: FnMut(&[f64]) -> Vec<f64>,
104{
105 type Error = Infallible;
106
107 fn step(&mut self) -> (ControlFlow<()>, Result<(), Infallible>) {
108 self.update_gradient();
109 // allow stepwise Driver to decide when to cease iteration
110 (ControlFlow::Continue(()), Ok(()))
111 }
112}
113
114impl<G> Metric<&GradientDescent<G>> for GradL2Norm {
115 type Output = f64;
116
117 fn observe_opt(&mut self, algo: &GradientDescent<G>) -> Option<Self::Output> {
118 Some(algo.gradient.norm_l2())
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use crate::{assert_approx_eq, fixed_iters, problems::sphere_grad};
126 use std::error::Error;
127
128 #[test]
129 fn test_doc_gradient_descent() {
130 let sphere_grad = |x: &[f64]| x.iter().map(|x| 2.0 * x).collect::<Vec<_>>();
131
132 let learning_rate = 0.1;
133 let initial_estimate = vec![5.55, 5.55];
134 let gd = GradientDescent::new(learning_rate, initial_estimate, sphere_grad);
135 let (solved, _step) = fixed_iters(gd, 500).solve().expect("failed to solve");
136 assert_approx_eq!(solved.x(), &[0.0, 0.0]);
137 }
138
139 #[test]
140 fn gradient_descent_core() -> Result<(), Box<dyn Error>> {
141 let gd = GradientDescent::new(0.01, vec![5.0, 5.0], sphere_grad);
142 let driver = fixed_iters(gd, 1000);
143 let (solved, _step) = driver
144 .on_step(|v, s| {
145 if s.iteration() % 100 == 0 {
146 println!("{s:?} x: {:.9?}", v.x())
147 }
148 })
149 .solve()?;
150
151 let x = solved.x();
152 assert_approx_eq!(x, &[0.0, 0.0], 1e-5);
153 Ok(())
154 }
155}