ruvector_attention/training/
optimizer.rs

1//! Optimizers for attention training
2//!
3//! Provides standard optimizers with momentum and adaptive learning rates.
4
5/// Optimizer trait for parameter updates
6pub trait Optimizer: Send + Sync {
7    /// Update parameters using gradients
8    fn step(&mut self, params: &mut [f32], gradients: &[f32]);
9
10    /// Reset optimizer state
11    fn reset(&mut self);
12
13    /// Get current learning rate
14    fn learning_rate(&self) -> f32;
15
16    /// Set learning rate
17    fn set_learning_rate(&mut self, lr: f32);
18}
19
20/// Stochastic Gradient Descent with momentum
21pub struct SGD {
22    lr: f32,
23    momentum: f32,
24    weight_decay: f32,
25    velocity: Vec<f32>,
26    nesterov: bool,
27}
28
29impl SGD {
30    pub fn new(dim: usize, lr: f32) -> Self {
31        Self {
32            lr,
33            momentum: 0.0,
34            weight_decay: 0.0,
35            velocity: vec![0.0; dim],
36            nesterov: false,
37        }
38    }
39
40    pub fn with_momentum(mut self, momentum: f32) -> Self {
41        self.momentum = momentum;
42        self
43    }
44
45    pub fn with_weight_decay(mut self, wd: f32) -> Self {
46        self.weight_decay = wd;
47        self
48    }
49
50    pub fn with_nesterov(mut self, nesterov: bool) -> Self {
51        self.nesterov = nesterov;
52        self
53    }
54}
55
56impl Optimizer for SGD {
57    fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
58        if self.velocity.len() != params.len() {
59            self.velocity = vec![0.0; params.len()];
60        }
61
62        for i in 0..params.len() {
63            let mut g = gradients[i];
64
65            // Weight decay
66            if self.weight_decay > 0.0 {
67                g += self.weight_decay * params[i];
68            }
69
70            // Update velocity
71            self.velocity[i] = self.momentum * self.velocity[i] + g;
72
73            // Update parameters
74            if self.nesterov {
75                params[i] -= self.lr * (g + self.momentum * self.velocity[i]);
76            } else {
77                params[i] -= self.lr * self.velocity[i];
78            }
79        }
80    }
81
82    fn reset(&mut self) {
83        self.velocity.fill(0.0);
84    }
85
86    fn learning_rate(&self) -> f32 {
87        self.lr
88    }
89
90    fn set_learning_rate(&mut self, lr: f32) {
91        self.lr = lr;
92    }
93}
94
95/// Adam optimizer with bias correction
96pub struct Adam {
97    lr: f32,
98    beta1: f32,
99    beta2: f32,
100    epsilon: f32,
101    weight_decay: f32,
102    m: Vec<f32>,    // First moment
103    v: Vec<f32>,    // Second moment
104    t: usize,       // Timestep
105}
106
107impl Adam {
108    pub fn new(dim: usize, lr: f32) -> Self {
109        Self {
110            lr,
111            beta1: 0.9,
112            beta2: 0.999,
113            epsilon: 1e-8,
114            weight_decay: 0.0,
115            m: vec![0.0; dim],
116            v: vec![0.0; dim],
117            t: 0,
118        }
119    }
120
121    pub fn with_betas(mut self, beta1: f32, beta2: f32) -> Self {
122        self.beta1 = beta1;
123        self.beta2 = beta2;
124        self
125    }
126
127    pub fn with_epsilon(mut self, eps: f32) -> Self {
128        self.epsilon = eps;
129        self
130    }
131
132    pub fn with_weight_decay(mut self, wd: f32) -> Self {
133        self.weight_decay = wd;
134        self
135    }
136}
137
138impl Optimizer for Adam {
139    fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
140        if self.m.len() != params.len() {
141            self.m = vec![0.0; params.len()];
142            self.v = vec![0.0; params.len()];
143        }
144
145        self.t += 1;
146        let bias_correction1 = 1.0 - self.beta1.powi(self.t as i32);
147        let bias_correction2 = 1.0 - self.beta2.powi(self.t as i32);
148
149        for i in 0..params.len() {
150            let g = gradients[i];
151
152            // Update moments
153            self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
154            self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
155
156            // Bias-corrected estimates
157            let m_hat = self.m[i] / bias_correction1;
158            let v_hat = self.v[i] / bias_correction2;
159
160            // Update with optional weight decay
161            let update = m_hat / (v_hat.sqrt() + self.epsilon);
162            params[i] -= self.lr * (update + self.weight_decay * params[i]);
163        }
164    }
165
166    fn reset(&mut self) {
167        self.m.fill(0.0);
168        self.v.fill(0.0);
169        self.t = 0;
170    }
171
172    fn learning_rate(&self) -> f32 {
173        self.lr
174    }
175
176    fn set_learning_rate(&mut self, lr: f32) {
177        self.lr = lr;
178    }
179}
180
181/// AdamW optimizer (decoupled weight decay)
182pub struct AdamW {
183    inner: Adam,
184    weight_decay: f32,
185}
186
187impl AdamW {
188    pub fn new(dim: usize, lr: f32) -> Self {
189        Self {
190            inner: Adam::new(dim, lr),
191            weight_decay: 0.01,
192        }
193    }
194
195    pub fn with_weight_decay(mut self, wd: f32) -> Self {
196        self.weight_decay = wd;
197        self
198    }
199
200    pub fn with_betas(mut self, beta1: f32, beta2: f32) -> Self {
201        self.inner = self.inner.with_betas(beta1, beta2);
202        self
203    }
204}
205
206impl Optimizer for AdamW {
207    fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
208        if self.inner.m.len() != params.len() {
209            self.inner.m = vec![0.0; params.len()];
210            self.inner.v = vec![0.0; params.len()];
211        }
212
213        self.inner.t += 1;
214        let bias_correction1 = 1.0 - self.inner.beta1.powi(self.inner.t as i32);
215        let bias_correction2 = 1.0 - self.inner.beta2.powi(self.inner.t as i32);
216
217        for i in 0..params.len() {
218            let g = gradients[i];
219
220            // Update moments
221            self.inner.m[i] = self.inner.beta1 * self.inner.m[i] + (1.0 - self.inner.beta1) * g;
222            self.inner.v[i] =
223                self.inner.beta2 * self.inner.v[i] + (1.0 - self.inner.beta2) * g * g;
224
225            // Bias-corrected estimates
226            let m_hat = self.inner.m[i] / bias_correction1;
227            let v_hat = self.inner.v[i] / bias_correction2;
228
229            // Decoupled weight decay (applied to params directly, not through gradient)
230            params[i] *= 1.0 - self.inner.lr * self.weight_decay;
231
232            // Adam update
233            params[i] -= self.inner.lr * m_hat / (v_hat.sqrt() + self.inner.epsilon);
234        }
235    }
236
237    fn reset(&mut self) {
238        self.inner.reset();
239    }
240
241    fn learning_rate(&self) -> f32 {
242        self.inner.lr
243    }
244
245    fn set_learning_rate(&mut self, lr: f32) {
246        self.inner.lr = lr;
247    }
248}
249
250/// Learning rate scheduler
251pub struct LearningRateScheduler {
252    initial_lr: f32,
253    warmup_steps: usize,
254    decay_steps: usize,
255    min_lr: f32,
256    current_step: usize,
257}
258
259impl LearningRateScheduler {
260    pub fn new(initial_lr: f32) -> Self {
261        Self {
262            initial_lr,
263            warmup_steps: 0,
264            decay_steps: 100000,
265            min_lr: 1e-7,
266            current_step: 0,
267        }
268    }
269
270    pub fn with_warmup(mut self, steps: usize) -> Self {
271        self.warmup_steps = steps;
272        self
273    }
274
275    pub fn with_decay(mut self, steps: usize) -> Self {
276        self.decay_steps = steps;
277        self
278    }
279
280    pub fn with_min_lr(mut self, min_lr: f32) -> Self {
281        self.min_lr = min_lr;
282        self
283    }
284
285    /// Get current learning rate and advance step
286    pub fn step(&mut self) -> f32 {
287        let lr = self.get_lr();
288        self.current_step += 1;
289        lr
290    }
291
292    /// Get learning rate without advancing
293    pub fn get_lr(&self) -> f32 {
294        if self.current_step < self.warmup_steps {
295            // Linear warmup
296            self.initial_lr * (self.current_step + 1) as f32 / self.warmup_steps as f32
297        } else {
298            // Cosine decay
299            let progress =
300                (self.current_step - self.warmup_steps) as f32 / self.decay_steps as f32;
301            let decay = 0.5 * (1.0 + (std::f32::consts::PI * progress.min(1.0)).cos());
302            self.min_lr + (self.initial_lr - self.min_lr) * decay
303        }
304    }
305
306    /// Reset scheduler
307    pub fn reset(&mut self) {
308        self.current_step = 0;
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn test_sgd() {
318        let mut opt = SGD::new(4, 0.1);
319        let mut params = vec![1.0, 2.0, 3.0, 4.0];
320        let gradients = vec![0.1, 0.2, 0.3, 0.4];
321
322        opt.step(&mut params, &gradients);
323
324        assert!(params[0] < 1.0);
325        assert!(params[1] < 2.0);
326    }
327
328    #[test]
329    fn test_sgd_momentum() {
330        let mut opt = SGD::new(4, 0.1).with_momentum(0.9);
331        let mut params = vec![1.0; 4];
332        let gradients = vec![1.0; 4];
333
334        // Multiple steps should accumulate momentum
335        for _ in 0..5 {
336            opt.step(&mut params, &gradients);
337        }
338
339        assert!(params[0] < 0.0);
340    }
341
342    #[test]
343    fn test_adam() {
344        let mut opt = Adam::new(64, 0.001);
345        let mut params = vec![0.5; 64];
346        let gradients = vec![0.1; 64];
347
348        for _ in 0..100 {
349            opt.step(&mut params, &gradients);
350        }
351
352        // Should have moved toward 0
353        assert!(params[0] < 0.5);
354    }
355
356    #[test]
357    fn test_adamw() {
358        let mut opt = AdamW::new(32, 0.001).with_weight_decay(0.01);
359        let mut params = vec![1.0; 32];
360        let gradients = vec![0.0; 32]; // No gradient, only weight decay
361
362        for _ in 0..100 {
363            opt.step(&mut params, &gradients);
364        }
365
366        // Weight decay should shrink params
367        assert!(params[0] < 1.0);
368    }
369
370    #[test]
371    fn test_lr_scheduler_warmup() {
372        let mut scheduler = LearningRateScheduler::new(0.001).with_warmup(100);
373
374        let lr_start = scheduler.step();
375        assert!(lr_start < 0.001); // Still warming up
376
377        for _ in 0..99 {
378            scheduler.step();
379        }
380
381        let lr_end_warmup = scheduler.get_lr();
382        assert!((lr_end_warmup - 0.001).abs() < 1e-5);
383    }
384
385    #[test]
386    fn test_lr_scheduler_decay() {
387        let mut scheduler = LearningRateScheduler::new(0.001)
388            .with_warmup(0)
389            .with_decay(100)
390            .with_min_lr(0.0001);
391
392        let lr_start = scheduler.step();
393        assert!((lr_start - 0.001).abs() < 1e-5);
394
395        for _ in 0..100 {
396            scheduler.step();
397        }
398
399        let lr_end = scheduler.get_lr();
400        assert!((lr_end - 0.0001).abs() < 1e-5);
401    }
402}