ruvector_attention/training/
optimizer.rs1pub trait Optimizer: Send + Sync {
7 fn step(&mut self, params: &mut [f32], gradients: &[f32]);
9
10 fn reset(&mut self);
12
13 fn learning_rate(&self) -> f32;
15
16 fn set_learning_rate(&mut self, lr: f32);
18}
19
20pub struct SGD {
22 lr: f32,
23 momentum: f32,
24 weight_decay: f32,
25 velocity: Vec<f32>,
26 nesterov: bool,
27}
28
29impl SGD {
30 pub fn new(dim: usize, lr: f32) -> Self {
31 Self {
32 lr,
33 momentum: 0.0,
34 weight_decay: 0.0,
35 velocity: vec![0.0; dim],
36 nesterov: false,
37 }
38 }
39
40 pub fn with_momentum(mut self, momentum: f32) -> Self {
41 self.momentum = momentum;
42 self
43 }
44
45 pub fn with_weight_decay(mut self, wd: f32) -> Self {
46 self.weight_decay = wd;
47 self
48 }
49
50 pub fn with_nesterov(mut self, nesterov: bool) -> Self {
51 self.nesterov = nesterov;
52 self
53 }
54}
55
56impl Optimizer for SGD {
57 fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
58 if self.velocity.len() != params.len() {
59 self.velocity = vec![0.0; params.len()];
60 }
61
62 for i in 0..params.len() {
63 let mut g = gradients[i];
64
65 if self.weight_decay > 0.0 {
67 g += self.weight_decay * params[i];
68 }
69
70 self.velocity[i] = self.momentum * self.velocity[i] + g;
72
73 if self.nesterov {
75 params[i] -= self.lr * (g + self.momentum * self.velocity[i]);
76 } else {
77 params[i] -= self.lr * self.velocity[i];
78 }
79 }
80 }
81
82 fn reset(&mut self) {
83 self.velocity.fill(0.0);
84 }
85
86 fn learning_rate(&self) -> f32 {
87 self.lr
88 }
89
90 fn set_learning_rate(&mut self, lr: f32) {
91 self.lr = lr;
92 }
93}
94
95pub struct Adam {
97 lr: f32,
98 beta1: f32,
99 beta2: f32,
100 epsilon: f32,
101 weight_decay: f32,
102 m: Vec<f32>, v: Vec<f32>, t: usize, }
106
107impl Adam {
108 pub fn new(dim: usize, lr: f32) -> Self {
109 Self {
110 lr,
111 beta1: 0.9,
112 beta2: 0.999,
113 epsilon: 1e-8,
114 weight_decay: 0.0,
115 m: vec![0.0; dim],
116 v: vec![0.0; dim],
117 t: 0,
118 }
119 }
120
121 pub fn with_betas(mut self, beta1: f32, beta2: f32) -> Self {
122 self.beta1 = beta1;
123 self.beta2 = beta2;
124 self
125 }
126
127 pub fn with_epsilon(mut self, eps: f32) -> Self {
128 self.epsilon = eps;
129 self
130 }
131
132 pub fn with_weight_decay(mut self, wd: f32) -> Self {
133 self.weight_decay = wd;
134 self
135 }
136}
137
138impl Optimizer for Adam {
139 fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
140 if self.m.len() != params.len() {
141 self.m = vec![0.0; params.len()];
142 self.v = vec![0.0; params.len()];
143 }
144
145 self.t += 1;
146 let bias_correction1 = 1.0 - self.beta1.powi(self.t as i32);
147 let bias_correction2 = 1.0 - self.beta2.powi(self.t as i32);
148
149 for i in 0..params.len() {
150 let g = gradients[i];
151
152 self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
154 self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
155
156 let m_hat = self.m[i] / bias_correction1;
158 let v_hat = self.v[i] / bias_correction2;
159
160 let update = m_hat / (v_hat.sqrt() + self.epsilon);
162 params[i] -= self.lr * (update + self.weight_decay * params[i]);
163 }
164 }
165
166 fn reset(&mut self) {
167 self.m.fill(0.0);
168 self.v.fill(0.0);
169 self.t = 0;
170 }
171
172 fn learning_rate(&self) -> f32 {
173 self.lr
174 }
175
176 fn set_learning_rate(&mut self, lr: f32) {
177 self.lr = lr;
178 }
179}
180
181pub struct AdamW {
183 inner: Adam,
184 weight_decay: f32,
185}
186
187impl AdamW {
188 pub fn new(dim: usize, lr: f32) -> Self {
189 Self {
190 inner: Adam::new(dim, lr),
191 weight_decay: 0.01,
192 }
193 }
194
195 pub fn with_weight_decay(mut self, wd: f32) -> Self {
196 self.weight_decay = wd;
197 self
198 }
199
200 pub fn with_betas(mut self, beta1: f32, beta2: f32) -> Self {
201 self.inner = self.inner.with_betas(beta1, beta2);
202 self
203 }
204}
205
206impl Optimizer for AdamW {
207 fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
208 if self.inner.m.len() != params.len() {
209 self.inner.m = vec![0.0; params.len()];
210 self.inner.v = vec![0.0; params.len()];
211 }
212
213 self.inner.t += 1;
214 let bias_correction1 = 1.0 - self.inner.beta1.powi(self.inner.t as i32);
215 let bias_correction2 = 1.0 - self.inner.beta2.powi(self.inner.t as i32);
216
217 for i in 0..params.len() {
218 let g = gradients[i];
219
220 self.inner.m[i] = self.inner.beta1 * self.inner.m[i] + (1.0 - self.inner.beta1) * g;
222 self.inner.v[i] =
223 self.inner.beta2 * self.inner.v[i] + (1.0 - self.inner.beta2) * g * g;
224
225 let m_hat = self.inner.m[i] / bias_correction1;
227 let v_hat = self.inner.v[i] / bias_correction2;
228
229 params[i] *= 1.0 - self.inner.lr * self.weight_decay;
231
232 params[i] -= self.inner.lr * m_hat / (v_hat.sqrt() + self.inner.epsilon);
234 }
235 }
236
237 fn reset(&mut self) {
238 self.inner.reset();
239 }
240
241 fn learning_rate(&self) -> f32 {
242 self.inner.lr
243 }
244
245 fn set_learning_rate(&mut self, lr: f32) {
246 self.inner.lr = lr;
247 }
248}
249
250pub struct LearningRateScheduler {
252 initial_lr: f32,
253 warmup_steps: usize,
254 decay_steps: usize,
255 min_lr: f32,
256 current_step: usize,
257}
258
259impl LearningRateScheduler {
260 pub fn new(initial_lr: f32) -> Self {
261 Self {
262 initial_lr,
263 warmup_steps: 0,
264 decay_steps: 100000,
265 min_lr: 1e-7,
266 current_step: 0,
267 }
268 }
269
270 pub fn with_warmup(mut self, steps: usize) -> Self {
271 self.warmup_steps = steps;
272 self
273 }
274
275 pub fn with_decay(mut self, steps: usize) -> Self {
276 self.decay_steps = steps;
277 self
278 }
279
280 pub fn with_min_lr(mut self, min_lr: f32) -> Self {
281 self.min_lr = min_lr;
282 self
283 }
284
285 pub fn step(&mut self) -> f32 {
287 let lr = self.get_lr();
288 self.current_step += 1;
289 lr
290 }
291
292 pub fn get_lr(&self) -> f32 {
294 if self.current_step < self.warmup_steps {
295 self.initial_lr * (self.current_step + 1) as f32 / self.warmup_steps as f32
297 } else {
298 let progress =
300 (self.current_step - self.warmup_steps) as f32 / self.decay_steps as f32;
301 let decay = 0.5 * (1.0 + (std::f32::consts::PI * progress.min(1.0)).cos());
302 self.min_lr + (self.initial_lr - self.min_lr) * decay
303 }
304 }
305
306 pub fn reset(&mut self) {
308 self.current_step = 0;
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn test_sgd() {
318 let mut opt = SGD::new(4, 0.1);
319 let mut params = vec![1.0, 2.0, 3.0, 4.0];
320 let gradients = vec![0.1, 0.2, 0.3, 0.4];
321
322 opt.step(&mut params, &gradients);
323
324 assert!(params[0] < 1.0);
325 assert!(params[1] < 2.0);
326 }
327
328 #[test]
329 fn test_sgd_momentum() {
330 let mut opt = SGD::new(4, 0.1).with_momentum(0.9);
331 let mut params = vec![1.0; 4];
332 let gradients = vec![1.0; 4];
333
334 for _ in 0..5 {
336 opt.step(&mut params, &gradients);
337 }
338
339 assert!(params[0] < 0.0);
340 }
341
342 #[test]
343 fn test_adam() {
344 let mut opt = Adam::new(64, 0.001);
345 let mut params = vec![0.5; 64];
346 let gradients = vec![0.1; 64];
347
348 for _ in 0..100 {
349 opt.step(&mut params, &gradients);
350 }
351
352 assert!(params[0] < 0.5);
354 }
355
356 #[test]
357 fn test_adamw() {
358 let mut opt = AdamW::new(32, 0.001).with_weight_decay(0.01);
359 let mut params = vec![1.0; 32];
360 let gradients = vec![0.0; 32]; for _ in 0..100 {
363 opt.step(&mut params, &gradients);
364 }
365
366 assert!(params[0] < 1.0);
368 }
369
370 #[test]
371 fn test_lr_scheduler_warmup() {
372 let mut scheduler = LearningRateScheduler::new(0.001).with_warmup(100);
373
374 let lr_start = scheduler.step();
375 assert!(lr_start < 0.001); for _ in 0..99 {
378 scheduler.step();
379 }
380
381 let lr_end_warmup = scheduler.get_lr();
382 assert!((lr_end_warmup - 0.001).abs() < 1e-5);
383 }
384
385 #[test]
386 fn test_lr_scheduler_decay() {
387 let mut scheduler = LearningRateScheduler::new(0.001)
388 .with_warmup(0)
389 .with_decay(100)
390 .with_min_lr(0.0001);
391
392 let lr_start = scheduler.step();
393 assert!((lr_start - 0.001).abs() < 1e-5);
394
395 for _ in 0..100 {
396 scheduler.step();
397 }
398
399 let lr_end = scheduler.get_lr();
400 assert!((lr_end - 0.0001).abs() < 1e-5);
401 }
402}