1use crate::{Error, Gradients, Mlp, Result};
11
12#[derive(Debug, Clone, Copy, PartialEq, Default)]
13pub enum Optimizer {
15 #[default]
17 Sgd,
18 SgdMomentum { momentum: f32 },
20 Adam { beta1: f32, beta2: f32, eps: f32 },
22}
23
24impl Optimizer {
25 pub fn validate(self) -> Result<()> {
27 match self {
28 Optimizer::Sgd => Ok(()),
29 Optimizer::SgdMomentum { momentum } => {
30 if !(momentum.is_finite() && (0.0..1.0).contains(&momentum)) {
31 return Err(Error::InvalidConfig(format!(
32 "momentum must be finite and in [0,1), got {momentum}"
33 )));
34 }
35 Ok(())
36 }
37 Optimizer::Adam { beta1, beta2, eps } => {
38 if !(beta1.is_finite() && (0.0..1.0).contains(&beta1)) {
39 return Err(Error::InvalidConfig(format!(
40 "adam beta1 must be finite and in [0,1), got {beta1}"
41 )));
42 }
43 if !(beta2.is_finite() && (0.0..1.0).contains(&beta2)) {
44 return Err(Error::InvalidConfig(format!(
45 "adam beta2 must be finite and in [0,1), got {beta2}"
46 )));
47 }
48 if !(eps.is_finite() && eps > 0.0) {
49 return Err(Error::InvalidConfig(format!(
50 "adam eps must be finite and > 0, got {eps}"
51 )));
52 }
53 Ok(())
54 }
55 }
56 }
57
58 pub fn state(self, model: &Mlp) -> Result<OptimizerState> {
60 self.validate()?;
61
62 match self {
63 Optimizer::Sgd => Ok(OptimizerState::Sgd),
64 Optimizer::SgdMomentum { momentum } => {
65 let (vw, vb) = zeros_like_params(model);
66 Ok(OptimizerState::SgdMomentum {
67 momentum,
68 v_weights: vw,
69 v_biases: vb,
70 })
71 }
72 Optimizer::Adam { beta1, beta2, eps } => {
73 let (mw, mb) = zeros_like_params(model);
74 let (vw, vb) = zeros_like_params(model);
75 Ok(OptimizerState::Adam {
76 beta1,
77 beta2,
78 eps,
79 t: 0,
80 beta1_pow: 1.0,
81 beta2_pow: 1.0,
82 m_weights: mw,
83 m_biases: mb,
84 v_weights: vw,
85 v_biases: vb,
86 })
87 }
88 }
89 }
90}
91
92#[derive(Debug, Clone, Default)]
93pub enum OptimizerState {
95 #[default]
97 Sgd,
98 SgdMomentum {
100 momentum: f32,
101 v_weights: Vec<Vec<f32>>,
102 v_biases: Vec<Vec<f32>>,
103 },
104 Adam {
106 beta1: f32,
107 beta2: f32,
108 eps: f32,
109 t: u64,
110 beta1_pow: f32,
111 beta2_pow: f32,
112 m_weights: Vec<Vec<f32>>,
113 m_biases: Vec<Vec<f32>>,
114 v_weights: Vec<Vec<f32>>,
115 v_biases: Vec<Vec<f32>>,
116 },
117}
118
119impl OptimizerState {
120 pub fn step(&mut self, model: &mut Mlp, grads: &mut Gradients, lr: f32) {
124 assert!(lr.is_finite() && lr > 0.0, "lr must be finite and > 0");
125
126 match self {
127 OptimizerState::Sgd => {
128 model.sgd_step(grads, lr);
129 }
130 OptimizerState::SgdMomentum {
131 momentum,
132 v_weights,
133 v_biases,
134 } => {
135 debug_assert_eq!(v_weights.len(), model.num_layers());
136 debug_assert_eq!(v_biases.len(), model.num_layers());
137
138 for layer_idx in 0..model.num_layers() {
139 let dw = grads.d_weights(layer_idx);
140 let db = grads.d_biases(layer_idx);
141
142 let vw = &mut v_weights[layer_idx];
143 let vb = &mut v_biases[layer_idx];
144
145 debug_assert_eq!(vw.len(), dw.len());
146 debug_assert_eq!(vb.len(), db.len());
147
148 for (v, &g) in vw.iter_mut().zip(dw) {
149 *v = (*momentum) * *v + g;
150 }
151 for (v, &g) in vb.iter_mut().zip(db) {
152 *v = (*momentum) * *v + g;
153 }
154
155 let layer = model.layer_mut(layer_idx).expect("layer idx must be valid");
156 layer.sgd_step(vw, vb, lr);
157 }
158 }
159 OptimizerState::Adam {
160 beta1,
161 beta2,
162 eps,
163 t,
164 beta1_pow,
165 beta2_pow,
166 m_weights,
167 m_biases,
168 v_weights,
169 v_biases,
170 } => {
171 *t += 1;
172 *beta1_pow *= *beta1;
173 *beta2_pow *= *beta2;
174
175 let one_minus_beta1 = 1.0 - *beta1;
176 let one_minus_beta2 = 1.0 - *beta2;
177 let corr1 = 1.0 - *beta1_pow;
178 let corr2 = 1.0 - *beta2_pow;
179
180 for layer_idx in 0..model.num_layers() {
182 let mw = &mut m_weights[layer_idx];
183 let mb = &mut m_biases[layer_idx];
184 let vw = &mut v_weights[layer_idx];
185 let vb = &mut v_biases[layer_idx];
186
187 debug_assert_eq!(mw.len(), vw.len());
188 debug_assert_eq!(mb.len(), vb.len());
189
190 {
191 let upd_w = grads.d_weights_mut(layer_idx);
192 for i in 0..upd_w.len() {
193 let g = upd_w[i];
194 mw[i] = (*beta1) * mw[i] + one_minus_beta1 * g;
195 vw[i] = (*beta2) * vw[i] + one_minus_beta2 * (g * g);
196
197 let m_hat = mw[i] / corr1;
198 let v_hat = vw[i] / corr2;
199 upd_w[i] = m_hat / (v_hat.sqrt() + *eps);
200 }
201 }
202 {
203 let upd_b = grads.d_biases_mut(layer_idx);
204 for i in 0..upd_b.len() {
205 let g = upd_b[i];
206 mb[i] = (*beta1) * mb[i] + one_minus_beta1 * g;
207 vb[i] = (*beta2) * vb[i] + one_minus_beta2 * (g * g);
208
209 let m_hat = mb[i] / corr1;
210 let v_hat = vb[i] / corr2;
211 upd_b[i] = m_hat / (v_hat.sqrt() + *eps);
212 }
213 }
214 }
215
216 model.sgd_step(grads, lr);
217 }
218 }
219 }
220}
221
222#[derive(Debug, Clone, Copy)]
223pub struct Sgd {
225 lr: f32,
226}
227
228impl Sgd {
229 #[inline]
230 pub fn new(lr: f32) -> Result<Self> {
234 if !(lr.is_finite() && lr > 0.0) {
235 return Err(Error::InvalidConfig(
236 "learning rate must be finite and > 0".to_owned(),
237 ));
238 }
239 Ok(Self { lr })
240 }
241
242 #[inline]
243 pub fn lr(&self) -> f32 {
245 self.lr
246 }
247
248 #[inline]
249 pub fn step(&self, model: &mut Mlp, grads: &Gradients) {
251 model.sgd_step(grads, self.lr);
252 }
253}
254
255fn zeros_like_params(model: &Mlp) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
256 let mut ws = Vec::with_capacity(model.num_layers());
257 let mut bs = Vec::with_capacity(model.num_layers());
258 for i in 0..model.num_layers() {
259 let layer = model.layer(i).expect("layer idx must be valid");
260 ws.push(vec![0.0; layer.in_dim() * layer.out_dim()]);
261 bs.push(vec![0.0; layer.out_dim()]);
262 }
263 (ws, bs)
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 use crate::{Activation, MlpBuilder};
271
272 #[test]
273 fn sgd_requires_positive_finite_lr() {
274 assert!(Sgd::new(0.0).is_err());
275 assert!(Sgd::new(-1.0).is_err());
276 assert!(Sgd::new(f32::NAN).is_err());
277 }
278
279 #[test]
280 fn optimizer_validation_rejects_bad_hyperparams() {
281 assert!(Optimizer::SgdMomentum { momentum: 1.0 }.validate().is_err());
282 assert!(
283 Optimizer::SgdMomentum { momentum: -0.1 }
284 .validate()
285 .is_err()
286 );
287 assert!(
288 Optimizer::Adam {
289 beta1: 1.0,
290 beta2: 0.999,
291 eps: 1e-8
292 }
293 .validate()
294 .is_err()
295 );
296 assert!(
297 Optimizer::Adam {
298 beta1: 0.9,
299 beta2: 1.0,
300 eps: 1e-8
301 }
302 .validate()
303 .is_err()
304 );
305 assert!(
306 Optimizer::Adam {
307 beta1: 0.9,
308 beta2: 0.999,
309 eps: 0.0
310 }
311 .validate()
312 .is_err()
313 );
314 }
315
316 #[test]
317 fn sgd_momentum_updates_like_sgd_on_first_step() {
318 let mut mlp = MlpBuilder::new(1)
319 .unwrap()
320 .add_layer(1, Activation::Identity)
321 .unwrap()
322 .build_with_seed(0)
323 .unwrap();
324
325 {
327 let layer = mlp.layer_mut(0).unwrap();
328 layer.weights_mut()[0] = 1.0;
329 layer.biases_mut()[0] = 2.0;
330 }
331
332 let mut grads = mlp.gradients();
333 grads.d_weights_mut(0)[0] = 3.0;
334 grads.d_biases_mut(0)[0] = 4.0;
335
336 let mut opt = Optimizer::SgdMomentum { momentum: 0.9 }
337 .state(&mlp)
338 .unwrap();
339 opt.step(&mut mlp, &mut grads, 0.1);
340
341 let (w, b) = {
342 let layer = mlp.layer_mut(0).unwrap();
343 (layer.weights_mut()[0], layer.biases_mut()[0])
344 };
345 assert!((w - (1.0 - 0.1 * 3.0)).abs() < 1e-6);
346 assert!((b - (2.0 - 0.1 * 4.0)).abs() < 1e-6);
347 }
348
349 #[test]
350 fn adam_first_step_matches_expected_direction_for_unit_grad() {
351 let mut mlp = MlpBuilder::new(1)
352 .unwrap()
353 .add_layer(1, Activation::Identity)
354 .unwrap()
355 .build_with_seed(0)
356 .unwrap();
357
358 {
359 let layer = mlp.layer_mut(0).unwrap();
360 layer.weights_mut()[0] = 1.0;
361 layer.biases_mut()[0] = 1.0;
362 }
363
364 let mut grads = mlp.gradients();
365 grads.d_weights_mut(0)[0] = 1.0;
366 grads.d_biases_mut(0)[0] = 1.0;
367
368 let mut opt = Optimizer::Adam {
369 beta1: 0.9,
370 beta2: 0.999,
371 eps: 1.0,
372 }
373 .state(&mlp)
374 .unwrap();
375 opt.step(&mut mlp, &mut grads, 0.1);
376
377 let (w, b) = {
379 let layer = mlp.layer_mut(0).unwrap();
380 (layer.weights_mut()[0], layer.biases_mut()[0])
381 };
382 assert!((w - (1.0 - 0.1 * 0.5)).abs() < 1e-6);
383 assert!((b - (1.0 - 0.1 * 0.5)).abs() < 1e-6);
384 }
385}