Skip to main content

rust_mlp/
optim.rs

1//! Optimizers.
2//!
3//! This module provides small, allocation-free-per-step optimizers that update an `Mlp`
4//! given a set of `Gradients`.
5//!
6//! Design notes:
7//! - Optimizer *state* (momentum/Adam moments) lives outside the model.
8//! - The training loop owns the optimizer state and reuses it across steps.
9
10use crate::{Error, Gradients, Mlp, Result};
11
12#[derive(Debug, Clone, Copy, PartialEq, Default)]
13/// Optimizer choice for training.
14pub enum Optimizer {
15    /// Plain SGD.
16    #[default]
17    Sgd,
18    /// SGD with momentum.
19    SgdMomentum { momentum: f32 },
20    /// Adam (bias-corrected).
21    Adam { beta1: f32, beta2: f32, eps: f32 },
22}
23
24impl Optimizer {
25    /// Validate optimizer hyperparameters.
26    pub fn validate(self) -> Result<()> {
27        match self {
28            Optimizer::Sgd => Ok(()),
29            Optimizer::SgdMomentum { momentum } => {
30                if !(momentum.is_finite() && (0.0..1.0).contains(&momentum)) {
31                    return Err(Error::InvalidConfig(format!(
32                        "momentum must be finite and in [0,1), got {momentum}"
33                    )));
34                }
35                Ok(())
36            }
37            Optimizer::Adam { beta1, beta2, eps } => {
38                if !(beta1.is_finite() && (0.0..1.0).contains(&beta1)) {
39                    return Err(Error::InvalidConfig(format!(
40                        "adam beta1 must be finite and in [0,1), got {beta1}"
41                    )));
42                }
43                if !(beta2.is_finite() && (0.0..1.0).contains(&beta2)) {
44                    return Err(Error::InvalidConfig(format!(
45                        "adam beta2 must be finite and in [0,1), got {beta2}"
46                    )));
47                }
48                if !(eps.is_finite() && eps > 0.0) {
49                    return Err(Error::InvalidConfig(format!(
50                        "adam eps must be finite and > 0, got {eps}"
51                    )));
52                }
53                Ok(())
54            }
55        }
56    }
57
58    /// Allocate optimizer state for `model`.
59    pub fn state(self, model: &Mlp) -> Result<OptimizerState> {
60        self.validate()?;
61
62        match self {
63            Optimizer::Sgd => Ok(OptimizerState::Sgd),
64            Optimizer::SgdMomentum { momentum } => {
65                let (vw, vb) = zeros_like_params(model);
66                Ok(OptimizerState::SgdMomentum {
67                    momentum,
68                    v_weights: vw,
69                    v_biases: vb,
70                })
71            }
72            Optimizer::Adam { beta1, beta2, eps } => {
73                let (mw, mb) = zeros_like_params(model);
74                let (vw, vb) = zeros_like_params(model);
75                Ok(OptimizerState::Adam {
76                    beta1,
77                    beta2,
78                    eps,
79                    t: 0,
80                    beta1_pow: 1.0,
81                    beta2_pow: 1.0,
82                    m_weights: mw,
83                    m_biases: mb,
84                    v_weights: vw,
85                    v_biases: vb,
86                })
87            }
88        }
89    }
90}
91
92#[derive(Debug, Clone, Default)]
93/// Owned optimizer state.
94pub enum OptimizerState {
95    /// Plain SGD (no state).
96    #[default]
97    Sgd,
98    /// SGD with momentum state.
99    SgdMomentum {
100        momentum: f32,
101        v_weights: Vec<Vec<f32>>,
102        v_biases: Vec<Vec<f32>>,
103    },
104    /// Adam state.
105    Adam {
106        beta1: f32,
107        beta2: f32,
108        eps: f32,
109        t: u64,
110        beta1_pow: f32,
111        beta2_pow: f32,
112        m_weights: Vec<Vec<f32>>,
113        m_biases: Vec<Vec<f32>>,
114        v_weights: Vec<Vec<f32>>,
115        v_biases: Vec<Vec<f32>>,
116    },
117}
118
119impl OptimizerState {
120    /// Apply one optimizer step.
121    ///
122    /// `lr` is passed in from the training loop to support learning rate schedules.
123    pub fn step(&mut self, model: &mut Mlp, grads: &mut Gradients, lr: f32) {
124        assert!(lr.is_finite() && lr > 0.0, "lr must be finite and > 0");
125
126        match self {
127            OptimizerState::Sgd => {
128                model.sgd_step(grads, lr);
129            }
130            OptimizerState::SgdMomentum {
131                momentum,
132                v_weights,
133                v_biases,
134            } => {
135                debug_assert_eq!(v_weights.len(), model.num_layers());
136                debug_assert_eq!(v_biases.len(), model.num_layers());
137
138                for layer_idx in 0..model.num_layers() {
139                    let dw = grads.d_weights(layer_idx);
140                    let db = grads.d_biases(layer_idx);
141
142                    let vw = &mut v_weights[layer_idx];
143                    let vb = &mut v_biases[layer_idx];
144
145                    debug_assert_eq!(vw.len(), dw.len());
146                    debug_assert_eq!(vb.len(), db.len());
147
148                    for (v, &g) in vw.iter_mut().zip(dw) {
149                        *v = (*momentum) * *v + g;
150                    }
151                    for (v, &g) in vb.iter_mut().zip(db) {
152                        *v = (*momentum) * *v + g;
153                    }
154
155                    let layer = model.layer_mut(layer_idx).expect("layer idx must be valid");
156                    layer.sgd_step(vw, vb, lr);
157                }
158            }
159            OptimizerState::Adam {
160                beta1,
161                beta2,
162                eps,
163                t,
164                beta1_pow,
165                beta2_pow,
166                m_weights,
167                m_biases,
168                v_weights,
169                v_biases,
170            } => {
171                *t += 1;
172                *beta1_pow *= *beta1;
173                *beta2_pow *= *beta2;
174
175                let one_minus_beta1 = 1.0 - *beta1;
176                let one_minus_beta2 = 1.0 - *beta2;
177                let corr1 = 1.0 - *beta1_pow;
178                let corr2 = 1.0 - *beta2_pow;
179
180                // Overwrite `grads` with the Adam update direction and then reuse `sgd_step`.
181                for layer_idx in 0..model.num_layers() {
182                    let mw = &mut m_weights[layer_idx];
183                    let mb = &mut m_biases[layer_idx];
184                    let vw = &mut v_weights[layer_idx];
185                    let vb = &mut v_biases[layer_idx];
186
187                    debug_assert_eq!(mw.len(), vw.len());
188                    debug_assert_eq!(mb.len(), vb.len());
189
190                    {
191                        let upd_w = grads.d_weights_mut(layer_idx);
192                        for i in 0..upd_w.len() {
193                            let g = upd_w[i];
194                            mw[i] = (*beta1) * mw[i] + one_minus_beta1 * g;
195                            vw[i] = (*beta2) * vw[i] + one_minus_beta2 * (g * g);
196
197                            let m_hat = mw[i] / corr1;
198                            let v_hat = vw[i] / corr2;
199                            upd_w[i] = m_hat / (v_hat.sqrt() + *eps);
200                        }
201                    }
202                    {
203                        let upd_b = grads.d_biases_mut(layer_idx);
204                        for i in 0..upd_b.len() {
205                            let g = upd_b[i];
206                            mb[i] = (*beta1) * mb[i] + one_minus_beta1 * g;
207                            vb[i] = (*beta2) * vb[i] + one_minus_beta2 * (g * g);
208
209                            let m_hat = mb[i] / corr1;
210                            let v_hat = vb[i] / corr2;
211                            upd_b[i] = m_hat / (v_hat.sqrt() + *eps);
212                        }
213                    }
214                }
215
216                model.sgd_step(grads, lr);
217            }
218        }
219    }
220}
221
222#[derive(Debug, Clone, Copy)]
223/// Stochastic gradient descent with a fixed learning rate.
224pub struct Sgd {
225    lr: f32,
226}
227
228impl Sgd {
229    #[inline]
230    /// Construct an SGD optimizer.
231    ///
232    /// Returns an error if `lr` is not finite or `lr <= 0`.
233    pub fn new(lr: f32) -> Result<Self> {
234        if !(lr.is_finite() && lr > 0.0) {
235            return Err(Error::InvalidConfig(
236                "learning rate must be finite and > 0".to_owned(),
237            ));
238        }
239        Ok(Self { lr })
240    }
241
242    #[inline]
243    /// Returns the learning rate.
244    pub fn lr(&self) -> f32 {
245        self.lr
246    }
247
248    #[inline]
249    /// Apply one optimizer step: `param -= lr * d_param`.
250    pub fn step(&self, model: &mut Mlp, grads: &Gradients) {
251        model.sgd_step(grads, self.lr);
252    }
253}
254
255fn zeros_like_params(model: &Mlp) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
256    let mut ws = Vec::with_capacity(model.num_layers());
257    let mut bs = Vec::with_capacity(model.num_layers());
258    for i in 0..model.num_layers() {
259        let layer = model.layer(i).expect("layer idx must be valid");
260        ws.push(vec![0.0; layer.in_dim() * layer.out_dim()]);
261        bs.push(vec![0.0; layer.out_dim()]);
262    }
263    (ws, bs)
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    use crate::{Activation, MlpBuilder};
271
272    #[test]
273    fn sgd_requires_positive_finite_lr() {
274        assert!(Sgd::new(0.0).is_err());
275        assert!(Sgd::new(-1.0).is_err());
276        assert!(Sgd::new(f32::NAN).is_err());
277    }
278
279    #[test]
280    fn optimizer_validation_rejects_bad_hyperparams() {
281        assert!(Optimizer::SgdMomentum { momentum: 1.0 }.validate().is_err());
282        assert!(
283            Optimizer::SgdMomentum { momentum: -0.1 }
284                .validate()
285                .is_err()
286        );
287        assert!(
288            Optimizer::Adam {
289                beta1: 1.0,
290                beta2: 0.999,
291                eps: 1e-8
292            }
293            .validate()
294            .is_err()
295        );
296        assert!(
297            Optimizer::Adam {
298                beta1: 0.9,
299                beta2: 1.0,
300                eps: 1e-8
301            }
302            .validate()
303            .is_err()
304        );
305        assert!(
306            Optimizer::Adam {
307                beta1: 0.9,
308                beta2: 0.999,
309                eps: 0.0
310            }
311            .validate()
312            .is_err()
313        );
314    }
315
316    #[test]
317    fn sgd_momentum_updates_like_sgd_on_first_step() {
318        let mut mlp = MlpBuilder::new(1)
319            .unwrap()
320            .add_layer(1, Activation::Identity)
321            .unwrap()
322            .build_with_seed(0)
323            .unwrap();
324
325        // Force parameters to known values.
326        {
327            let layer = mlp.layer_mut(0).unwrap();
328            layer.weights_mut()[0] = 1.0;
329            layer.biases_mut()[0] = 2.0;
330        }
331
332        let mut grads = mlp.gradients();
333        grads.d_weights_mut(0)[0] = 3.0;
334        grads.d_biases_mut(0)[0] = 4.0;
335
336        let mut opt = Optimizer::SgdMomentum { momentum: 0.9 }
337            .state(&mlp)
338            .unwrap();
339        opt.step(&mut mlp, &mut grads, 0.1);
340
341        let (w, b) = {
342            let layer = mlp.layer_mut(0).unwrap();
343            (layer.weights_mut()[0], layer.biases_mut()[0])
344        };
345        assert!((w - (1.0 - 0.1 * 3.0)).abs() < 1e-6);
346        assert!((b - (2.0 - 0.1 * 4.0)).abs() < 1e-6);
347    }
348
349    #[test]
350    fn adam_first_step_matches_expected_direction_for_unit_grad() {
351        let mut mlp = MlpBuilder::new(1)
352            .unwrap()
353            .add_layer(1, Activation::Identity)
354            .unwrap()
355            .build_with_seed(0)
356            .unwrap();
357
358        {
359            let layer = mlp.layer_mut(0).unwrap();
360            layer.weights_mut()[0] = 1.0;
361            layer.biases_mut()[0] = 1.0;
362        }
363
364        let mut grads = mlp.gradients();
365        grads.d_weights_mut(0)[0] = 1.0;
366        grads.d_biases_mut(0)[0] = 1.0;
367
368        let mut opt = Optimizer::Adam {
369            beta1: 0.9,
370            beta2: 0.999,
371            eps: 1.0,
372        }
373        .state(&mlp)
374        .unwrap();
375        opt.step(&mut mlp, &mut grads, 0.1);
376
377        // With eps=1.0 and unit grad, the first bias-corrected step has update ~= 1/(1+eps) = 0.5.
378        let (w, b) = {
379            let layer = mlp.layer_mut(0).unwrap();
380            (layer.weights_mut()[0], layer.biases_mut()[0])
381        };
382        assert!((w - (1.0 - 0.1 * 0.5)).abs() < 1e-6);
383        assert!((b - (1.0 - 0.1 * 0.5)).abs() < 1e-6);
384    }
385}