ruvector_attention/training/
optimizer.rs1pub trait Optimizer: Send + Sync {
7 fn step(&mut self, params: &mut [f32], gradients: &[f32]);
9
10 fn reset(&mut self);
12
13 fn learning_rate(&self) -> f32;
15
16 fn set_learning_rate(&mut self, lr: f32);
18}
19
20pub struct SGD {
22 lr: f32,
23 momentum: f32,
24 weight_decay: f32,
25 velocity: Vec<f32>,
26 nesterov: bool,
27}
28
29impl SGD {
30 pub fn new(dim: usize, lr: f32) -> Self {
31 Self {
32 lr,
33 momentum: 0.0,
34 weight_decay: 0.0,
35 velocity: vec![0.0; dim],
36 nesterov: false,
37 }
38 }
39
40 pub fn with_momentum(mut self, momentum: f32) -> Self {
41 self.momentum = momentum;
42 self
43 }
44
45 pub fn with_weight_decay(mut self, wd: f32) -> Self {
46 self.weight_decay = wd;
47 self
48 }
49
50 pub fn with_nesterov(mut self, nesterov: bool) -> Self {
51 self.nesterov = nesterov;
52 self
53 }
54}
55
56impl Optimizer for SGD {
57 fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
58 if self.velocity.len() != params.len() {
59 self.velocity = vec![0.0; params.len()];
60 }
61
62 for i in 0..params.len() {
63 let mut g = gradients[i];
64
65 if self.weight_decay > 0.0 {
67 g += self.weight_decay * params[i];
68 }
69
70 self.velocity[i] = self.momentum * self.velocity[i] + g;
72
73 if self.nesterov {
75 params[i] -= self.lr * (g + self.momentum * self.velocity[i]);
76 } else {
77 params[i] -= self.lr * self.velocity[i];
78 }
79 }
80 }
81
82 fn reset(&mut self) {
83 self.velocity.fill(0.0);
84 }
85
86 fn learning_rate(&self) -> f32 {
87 self.lr
88 }
89
90 fn set_learning_rate(&mut self, lr: f32) {
91 self.lr = lr;
92 }
93}
94
95pub struct Adam {
97 lr: f32,
98 beta1: f32,
99 beta2: f32,
100 epsilon: f32,
101 weight_decay: f32,
102 m: Vec<f32>, v: Vec<f32>, t: usize, }
106
107impl Adam {
108 pub fn new(dim: usize, lr: f32) -> Self {
109 Self {
110 lr,
111 beta1: 0.9,
112 beta2: 0.999,
113 epsilon: 1e-8,
114 weight_decay: 0.0,
115 m: vec![0.0; dim],
116 v: vec![0.0; dim],
117 t: 0,
118 }
119 }
120
121 pub fn with_betas(mut self, beta1: f32, beta2: f32) -> Self {
122 self.beta1 = beta1;
123 self.beta2 = beta2;
124 self
125 }
126
127 pub fn with_epsilon(mut self, eps: f32) -> Self {
128 self.epsilon = eps;
129 self
130 }
131
132 pub fn with_weight_decay(mut self, wd: f32) -> Self {
133 self.weight_decay = wd;
134 self
135 }
136}
137
138impl Optimizer for Adam {
139 fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
140 if self.m.len() != params.len() {
141 self.m = vec![0.0; params.len()];
142 self.v = vec![0.0; params.len()];
143 }
144
145 self.t += 1;
146 let bias_correction1 = 1.0 - self.beta1.powi(self.t as i32);
147 let bias_correction2 = 1.0 - self.beta2.powi(self.t as i32);
148
149 for i in 0..params.len() {
150 let g = gradients[i];
151
152 self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
154 self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
155
156 let m_hat = self.m[i] / bias_correction1;
158 let v_hat = self.v[i] / bias_correction2;
159
160 let update = m_hat / (v_hat.sqrt() + self.epsilon);
162 params[i] -= self.lr * (update + self.weight_decay * params[i]);
163 }
164 }
165
166 fn reset(&mut self) {
167 self.m.fill(0.0);
168 self.v.fill(0.0);
169 self.t = 0;
170 }
171
172 fn learning_rate(&self) -> f32 {
173 self.lr
174 }
175
176 fn set_learning_rate(&mut self, lr: f32) {
177 self.lr = lr;
178 }
179}
180
181pub struct AdamW {
183 inner: Adam,
184 weight_decay: f32,
185}
186
187impl AdamW {
188 pub fn new(dim: usize, lr: f32) -> Self {
189 Self {
190 inner: Adam::new(dim, lr),
191 weight_decay: 0.01,
192 }
193 }
194
195 pub fn with_weight_decay(mut self, wd: f32) -> Self {
196 self.weight_decay = wd;
197 self
198 }
199
200 pub fn with_betas(mut self, beta1: f32, beta2: f32) -> Self {
201 self.inner = self.inner.with_betas(beta1, beta2);
202 self
203 }
204}
205
206impl Optimizer for AdamW {
207 fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
208 if self.inner.m.len() != params.len() {
209 self.inner.m = vec![0.0; params.len()];
210 self.inner.v = vec![0.0; params.len()];
211 }
212
213 self.inner.t += 1;
214 let bias_correction1 = 1.0 - self.inner.beta1.powi(self.inner.t as i32);
215 let bias_correction2 = 1.0 - self.inner.beta2.powi(self.inner.t as i32);
216
217 for i in 0..params.len() {
218 let g = gradients[i];
219
220 self.inner.m[i] = self.inner.beta1 * self.inner.m[i] + (1.0 - self.inner.beta1) * g;
222 self.inner.v[i] = self.inner.beta2 * self.inner.v[i] + (1.0 - self.inner.beta2) * g * g;
223
224 let m_hat = self.inner.m[i] / bias_correction1;
226 let v_hat = self.inner.v[i] / bias_correction2;
227
228 params[i] *= 1.0 - self.inner.lr * self.weight_decay;
230
231 params[i] -= self.inner.lr * m_hat / (v_hat.sqrt() + self.inner.epsilon);
233 }
234 }
235
236 fn reset(&mut self) {
237 self.inner.reset();
238 }
239
240 fn learning_rate(&self) -> f32 {
241 self.inner.lr
242 }
243
244 fn set_learning_rate(&mut self, lr: f32) {
245 self.inner.lr = lr;
246 }
247}
248
249pub struct LearningRateScheduler {
251 initial_lr: f32,
252 warmup_steps: usize,
253 decay_steps: usize,
254 min_lr: f32,
255 current_step: usize,
256}
257
258impl LearningRateScheduler {
259 pub fn new(initial_lr: f32) -> Self {
260 Self {
261 initial_lr,
262 warmup_steps: 0,
263 decay_steps: 100000,
264 min_lr: 1e-7,
265 current_step: 0,
266 }
267 }
268
269 pub fn with_warmup(mut self, steps: usize) -> Self {
270 self.warmup_steps = steps;
271 self
272 }
273
274 pub fn with_decay(mut self, steps: usize) -> Self {
275 self.decay_steps = steps;
276 self
277 }
278
279 pub fn with_min_lr(mut self, min_lr: f32) -> Self {
280 self.min_lr = min_lr;
281 self
282 }
283
284 pub fn step(&mut self) -> f32 {
286 let lr = self.get_lr();
287 self.current_step += 1;
288 lr
289 }
290
291 pub fn get_lr(&self) -> f32 {
293 if self.current_step < self.warmup_steps {
294 self.initial_lr * (self.current_step + 1) as f32 / self.warmup_steps as f32
296 } else {
297 let progress = (self.current_step - self.warmup_steps) as f32 / self.decay_steps as f32;
299 let decay = 0.5 * (1.0 + (std::f32::consts::PI * progress.min(1.0)).cos());
300 self.min_lr + (self.initial_lr - self.min_lr) * decay
301 }
302 }
303
304 pub fn reset(&mut self) {
306 self.current_step = 0;
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_sgd() {
316 let mut opt = SGD::new(4, 0.1);
317 let mut params = vec![1.0, 2.0, 3.0, 4.0];
318 let gradients = vec![0.1, 0.2, 0.3, 0.4];
319
320 opt.step(&mut params, &gradients);
321
322 assert!(params[0] < 1.0);
323 assert!(params[1] < 2.0);
324 }
325
326 #[test]
327 fn test_sgd_momentum() {
328 let mut opt = SGD::new(4, 0.1).with_momentum(0.9);
329 let mut params = vec![1.0; 4];
330 let gradients = vec![1.0; 4];
331
332 for _ in 0..5 {
334 opt.step(&mut params, &gradients);
335 }
336
337 assert!(params[0] < 0.0);
338 }
339
340 #[test]
341 fn test_adam() {
342 let mut opt = Adam::new(64, 0.001);
343 let mut params = vec![0.5; 64];
344 let gradients = vec![0.1; 64];
345
346 for _ in 0..100 {
347 opt.step(&mut params, &gradients);
348 }
349
350 assert!(params[0] < 0.5);
352 }
353
354 #[test]
355 fn test_adamw() {
356 let mut opt = AdamW::new(32, 0.001).with_weight_decay(0.01);
357 let mut params = vec![1.0; 32];
358 let gradients = vec![0.0; 32]; for _ in 0..100 {
361 opt.step(&mut params, &gradients);
362 }
363
364 assert!(params[0] < 1.0);
366 }
367
368 #[test]
369 fn test_lr_scheduler_warmup() {
370 let mut scheduler = LearningRateScheduler::new(0.001).with_warmup(100);
371
372 let lr_start = scheduler.step();
373 assert!(lr_start < 0.001); for _ in 0..99 {
376 scheduler.step();
377 }
378
379 let lr_end_warmup = scheduler.get_lr();
380 assert!((lr_end_warmup - 0.001).abs() < 1e-5);
381 }
382
383 #[test]
384 fn test_lr_scheduler_decay() {
385 let mut scheduler = LearningRateScheduler::new(0.001)
386 .with_warmup(0)
387 .with_decay(100)
388 .with_min_lr(0.0001);
389
390 let lr_start = scheduler.step();
391 assert!((lr_start - 0.001).abs() < 1e-5);
392
393 for _ in 0..100 {
394 scheduler.step();
395 }
396
397 let lr_end = scheduler.get_lr();
398 assert!((lr_end - 0.0001).abs() < 1e-5);
399 }
400}