1#[derive(Debug, Clone, Copy, PartialEq)]
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[non_exhaustive]
10pub enum OptimizerKind {
11 Sgd {
13 momentum: f64,
15 nesterov: bool,
17 },
18 Adam {
22 beta1: f64,
24 beta2: f64,
26 epsilon: f64,
28 },
29}
30
31impl Default for OptimizerKind {
32 fn default() -> Self {
33 Self::Adam {
34 beta1: crate::constants::ADAM_BETA1,
35 beta2: crate::constants::ADAM_BETA2,
36 epsilon: crate::constants::ADAM_EPSILON,
37 }
38 }
39}
40
41impl OptimizerKind {
42 pub fn sgd() -> Self {
44 Self::Sgd {
45 momentum: crate::constants::SGD_MOMENTUM,
46 nesterov: true,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Default)]
55#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
56#[non_exhaustive]
57pub enum LearningRateSchedule {
58 #[default]
60 Constant,
61 Adaptive {
65 factor: f64,
67 patience: usize,
69 },
70 InvScaling {
72 power: f64,
74 },
75}
76
77impl LearningRateSchedule {
78 pub fn adaptive() -> Self {
80 Self::Adaptive {
81 factor: 0.2,
82 patience: 10,
83 }
84 }
85}
86
87pub(crate) struct OptimizerState {
91 kind: OptimizerKind,
92 lr: f64,
93 initial_lr: f64,
94 t: u64,
95 velocity: Vec<Vec<f64>>,
97 m: Vec<Vec<f64>>,
99 v: Vec<Vec<f64>>,
101 schedule: LearningRateSchedule,
103 best_loss: f64,
104 plateau_count: usize,
105 epoch_count: usize,
106}
107
108impl OptimizerState {
109 pub fn new(kind: OptimizerKind, lr: f64, group_sizes: &[usize]) -> Self {
112 Self::new_with_schedule(kind, lr, group_sizes, LearningRateSchedule::Constant)
113 }
114
115 pub fn new_with_schedule(
117 kind: OptimizerKind,
118 lr: f64,
119 group_sizes: &[usize],
120 schedule: LearningRateSchedule,
121 ) -> Self {
122 let n = group_sizes.len();
123 let zeros =
124 |sizes: &[usize]| -> Vec<Vec<f64>> { sizes.iter().map(|&s| vec![0.0; s]).collect() };
125
126 Self {
127 kind,
128 lr,
129 initial_lr: lr,
130 t: 0,
131 velocity: zeros(group_sizes),
132 m: if matches!(kind, OptimizerKind::Adam { .. }) {
133 zeros(group_sizes)
134 } else {
135 Vec::with_capacity(n)
136 },
137 v: if matches!(kind, OptimizerKind::Adam { .. }) {
138 zeros(group_sizes)
139 } else {
140 Vec::with_capacity(n)
141 },
142 schedule,
143 best_loss: f64::INFINITY,
144 plateau_count: 0,
145 epoch_count: 0,
146 }
147 }
148
149 pub fn step(&mut self, idx: usize, params: &mut [f64], grads: &[f64]) {
153 debug_assert_eq!(params.len(), grads.len());
154 debug_assert!(idx < self.velocity.len());
155
156 match self.kind {
157 OptimizerKind::Sgd { momentum, nesterov } => {
158 self.step_sgd(idx, params, grads, momentum, nesterov);
159 }
160 OptimizerKind::Adam {
161 beta1,
162 beta2,
163 epsilon,
164 } => {
165 self.step_adam(idx, params, grads, beta1, beta2, epsilon);
166 }
167 }
168 }
169
170 pub fn tick(&mut self) {
172 self.t += 1;
173 }
174
175 pub fn current_lr(&self) -> f64 {
177 self.lr
178 }
179
180 pub fn adjust_lr(&mut self, epoch_loss: f64) {
184 self.epoch_count += 1;
185
186 match self.schedule {
187 LearningRateSchedule::Constant => {}
188 LearningRateSchedule::Adaptive { factor, patience } => {
189 if epoch_loss < self.best_loss - 1e-10 {
190 self.best_loss = epoch_loss;
191 self.plateau_count = 0;
192 } else {
193 self.plateau_count += 1;
194 if self.plateau_count >= patience {
195 self.lr *= factor;
196 self.plateau_count = 0;
197 self.best_loss = epoch_loss;
198 }
199 }
200 }
201 LearningRateSchedule::InvScaling { power } => {
202 self.lr = self.initial_lr / (self.epoch_count as f64).powf(power);
203 }
204 }
205 }
206
207 fn step_sgd(
208 &mut self,
209 idx: usize,
210 params: &mut [f64],
211 grads: &[f64],
212 momentum: f64,
213 nesterov: bool,
214 ) {
215 let vel = &mut self.velocity[idx];
216 let lr = self.lr;
217
218 if momentum == 0.0 {
219 for (p, g) in params.iter_mut().zip(grads.iter()) {
220 *p -= lr * g;
221 }
222 } else if nesterov {
223 for i in 0..params.len() {
224 vel[i] = momentum * vel[i] + grads[i];
225 params[i] -= lr * (grads[i] + momentum * vel[i]);
226 }
227 } else {
228 for i in 0..params.len() {
229 vel[i] = momentum * vel[i] + grads[i];
230 params[i] -= lr * vel[i];
231 }
232 }
233 }
234
235 fn step_adam(
236 &mut self,
237 idx: usize,
238 params: &mut [f64],
239 grads: &[f64],
240 beta1: f64,
241 beta2: f64,
242 epsilon: f64,
243 ) {
244 let lr = self.lr;
245 let t = self.t.max(1) as f64;
246 let m = &mut self.m[idx];
247 let v = &mut self.v[idx];
248
249 let bc1 = 1.0 - beta1.powf(t);
251 let bc2 = 1.0 - beta2.powf(t);
252
253 for i in 0..params.len() {
254 m[i] = beta1 * m[i] + (1.0 - beta1) * grads[i];
256 v[i] = beta2 * v[i] + (1.0 - beta2) * grads[i] * grads[i];
258 let m_hat = m[i] / bc1;
260 let v_hat = v[i] / bc2;
261 params[i] -= lr * m_hat / (v_hat.sqrt() + epsilon);
263 }
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn sgd_no_momentum() {
273 let kind = OptimizerKind::Sgd {
274 momentum: 0.0,
275 nesterov: false,
276 };
277 let mut opt = OptimizerState::new(kind, 0.1, &[3]);
278 let mut params = vec![1.0, 2.0, 3.0];
279 let grads = vec![0.5, -0.5, 1.0];
280 opt.tick();
281 opt.step(0, &mut params, &grads);
282 assert!((params[0] - 0.95).abs() < 1e-10);
283 assert!((params[1] - 2.05).abs() < 1e-10);
284 assert!((params[2] - 2.9).abs() < 1e-10);
285 }
286
287 #[test]
288 fn sgd_with_momentum() {
289 let kind = OptimizerKind::Sgd {
290 momentum: 0.9,
291 nesterov: false,
292 };
293 let mut opt = OptimizerState::new(kind, 0.01, &[2]);
294 let mut params = vec![1.0, 2.0];
295 let grads = vec![1.0, -1.0];
296 opt.tick();
297 opt.step(0, &mut params, &grads);
298 assert!((params[0] - 0.99).abs() < 1e-10);
300 assert!((params[1] - 2.01).abs() < 1e-10);
301 }
302
303 #[test]
304 fn adam_basic() {
305 let kind = OptimizerKind::default(); let mut opt = OptimizerState::new(kind, 0.001, &[2]);
307 let mut params = vec![1.0, 2.0];
308 let grads = vec![0.5, -0.5];
309 opt.tick();
310 opt.step(0, &mut params, &grads);
311 assert!(params[0] < 1.0);
313 assert!(params[1] > 2.0);
314 }
315
316 #[test]
317 fn adam_converges_toward_minimum() {
318 let kind = OptimizerKind::default();
320 let mut opt = OptimizerState::new(kind, 0.1, &[1]);
321 let mut params = vec![5.0];
322
323 for _ in 0..500 {
324 let grads = vec![2.0 * params[0]];
325 opt.tick();
326 opt.step(0, &mut params, &grads);
327 }
328 assert!(
329 params[0].abs() < 0.1,
330 "should converge near 0, got {}",
331 params[0]
332 );
333 }
334
335 #[test]
336 fn multiple_groups() {
337 let kind = OptimizerKind::default();
338 let mut opt = OptimizerState::new(kind, 0.001, &[3, 2]);
339 let mut p1 = vec![1.0, 2.0, 3.0];
340 let mut p2 = vec![4.0, 5.0];
341 let g1 = vec![0.1, 0.2, 0.3];
342 let g2 = vec![0.4, 0.5];
343 opt.tick();
344 opt.step(0, &mut p1, &g1);
345 opt.step(1, &mut p2, &g2);
346 assert!(p1[0] < 1.0);
348 assert!(p2[0] < 4.0);
349 }
350}