Skip to main content

yscv_optim/
sgd.rs

1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use yscv_autograd::{Graph, NodeId};
5use yscv_tensor::Tensor;
6
7use super::validate::{validate_dampening, validate_lr, validate_momentum};
8use super::{LearningRate, OptimError};
9
10/// Stochastic gradient descent optimizer with optional momentum and weight decay.
11#[derive(Debug, Clone)]
12pub struct Sgd {
13    lr: f32,
14    momentum: f32,
15    dampening: f32,
16    weight_decay: f32,
17    nesterov: bool,
18    velocity: HashMap<u64, Tensor>,
19}
20
21impl Sgd {
22    /// Creates SGD with required learning rate.
23    pub fn new(lr: f32) -> Result<Self, OptimError> {
24        validate_lr(lr)?;
25        Ok(Self {
26            lr,
27            momentum: 0.0,
28            dampening: 0.0,
29            weight_decay: 0.0,
30            nesterov: false,
31            velocity: HashMap::new(),
32        })
33    }
34
35    /// Sets momentum factor in `[0, 1)`.
36    pub fn with_momentum(mut self, momentum: f32) -> Result<Self, OptimError> {
37        validate_momentum(momentum)?;
38        self.momentum = momentum;
39        self.validate_nesterov_constraints()?;
40        Ok(self)
41    }
42
43    /// Sets dampening factor in `[0, 1]`.
44    pub fn with_dampening(mut self, dampening: f32) -> Result<Self, OptimError> {
45        validate_dampening(dampening)?;
46        self.dampening = dampening;
47        Ok(self)
48    }
49
50    /// Sets L2 weight decay factor in `[0, +inf)`.
51    pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
52        if !weight_decay.is_finite() || weight_decay < 0.0 {
53            return Err(OptimError::InvalidWeightDecay { weight_decay });
54        }
55        self.weight_decay = weight_decay;
56        Ok(self)
57    }
58
59    /// Enables/disables Nesterov update rule.
60    pub fn with_nesterov(mut self, nesterov: bool) -> Result<Self, OptimError> {
61        self.nesterov = nesterov;
62        self.validate_nesterov_constraints()?;
63        Ok(self)
64    }
65
66    /// Drops optimizer state (for example when restarting training).
67    pub fn clear_state(&mut self) {
68        self.velocity.clear();
69    }
70
71    /// Returns current learning rate.
72    pub fn learning_rate(&self) -> f32 {
73        self.lr
74    }
75
76    /// Overrides current learning rate.
77    pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
78        validate_lr(lr)?;
79        self.lr = lr;
80        Ok(())
81    }
82
83    /// Applies one update to raw tensor weights.
84    pub fn step(
85        &mut self,
86        parameter_id: u64,
87        weights: &mut Tensor,
88        grad: &Tensor,
89    ) -> Result<(), OptimError> {
90        if weights.shape() != grad.shape() {
91            return Err(OptimError::ShapeMismatch {
92                weights: weights.shape().to_vec(),
93                grad: grad.shape().to_vec(),
94            });
95        }
96
97        // Fast path: no weight decay, no momentum — just axpy in-place.
98        if self.weight_decay == 0.0 && self.momentum == 0.0 {
99            axpy_neg(weights.data_mut(), grad.data(), self.lr);
100            return Ok(());
101        }
102
103        // When weight_decay != 0 we need adjusted gradients.
104        let has_wd = self.weight_decay != 0.0;
105        // Build adjusted_grad only when weight_decay is non-zero; otherwise
106        // we can reference grad.data() directly.
107        let adjusted_grad_buf: Vec<f32>;
108        let grad_slice: &[f32] = if has_wd {
109            let mut buf = grad.data().to_vec();
110            let wd = self.weight_decay;
111            fma_inplace(&mut buf, weights.data(), wd);
112            adjusted_grad_buf = buf;
113            &adjusted_grad_buf
114        } else {
115            grad.data()
116        };
117
118        if self.momentum != 0.0 {
119            let velocity = match self.velocity.entry(parameter_id) {
120                Entry::Occupied(entry) => entry.into_mut(),
121                Entry::Vacant(entry) => {
122                    let initial = Tensor::zeros(weights.shape().to_vec())?;
123                    entry.insert(initial)
124                }
125            };
126
127            if velocity.shape() != weights.shape() {
128                *velocity = Tensor::zeros(weights.shape().to_vec())?;
129            }
130
131            // velocity = momentum * velocity + (1 - dampening) * grad
132            // Done in-place on velocity's buffer to avoid allocation.
133            let mom = self.momentum;
134            let grad_scale = 1.0 - self.dampening;
135            momentum_update(velocity.data_mut(), grad_slice, mom, grad_scale);
136
137            if self.nesterov {
138                // update = grad + momentum * velocity
139                // weights -= lr * update
140                // => weights -= lr * grad + lr * momentum * velocity
141                axpy_neg(weights.data_mut(), grad_slice, self.lr);
142                axpy_neg(weights.data_mut(), velocity.data(), self.lr * mom);
143            } else {
144                // weights -= lr * velocity
145                axpy_neg(weights.data_mut(), velocity.data(), self.lr);
146            }
147        } else {
148            axpy_neg(weights.data_mut(), grad_slice, self.lr);
149        }
150        Ok(())
151    }
152
153    /// Applies one update to a trainable graph node by its `NodeId`.
154    pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
155        if !graph.requires_grad(node)? {
156            return Ok(());
157        }
158
159        let grad = match graph.grad(node)? {
160            Some(grad) => grad.clone(),
161            None => return Err(OptimError::MissingGradient { node: node.0 }),
162        };
163        let weights = graph.value_mut(node)?;
164        self.step(node.0 as u64, weights, &grad)
165    }
166
167    fn validate_nesterov_constraints(&self) -> Result<(), OptimError> {
168        if self.nesterov && self.momentum == 0.0 {
169            return Err(OptimError::NesterovRequiresMomentum);
170        }
171        Ok(())
172    }
173}
174
175impl LearningRate for Sgd {
176    fn learning_rate(&self) -> f32 {
177        Sgd::learning_rate(self)
178    }
179
180    fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
181        Sgd::set_learning_rate(self, lr)
182    }
183}
184
185/// weights[i] -= lr * grads[i]  — SIMD-accelerated axpy(negative).
186#[allow(unsafe_code)]
187fn axpy_neg(weights: &mut [f32], grads: &[f32], lr: f32) {
188    debug_assert_eq!(weights.len(), grads.len());
189    let len = weights.len();
190
191    #[cfg(target_arch = "aarch64")]
192    if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
193        unsafe { axpy_neg_neon(weights, grads, lr) };
194        return;
195    }
196
197    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
198    if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
199        unsafe { axpy_neg_avx(weights, grads, lr) };
200        return;
201    }
202
203    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
204    if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
205        unsafe { axpy_neg_sse(weights, grads, lr) };
206        return;
207    }
208
209    let w_ptr = weights.as_mut_ptr();
210    let g_ptr = grads.as_ptr();
211    unsafe {
212        let mut i = 0usize;
213        while i + 4 <= len {
214            *w_ptr.add(i) -= lr * *g_ptr.add(i);
215            *w_ptr.add(i + 1) -= lr * *g_ptr.add(i + 1);
216            *w_ptr.add(i + 2) -= lr * *g_ptr.add(i + 2);
217            *w_ptr.add(i + 3) -= lr * *g_ptr.add(i + 3);
218            i += 4;
219        }
220        while i < len {
221            *w_ptr.add(i) -= lr * *g_ptr.add(i);
222            i += 1;
223        }
224    }
225}
226
227/// dst[i] += src[i] * scale  — SIMD-accelerated fused multiply-add.
228#[allow(unsafe_code)]
229fn fma_inplace(dst: &mut [f32], src: &[f32], scale: f32) {
230    debug_assert_eq!(dst.len(), src.len());
231    let len = dst.len();
232
233    #[cfg(target_arch = "aarch64")]
234    if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
235        unsafe { fma_inplace_neon(dst, src, scale) };
236        return;
237    }
238
239    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
240    if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
241        unsafe { fma_inplace_avx(dst, src, scale) };
242        return;
243    }
244
245    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
246    if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
247        unsafe { fma_inplace_sse(dst, src, scale) };
248        return;
249    }
250
251    for i in 0..len {
252        dst[i] += src[i] * scale;
253    }
254}
255
256/// velocity[i] = momentum * velocity[i] + grad_scale * grad[i]  — in-place.
257#[allow(unsafe_code)]
258fn momentum_update(velocity: &mut [f32], grad: &[f32], momentum: f32, grad_scale: f32) {
259    debug_assert_eq!(velocity.len(), grad.len());
260    let len = velocity.len();
261
262    #[cfg(target_arch = "aarch64")]
263    if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
264        unsafe { momentum_update_neon(velocity, grad, momentum, grad_scale) };
265        return;
266    }
267
268    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
269    if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
270        unsafe { momentum_update_avx(velocity, grad, momentum, grad_scale) };
271        return;
272    }
273
274    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
275    if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
276        unsafe { momentum_update_sse(velocity, grad, momentum, grad_scale) };
277        return;
278    }
279
280    for i in 0..len {
281        velocity[i] = momentum * velocity[i] + grad_scale * grad[i];
282    }
283}
284
285// ── NEON implementations ────────────────────────────────────────────────
286
287#[cfg(target_arch = "aarch64")]
288#[target_feature(enable = "neon")]
289#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
290unsafe fn axpy_neg_neon(weights: &mut [f32], grads: &[f32], lr: f32) {
291    use std::arch::aarch64::*;
292    let len = weights.len();
293    let wp = weights.as_mut_ptr();
294    let gp = grads.as_ptr();
295    let vlr = vdupq_n_f32(lr);
296    let mut i = 0usize;
297    while i + 4 <= len {
298        let w = vld1q_f32(wp.add(i));
299        let g = vld1q_f32(gp.add(i));
300        vst1q_f32(wp.add(i), vfmsq_f32(w, g, vlr));
301        i += 4;
302    }
303    while i < len {
304        *wp.add(i) -= lr * *gp.add(i);
305        i += 1;
306    }
307}
308
309#[cfg(target_arch = "aarch64")]
310#[target_feature(enable = "neon")]
311#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
312unsafe fn fma_inplace_neon(dst: &mut [f32], src: &[f32], scale: f32) {
313    use std::arch::aarch64::*;
314    let len = dst.len();
315    let dp = dst.as_mut_ptr();
316    let sp = src.as_ptr();
317    let vs = vdupq_n_f32(scale);
318    let mut i = 0usize;
319    while i + 4 <= len {
320        let d = vld1q_f32(dp.add(i));
321        let s = vld1q_f32(sp.add(i));
322        vst1q_f32(dp.add(i), vfmaq_f32(d, s, vs));
323        i += 4;
324    }
325    while i < len {
326        *dp.add(i) += *sp.add(i) * scale;
327        i += 1;
328    }
329}
330
331#[cfg(target_arch = "aarch64")]
332#[target_feature(enable = "neon")]
333#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
334unsafe fn momentum_update_neon(velocity: &mut [f32], grad: &[f32], momentum: f32, grad_scale: f32) {
335    use std::arch::aarch64::*;
336    let len = velocity.len();
337    let vp = velocity.as_mut_ptr();
338    let gp = grad.as_ptr();
339    let vmom = vdupq_n_f32(momentum);
340    let vgs = vdupq_n_f32(grad_scale);
341    let mut i = 0usize;
342    while i + 4 <= len {
343        let v = vld1q_f32(vp.add(i));
344        let g = vld1q_f32(gp.add(i));
345        // momentum * v + grad_scale * g
346        let result = vfmaq_f32(vmulq_f32(vmom, v), g, vgs);
347        vst1q_f32(vp.add(i), result);
348        i += 4;
349    }
350    while i < len {
351        *vp.add(i) = momentum * *vp.add(i) + grad_scale * *gp.add(i);
352        i += 1;
353    }
354}
355
356// ── AVX implementations ─────────────────────────────────────────────────
357
358#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
359#[target_feature(enable = "avx")]
360#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
361unsafe fn axpy_neg_avx(weights: &mut [f32], grads: &[f32], lr: f32) {
362    #[cfg(target_arch = "x86")]
363    use std::arch::x86::*;
364    #[cfg(target_arch = "x86_64")]
365    use std::arch::x86_64::*;
366    let len = weights.len();
367    let wp = weights.as_mut_ptr();
368    let gp = grads.as_ptr();
369    let vlr = _mm256_set1_ps(lr);
370    let mut i = 0usize;
371    while i + 8 <= len {
372        let w = _mm256_loadu_ps(wp.add(i));
373        let g = _mm256_loadu_ps(gp.add(i));
374        _mm256_storeu_ps(wp.add(i), _mm256_sub_ps(w, _mm256_mul_ps(g, vlr)));
375        i += 8;
376    }
377    while i < len {
378        *wp.add(i) -= lr * *gp.add(i);
379        i += 1;
380    }
381}
382
383#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
384#[target_feature(enable = "avx")]
385#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
386unsafe fn fma_inplace_avx(dst: &mut [f32], src: &[f32], scale: f32) {
387    #[cfg(target_arch = "x86")]
388    use std::arch::x86::*;
389    #[cfg(target_arch = "x86_64")]
390    use std::arch::x86_64::*;
391    let len = dst.len();
392    let dp = dst.as_mut_ptr();
393    let sp = src.as_ptr();
394    let vs = _mm256_set1_ps(scale);
395    let mut i = 0usize;
396    while i + 8 <= len {
397        let d = _mm256_loadu_ps(dp.add(i));
398        let s = _mm256_loadu_ps(sp.add(i));
399        _mm256_storeu_ps(dp.add(i), _mm256_add_ps(d, _mm256_mul_ps(s, vs)));
400        i += 8;
401    }
402    while i < len {
403        *dp.add(i) += *sp.add(i) * scale;
404        i += 1;
405    }
406}
407
408#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
409#[target_feature(enable = "avx")]
410#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
411unsafe fn momentum_update_avx(velocity: &mut [f32], grad: &[f32], momentum: f32, grad_scale: f32) {
412    #[cfg(target_arch = "x86")]
413    use std::arch::x86::*;
414    #[cfg(target_arch = "x86_64")]
415    use std::arch::x86_64::*;
416    let len = velocity.len();
417    let vp = velocity.as_mut_ptr();
418    let gp = grad.as_ptr();
419    let vmom = _mm256_set1_ps(momentum);
420    let vgs = _mm256_set1_ps(grad_scale);
421    let mut i = 0usize;
422    while i + 8 <= len {
423        let v = _mm256_loadu_ps(vp.add(i));
424        let g = _mm256_loadu_ps(gp.add(i));
425        let result = _mm256_add_ps(_mm256_mul_ps(vmom, v), _mm256_mul_ps(g, vgs));
426        _mm256_storeu_ps(vp.add(i), result);
427        i += 8;
428    }
429    while i < len {
430        *vp.add(i) = momentum * *vp.add(i) + grad_scale * *gp.add(i);
431        i += 1;
432    }
433}
434
435// ── SSE implementations ─────────────────────────────────────────────────
436
437#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
438#[target_feature(enable = "sse")]
439#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
440unsafe fn axpy_neg_sse(weights: &mut [f32], grads: &[f32], lr: f32) {
441    #[cfg(target_arch = "x86")]
442    use std::arch::x86::*;
443    #[cfg(target_arch = "x86_64")]
444    use std::arch::x86_64::*;
445    let len = weights.len();
446    let wp = weights.as_mut_ptr();
447    let gp = grads.as_ptr();
448    let vlr = _mm_set1_ps(lr);
449    let mut i = 0usize;
450    while i + 4 <= len {
451        let w = _mm_loadu_ps(wp.add(i));
452        let g = _mm_loadu_ps(gp.add(i));
453        _mm_storeu_ps(wp.add(i), _mm_sub_ps(w, _mm_mul_ps(g, vlr)));
454        i += 4;
455    }
456    while i < len {
457        *wp.add(i) -= lr * *gp.add(i);
458        i += 1;
459    }
460}
461
462#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
463#[target_feature(enable = "sse")]
464#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
465unsafe fn fma_inplace_sse(dst: &mut [f32], src: &[f32], scale: f32) {
466    #[cfg(target_arch = "x86")]
467    use std::arch::x86::*;
468    #[cfg(target_arch = "x86_64")]
469    use std::arch::x86_64::*;
470    let len = dst.len();
471    let dp = dst.as_mut_ptr();
472    let sp = src.as_ptr();
473    let vs = _mm_set1_ps(scale);
474    let mut i = 0usize;
475    while i + 4 <= len {
476        let d = _mm_loadu_ps(dp.add(i));
477        let s = _mm_loadu_ps(sp.add(i));
478        _mm_storeu_ps(dp.add(i), _mm_add_ps(d, _mm_mul_ps(s, vs)));
479        i += 4;
480    }
481    while i < len {
482        *dp.add(i) += *sp.add(i) * scale;
483        i += 1;
484    }
485}
486
487#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
488#[target_feature(enable = "sse")]
489#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
490unsafe fn momentum_update_sse(velocity: &mut [f32], grad: &[f32], momentum: f32, grad_scale: f32) {
491    #[cfg(target_arch = "x86")]
492    use std::arch::x86::*;
493    #[cfg(target_arch = "x86_64")]
494    use std::arch::x86_64::*;
495    let len = velocity.len();
496    let vp = velocity.as_mut_ptr();
497    let gp = grad.as_ptr();
498    let vmom = _mm_set1_ps(momentum);
499    let vgs = _mm_set1_ps(grad_scale);
500    let mut i = 0usize;
501    while i + 4 <= len {
502        let v = _mm_loadu_ps(vp.add(i));
503        let g = _mm_loadu_ps(gp.add(i));
504        let result = _mm_add_ps(_mm_mul_ps(vmom, v), _mm_mul_ps(g, vgs));
505        _mm_storeu_ps(vp.add(i), result);
506        i += 4;
507    }
508    while i < len {
509        *vp.add(i) = momentum * *vp.add(i) + grad_scale * *gp.add(i);
510        i += 1;
511    }
512}