Skip to main content

scry_learn/neural/
optimizer.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Optimizers for neural network training.
3//!
4//! Provides SGD with Nesterov momentum and Adam, matching sklearn defaults.
5
6/// Available optimizer algorithms.
7#[derive(Debug, Clone, Copy, PartialEq)]
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[non_exhaustive]
10pub enum OptimizerKind {
11    /// Stochastic gradient descent with optional Nesterov momentum.
12    Sgd {
13        /// Momentum coefficient (0.0 = no momentum). Default: 0.9.
14        momentum: f64,
15        /// Use Nesterov accelerated gradient. Default: true.
16        nesterov: bool,
17    },
18    /// Adaptive moment estimation (Adam).
19    ///
20    /// Defaults: β₁=0.9, β₂=0.999, ε=1e-8.
21    Adam {
22        /// Exponential decay rate for first moment. Default: 0.9.
23        beta1: f64,
24        /// Exponential decay rate for second moment. Default: 0.999.
25        beta2: f64,
26        /// Small constant for numerical stability. Default: 1e-8.
27        epsilon: f64,
28    },
29}
30
31impl Default for OptimizerKind {
32    fn default() -> Self {
33        Self::Adam {
34            beta1: crate::constants::ADAM_BETA1,
35            beta2: crate::constants::ADAM_BETA2,
36            epsilon: crate::constants::ADAM_EPSILON,
37        }
38    }
39}
40
41impl OptimizerKind {
42    /// SGD with default momentum (0.9, Nesterov).
43    pub fn sgd() -> Self {
44        Self::Sgd {
45            momentum: crate::constants::SGD_MOMENTUM,
46            nesterov: true,
47        }
48    }
49}
50
51/// Learning rate schedule for neural network training.
52///
53/// Controls how the learning rate changes over epochs.
54#[derive(Debug, Clone, Copy, PartialEq, Default)]
55#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
56#[non_exhaustive]
57pub enum LearningRateSchedule {
58    /// Fixed learning rate throughout training. Default.
59    #[default]
60    Constant,
61    /// Reduce learning rate by `factor` when loss plateaus for `patience` epochs.
62    ///
63    /// Matches sklearn's `learning_rate='adaptive'` behavior.
64    Adaptive {
65        /// Multiplicative factor to reduce LR (default: 0.2).
66        factor: f64,
67        /// Number of plateau epochs before reducing (default: 10).
68        patience: usize,
69    },
70    /// Inverse scaling: `lr(t) = initial_lr / t^power`.
71    InvScaling {
72        /// Exponent for inverse scaling (default: 0.5).
73        power: f64,
74    },
75}
76
77impl LearningRateSchedule {
78    /// Adaptive schedule with sklearn-like defaults (factor=0.2, patience=10).
79    pub fn adaptive() -> Self {
80        Self::Adaptive {
81            factor: 0.2,
82            patience: 10,
83        }
84    }
85}
86
87/// Per-parameter optimizer state.
88///
89/// Tracks the moving averages needed by each optimizer algorithm.
90pub(crate) struct OptimizerState {
91    kind: OptimizerKind,
92    lr: f64,
93    initial_lr: f64,
94    t: u64,
95    // SGD momentum buffers (one per parameter group)
96    velocity: Vec<Vec<f64>>,
97    // Adam first moment (mean)
98    m: Vec<Vec<f64>>,
99    // Adam second moment (variance)
100    v: Vec<Vec<f64>>,
101    // ── Learning rate schedule state ──
102    schedule: LearningRateSchedule,
103    best_loss: f64,
104    plateau_count: usize,
105    epoch_count: usize,
106}
107
108impl OptimizerState {
109    /// Create a new optimizer state for `n_groups` parameter groups,
110    /// each with the given sizes.
111    pub fn new(kind: OptimizerKind, lr: f64, group_sizes: &[usize]) -> Self {
112        Self::new_with_schedule(kind, lr, group_sizes, LearningRateSchedule::Constant)
113    }
114
115    /// Create a new optimizer state with a learning rate schedule.
116    pub fn new_with_schedule(
117        kind: OptimizerKind,
118        lr: f64,
119        group_sizes: &[usize],
120        schedule: LearningRateSchedule,
121    ) -> Self {
122        let n = group_sizes.len();
123        let zeros =
124            |sizes: &[usize]| -> Vec<Vec<f64>> { sizes.iter().map(|&s| vec![0.0; s]).collect() };
125
126        Self {
127            kind,
128            lr,
129            initial_lr: lr,
130            t: 0,
131            velocity: zeros(group_sizes),
132            m: if matches!(kind, OptimizerKind::Adam { .. }) {
133                zeros(group_sizes)
134            } else {
135                Vec::with_capacity(n)
136            },
137            v: if matches!(kind, OptimizerKind::Adam { .. }) {
138                zeros(group_sizes)
139            } else {
140                Vec::with_capacity(n)
141            },
142            schedule,
143            best_loss: f64::INFINITY,
144            plateau_count: 0,
145            epoch_count: 0,
146        }
147    }
148
149    /// Apply one optimization step to parameter group `idx`.
150    ///
151    /// `params` are modified in-place. `grads` are the computed gradients.
152    pub fn step(&mut self, idx: usize, params: &mut [f64], grads: &[f64]) {
153        debug_assert_eq!(params.len(), grads.len());
154        debug_assert!(idx < self.velocity.len());
155
156        match self.kind {
157            OptimizerKind::Sgd { momentum, nesterov } => {
158                self.step_sgd(idx, params, grads, momentum, nesterov);
159            }
160            OptimizerKind::Adam {
161                beta1,
162                beta2,
163                epsilon,
164            } => {
165                self.step_adam(idx, params, grads, beta1, beta2, epsilon);
166            }
167        }
168    }
169
170    /// Increment the global step counter. Call once per mini-batch.
171    pub fn tick(&mut self) {
172        self.t += 1;
173    }
174
175    /// Current learning rate (may differ from initial after scheduling).
176    pub fn current_lr(&self) -> f64 {
177        self.lr
178    }
179
180    /// Adjust learning rate based on the schedule after each epoch.
181    ///
182    /// Call this at the end of each epoch with the epoch's average loss.
183    pub fn adjust_lr(&mut self, epoch_loss: f64) {
184        self.epoch_count += 1;
185
186        match self.schedule {
187            LearningRateSchedule::Constant => {}
188            LearningRateSchedule::Adaptive { factor, patience } => {
189                if epoch_loss < self.best_loss - 1e-10 {
190                    self.best_loss = epoch_loss;
191                    self.plateau_count = 0;
192                } else {
193                    self.plateau_count += 1;
194                    if self.plateau_count >= patience {
195                        self.lr *= factor;
196                        self.plateau_count = 0;
197                        self.best_loss = epoch_loss;
198                    }
199                }
200            }
201            LearningRateSchedule::InvScaling { power } => {
202                self.lr = self.initial_lr / (self.epoch_count as f64).powf(power);
203            }
204        }
205    }
206
207    fn step_sgd(
208        &mut self,
209        idx: usize,
210        params: &mut [f64],
211        grads: &[f64],
212        momentum: f64,
213        nesterov: bool,
214    ) {
215        let vel = &mut self.velocity[idx];
216        let lr = self.lr;
217
218        if momentum == 0.0 {
219            for (p, g) in params.iter_mut().zip(grads.iter()) {
220                *p -= lr * g;
221            }
222        } else if nesterov {
223            for i in 0..params.len() {
224                vel[i] = momentum * vel[i] + grads[i];
225                params[i] -= lr * (grads[i] + momentum * vel[i]);
226            }
227        } else {
228            for i in 0..params.len() {
229                vel[i] = momentum * vel[i] + grads[i];
230                params[i] -= lr * vel[i];
231            }
232        }
233    }
234
235    fn step_adam(
236        &mut self,
237        idx: usize,
238        params: &mut [f64],
239        grads: &[f64],
240        beta1: f64,
241        beta2: f64,
242        epsilon: f64,
243    ) {
244        let lr = self.lr;
245        let t = self.t.max(1) as f64;
246        let m = &mut self.m[idx];
247        let v = &mut self.v[idx];
248
249        // Bias correction
250        let bc1 = 1.0 - beta1.powf(t);
251        let bc2 = 1.0 - beta2.powf(t);
252
253        for i in 0..params.len() {
254            // Update biased first moment
255            m[i] = beta1 * m[i] + (1.0 - beta1) * grads[i];
256            // Update biased second moment
257            v[i] = beta2 * v[i] + (1.0 - beta2) * grads[i] * grads[i];
258            // Bias-corrected estimates
259            let m_hat = m[i] / bc1;
260            let v_hat = v[i] / bc2;
261            // Parameter update
262            params[i] -= lr * m_hat / (v_hat.sqrt() + epsilon);
263        }
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn sgd_no_momentum() {
273        let kind = OptimizerKind::Sgd {
274            momentum: 0.0,
275            nesterov: false,
276        };
277        let mut opt = OptimizerState::new(kind, 0.1, &[3]);
278        let mut params = vec![1.0, 2.0, 3.0];
279        let grads = vec![0.5, -0.5, 1.0];
280        opt.tick();
281        opt.step(0, &mut params, &grads);
282        assert!((params[0] - 0.95).abs() < 1e-10);
283        assert!((params[1] - 2.05).abs() < 1e-10);
284        assert!((params[2] - 2.9).abs() < 1e-10);
285    }
286
287    #[test]
288    fn sgd_with_momentum() {
289        let kind = OptimizerKind::Sgd {
290            momentum: 0.9,
291            nesterov: false,
292        };
293        let mut opt = OptimizerState::new(kind, 0.01, &[2]);
294        let mut params = vec![1.0, 2.0];
295        let grads = vec![1.0, -1.0];
296        opt.tick();
297        opt.step(0, &mut params, &grads);
298        // velocity = 0.9*0 + 1.0 = 1.0, param = 1.0 - 0.01*1.0 = 0.99
299        assert!((params[0] - 0.99).abs() < 1e-10);
300        assert!((params[1] - 2.01).abs() < 1e-10);
301    }
302
303    #[test]
304    fn adam_basic() {
305        let kind = OptimizerKind::default(); // Adam
306        let mut opt = OptimizerState::new(kind, 0.001, &[2]);
307        let mut params = vec![1.0, 2.0];
308        let grads = vec![0.5, -0.5];
309        opt.tick();
310        opt.step(0, &mut params, &grads);
311        // After one step, params should have moved toward zero gradient
312        assert!(params[0] < 1.0);
313        assert!(params[1] > 2.0);
314    }
315
316    #[test]
317    fn adam_converges_toward_minimum() {
318        // Minimize f(x) = x^2, gradient = 2x
319        let kind = OptimizerKind::default();
320        let mut opt = OptimizerState::new(kind, 0.1, &[1]);
321        let mut params = vec![5.0];
322
323        for _ in 0..500 {
324            let grads = vec![2.0 * params[0]];
325            opt.tick();
326            opt.step(0, &mut params, &grads);
327        }
328        assert!(
329            params[0].abs() < 0.1,
330            "should converge near 0, got {}",
331            params[0]
332        );
333    }
334
335    #[test]
336    fn multiple_groups() {
337        let kind = OptimizerKind::default();
338        let mut opt = OptimizerState::new(kind, 0.001, &[3, 2]);
339        let mut p1 = vec![1.0, 2.0, 3.0];
340        let mut p2 = vec![4.0, 5.0];
341        let g1 = vec![0.1, 0.2, 0.3];
342        let g2 = vec![0.4, 0.5];
343        opt.tick();
344        opt.step(0, &mut p1, &g1);
345        opt.step(1, &mut p2, &g2);
346        // Just verify no panic and params changed
347        assert!(p1[0] < 1.0);
348        assert!(p2[0] < 4.0);
349    }
350}