ruvector_attention/training/
optimizer.rs

1//! Optimizers for attention training
2//!
3//! Provides standard optimizers with momentum and adaptive learning rates.
4
5/// Optimizer trait for parameter updates
6pub trait Optimizer: Send + Sync {
7    /// Update parameters using gradients
8    fn step(&mut self, params: &mut [f32], gradients: &[f32]);
9
10    /// Reset optimizer state
11    fn reset(&mut self);
12
13    /// Get current learning rate
14    fn learning_rate(&self) -> f32;
15
16    /// Set learning rate
17    fn set_learning_rate(&mut self, lr: f32);
18}
19
20/// Stochastic Gradient Descent with momentum
21pub struct SGD {
22    lr: f32,
23    momentum: f32,
24    weight_decay: f32,
25    velocity: Vec<f32>,
26    nesterov: bool,
27}
28
29impl SGD {
30    pub fn new(dim: usize, lr: f32) -> Self {
31        Self {
32            lr,
33            momentum: 0.0,
34            weight_decay: 0.0,
35            velocity: vec![0.0; dim],
36            nesterov: false,
37        }
38    }
39
40    pub fn with_momentum(mut self, momentum: f32) -> Self {
41        self.momentum = momentum;
42        self
43    }
44
45    pub fn with_weight_decay(mut self, wd: f32) -> Self {
46        self.weight_decay = wd;
47        self
48    }
49
50    pub fn with_nesterov(mut self, nesterov: bool) -> Self {
51        self.nesterov = nesterov;
52        self
53    }
54}
55
56impl Optimizer for SGD {
57    fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
58        if self.velocity.len() != params.len() {
59            self.velocity = vec![0.0; params.len()];
60        }
61
62        for i in 0..params.len() {
63            let mut g = gradients[i];
64
65            // Weight decay
66            if self.weight_decay > 0.0 {
67                g += self.weight_decay * params[i];
68            }
69
70            // Update velocity
71            self.velocity[i] = self.momentum * self.velocity[i] + g;
72
73            // Update parameters
74            if self.nesterov {
75                params[i] -= self.lr * (g + self.momentum * self.velocity[i]);
76            } else {
77                params[i] -= self.lr * self.velocity[i];
78            }
79        }
80    }
81
82    fn reset(&mut self) {
83        self.velocity.fill(0.0);
84    }
85
86    fn learning_rate(&self) -> f32 {
87        self.lr
88    }
89
90    fn set_learning_rate(&mut self, lr: f32) {
91        self.lr = lr;
92    }
93}
94
95/// Adam optimizer with bias correction
96pub struct Adam {
97    lr: f32,
98    beta1: f32,
99    beta2: f32,
100    epsilon: f32,
101    weight_decay: f32,
102    m: Vec<f32>, // First moment
103    v: Vec<f32>, // Second moment
104    t: usize,    // Timestep
105}
106
107impl Adam {
108    pub fn new(dim: usize, lr: f32) -> Self {
109        Self {
110            lr,
111            beta1: 0.9,
112            beta2: 0.999,
113            epsilon: 1e-8,
114            weight_decay: 0.0,
115            m: vec![0.0; dim],
116            v: vec![0.0; dim],
117            t: 0,
118        }
119    }
120
121    pub fn with_betas(mut self, beta1: f32, beta2: f32) -> Self {
122        self.beta1 = beta1;
123        self.beta2 = beta2;
124        self
125    }
126
127    pub fn with_epsilon(mut self, eps: f32) -> Self {
128        self.epsilon = eps;
129        self
130    }
131
132    pub fn with_weight_decay(mut self, wd: f32) -> Self {
133        self.weight_decay = wd;
134        self
135    }
136}
137
138impl Optimizer for Adam {
139    fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
140        if self.m.len() != params.len() {
141            self.m = vec![0.0; params.len()];
142            self.v = vec![0.0; params.len()];
143        }
144
145        self.t += 1;
146        let bias_correction1 = 1.0 - self.beta1.powi(self.t as i32);
147        let bias_correction2 = 1.0 - self.beta2.powi(self.t as i32);
148
149        for i in 0..params.len() {
150            let g = gradients[i];
151
152            // Update moments
153            self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
154            self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
155
156            // Bias-corrected estimates
157            let m_hat = self.m[i] / bias_correction1;
158            let v_hat = self.v[i] / bias_correction2;
159
160            // Update with optional weight decay
161            let update = m_hat / (v_hat.sqrt() + self.epsilon);
162            params[i] -= self.lr * (update + self.weight_decay * params[i]);
163        }
164    }
165
166    fn reset(&mut self) {
167        self.m.fill(0.0);
168        self.v.fill(0.0);
169        self.t = 0;
170    }
171
172    fn learning_rate(&self) -> f32 {
173        self.lr
174    }
175
176    fn set_learning_rate(&mut self, lr: f32) {
177        self.lr = lr;
178    }
179}
180
181/// AdamW optimizer (decoupled weight decay)
182pub struct AdamW {
183    inner: Adam,
184    weight_decay: f32,
185}
186
187impl AdamW {
188    pub fn new(dim: usize, lr: f32) -> Self {
189        Self {
190            inner: Adam::new(dim, lr),
191            weight_decay: 0.01,
192        }
193    }
194
195    pub fn with_weight_decay(mut self, wd: f32) -> Self {
196        self.weight_decay = wd;
197        self
198    }
199
200    pub fn with_betas(mut self, beta1: f32, beta2: f32) -> Self {
201        self.inner = self.inner.with_betas(beta1, beta2);
202        self
203    }
204}
205
206impl Optimizer for AdamW {
207    fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
208        if self.inner.m.len() != params.len() {
209            self.inner.m = vec![0.0; params.len()];
210            self.inner.v = vec![0.0; params.len()];
211        }
212
213        self.inner.t += 1;
214        let bias_correction1 = 1.0 - self.inner.beta1.powi(self.inner.t as i32);
215        let bias_correction2 = 1.0 - self.inner.beta2.powi(self.inner.t as i32);
216
217        for i in 0..params.len() {
218            let g = gradients[i];
219
220            // Update moments
221            self.inner.m[i] = self.inner.beta1 * self.inner.m[i] + (1.0 - self.inner.beta1) * g;
222            self.inner.v[i] = self.inner.beta2 * self.inner.v[i] + (1.0 - self.inner.beta2) * g * g;
223
224            // Bias-corrected estimates
225            let m_hat = self.inner.m[i] / bias_correction1;
226            let v_hat = self.inner.v[i] / bias_correction2;
227
228            // Decoupled weight decay (applied to params directly, not through gradient)
229            params[i] *= 1.0 - self.inner.lr * self.weight_decay;
230
231            // Adam update
232            params[i] -= self.inner.lr * m_hat / (v_hat.sqrt() + self.inner.epsilon);
233        }
234    }
235
236    fn reset(&mut self) {
237        self.inner.reset();
238    }
239
240    fn learning_rate(&self) -> f32 {
241        self.inner.lr
242    }
243
244    fn set_learning_rate(&mut self, lr: f32) {
245        self.inner.lr = lr;
246    }
247}
248
249/// Learning rate scheduler
250pub struct LearningRateScheduler {
251    initial_lr: f32,
252    warmup_steps: usize,
253    decay_steps: usize,
254    min_lr: f32,
255    current_step: usize,
256}
257
258impl LearningRateScheduler {
259    pub fn new(initial_lr: f32) -> Self {
260        Self {
261            initial_lr,
262            warmup_steps: 0,
263            decay_steps: 100000,
264            min_lr: 1e-7,
265            current_step: 0,
266        }
267    }
268
269    pub fn with_warmup(mut self, steps: usize) -> Self {
270        self.warmup_steps = steps;
271        self
272    }
273
274    pub fn with_decay(mut self, steps: usize) -> Self {
275        self.decay_steps = steps;
276        self
277    }
278
279    pub fn with_min_lr(mut self, min_lr: f32) -> Self {
280        self.min_lr = min_lr;
281        self
282    }
283
284    /// Get current learning rate and advance step
285    pub fn step(&mut self) -> f32 {
286        let lr = self.get_lr();
287        self.current_step += 1;
288        lr
289    }
290
291    /// Get learning rate without advancing
292    pub fn get_lr(&self) -> f32 {
293        if self.current_step < self.warmup_steps {
294            // Linear warmup
295            self.initial_lr * (self.current_step + 1) as f32 / self.warmup_steps as f32
296        } else {
297            // Cosine decay
298            let progress = (self.current_step - self.warmup_steps) as f32 / self.decay_steps as f32;
299            let decay = 0.5 * (1.0 + (std::f32::consts::PI * progress.min(1.0)).cos());
300            self.min_lr + (self.initial_lr - self.min_lr) * decay
301        }
302    }
303
304    /// Reset scheduler
305    pub fn reset(&mut self) {
306        self.current_step = 0;
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn test_sgd() {
316        let mut opt = SGD::new(4, 0.1);
317        let mut params = vec![1.0, 2.0, 3.0, 4.0];
318        let gradients = vec![0.1, 0.2, 0.3, 0.4];
319
320        opt.step(&mut params, &gradients);
321
322        assert!(params[0] < 1.0);
323        assert!(params[1] < 2.0);
324    }
325
326    #[test]
327    fn test_sgd_momentum() {
328        let mut opt = SGD::new(4, 0.1).with_momentum(0.9);
329        let mut params = vec![1.0; 4];
330        let gradients = vec![1.0; 4];
331
332        // Multiple steps should accumulate momentum
333        for _ in 0..5 {
334            opt.step(&mut params, &gradients);
335        }
336
337        assert!(params[0] < 0.0);
338    }
339
340    #[test]
341    fn test_adam() {
342        let mut opt = Adam::new(64, 0.001);
343        let mut params = vec![0.5; 64];
344        let gradients = vec![0.1; 64];
345
346        for _ in 0..100 {
347            opt.step(&mut params, &gradients);
348        }
349
350        // Should have moved toward 0
351        assert!(params[0] < 0.5);
352    }
353
354    #[test]
355    fn test_adamw() {
356        let mut opt = AdamW::new(32, 0.001).with_weight_decay(0.01);
357        let mut params = vec![1.0; 32];
358        let gradients = vec![0.0; 32]; // No gradient, only weight decay
359
360        for _ in 0..100 {
361            opt.step(&mut params, &gradients);
362        }
363
364        // Weight decay should shrink params
365        assert!(params[0] < 1.0);
366    }
367
368    #[test]
369    fn test_lr_scheduler_warmup() {
370        let mut scheduler = LearningRateScheduler::new(0.001).with_warmup(100);
371
372        let lr_start = scheduler.step();
373        assert!(lr_start < 0.001); // Still warming up
374
375        for _ in 0..99 {
376            scheduler.step();
377        }
378
379        let lr_end_warmup = scheduler.get_lr();
380        assert!((lr_end_warmup - 0.001).abs() < 1e-5);
381    }
382
383    #[test]
384    fn test_lr_scheduler_decay() {
385        let mut scheduler = LearningRateScheduler::new(0.001)
386            .with_warmup(0)
387            .with_decay(100)
388            .with_min_lr(0.0001);
389
390        let lr_start = scheduler.step();
391        assert!((lr_start - 0.001).abs() < 1e-5);
392
393        for _ in 0..100 {
394            scheduler.step();
395        }
396
397        let lr_end = scheduler.get_lr();
398        assert!((lr_end - 0.0001).abs() < 1e-5);
399    }
400}