Skip to main content

yscv_optim/
adam.rs

1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use yscv_autograd::{Graph, NodeId};
5use yscv_tensor::Tensor;
6
7use super::validate::{validate_beta1, validate_beta2, validate_epsilon, validate_lr};
8use super::{LearningRate, OptimError};
9
10#[derive(Debug, Clone)]
11struct AdamState {
12    first_moment: Tensor,
13    second_moment: Tensor,
14    step: u64,
15}
16
17impl AdamState {
18    fn new(shape: &[usize]) -> Result<Self, OptimError> {
19        Ok(Self {
20            first_moment: Tensor::zeros(shape.to_vec())?,
21            second_moment: Tensor::zeros(shape.to_vec())?,
22            step: 0,
23        })
24    }
25
26    fn reset(&mut self, shape: &[usize]) -> Result<(), OptimError> {
27        *self = Self::new(shape)?;
28        Ok(())
29    }
30}
31
32/// Adam optimizer with optional L2 weight decay.
33#[derive(Debug, Clone)]
34pub struct Adam {
35    lr: f32,
36    beta1: f32,
37    beta2: f32,
38    epsilon: f32,
39    weight_decay: f32,
40    state: HashMap<u64, AdamState>,
41}
42
43impl Adam {
44    /// Creates Adam with required learning rate.
45    pub fn new(lr: f32) -> Result<Self, OptimError> {
46        validate_lr(lr)?;
47        Ok(Self {
48            lr,
49            beta1: 0.9,
50            beta2: 0.999,
51            epsilon: 1e-8,
52            weight_decay: 0.0,
53            state: HashMap::new(),
54        })
55    }
56
57    /// Sets beta1 factor in `[0, 1)`.
58    pub fn with_beta1(mut self, beta1: f32) -> Result<Self, OptimError> {
59        validate_beta1(beta1)?;
60        self.beta1 = beta1;
61        Ok(self)
62    }
63
64    /// Sets beta2 factor in `[0, 1)`.
65    pub fn with_beta2(mut self, beta2: f32) -> Result<Self, OptimError> {
66        validate_beta2(beta2)?;
67        self.beta2 = beta2;
68        Ok(self)
69    }
70
71    /// Sets epsilon value, must be finite and `> 0`.
72    pub fn with_epsilon(mut self, epsilon: f32) -> Result<Self, OptimError> {
73        validate_epsilon(epsilon)?;
74        self.epsilon = epsilon;
75        Ok(self)
76    }
77
78    /// Sets L2 weight decay factor in `[0, +inf)`.
79    pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
80        if !weight_decay.is_finite() || weight_decay < 0.0 {
81            return Err(OptimError::InvalidWeightDecay { weight_decay });
82        }
83        self.weight_decay = weight_decay;
84        Ok(self)
85    }
86
87    /// Drops optimizer state (for example when restarting training).
88    pub fn clear_state(&mut self) {
89        self.state.clear();
90    }
91
92    /// Returns current learning rate.
93    pub fn learning_rate(&self) -> f32 {
94        self.lr
95    }
96
97    /// Overrides current learning rate.
98    pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
99        validate_lr(lr)?;
100        self.lr = lr;
101        Ok(())
102    }
103
104    /// Applies one update to raw tensor weights.
105    pub fn step(
106        &mut self,
107        parameter_id: u64,
108        weights: &mut Tensor,
109        grad: &Tensor,
110    ) -> Result<(), OptimError> {
111        if weights.shape() != grad.shape() {
112            return Err(OptimError::ShapeMismatch {
113                weights: weights.shape().to_vec(),
114                grad: grad.shape().to_vec(),
115            });
116        }
117
118        let state = match self.state.entry(parameter_id) {
119            Entry::Occupied(entry) => entry.into_mut(),
120            Entry::Vacant(entry) => entry.insert(AdamState::new(weights.shape())?),
121        };
122        if state.first_moment.shape() != weights.shape()
123            || state.second_moment.shape() != weights.shape()
124        {
125            state.reset(weights.shape())?;
126        }
127
128        state.step = state.step.saturating_add(1);
129        let step_f64 = state.step as f64;
130        let bias_correction1 =
131            (1.0 - (self.beta1 as f64).powf(step_f64)).max(f64::MIN_POSITIVE) as f32;
132        let bias_correction2 =
133            (1.0 - (self.beta2 as f64).powf(step_f64)).max(f64::MIN_POSITIVE) as f32;
134
135        let first_moment = state.first_moment.data_mut();
136        let second_moment = state.second_moment.data_mut();
137        let grad_values = grad.data();
138        let weights_data = weights.data_mut();
139
140        let beta1 = self.beta1;
141        let beta2 = self.beta2;
142        let one_minus_beta1 = 1.0 - beta1;
143        let one_minus_beta2 = 1.0 - beta2;
144        let bias_correction1_inv = 1.0 / bias_correction1;
145        let bias_correction2_inv = 1.0 / bias_correction2;
146        let lr = self.lr;
147        let epsilon = self.epsilon;
148        let weight_decay = self.weight_decay;
149
150        adam_update_inner(
151            weights_data,
152            grad_values,
153            first_moment,
154            second_moment,
155            beta1,
156            beta2,
157            one_minus_beta1,
158            one_minus_beta2,
159            bias_correction1_inv,
160            bias_correction2_inv,
161            lr,
162            epsilon,
163            weight_decay,
164        );
165
166        Ok(())
167    }
168
169    /// Applies one update to a trainable graph node by its `NodeId`.
170    pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
171        if !graph.requires_grad(node)? {
172            return Ok(());
173        }
174
175        let grad = match graph.grad(node)? {
176            Some(grad) => grad.clone(),
177            None => return Err(OptimError::MissingGradient { node: node.0 }),
178        };
179        let weights = graph.value_mut(node)?;
180        self.step(node.0 as u64, weights, &grad)
181    }
182}
183
184impl LearningRate for Adam {
185    fn learning_rate(&self) -> f32 {
186        Adam::learning_rate(self)
187    }
188
189    fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
190        Adam::set_learning_rate(self, lr)
191    }
192}
193
194/// SIMD-accelerated Adam parameter update.
195#[allow(clippy::too_many_arguments, unsafe_code)]
196fn adam_update_inner(
197    weights: &mut [f32],
198    grad: &[f32],
199    first_moment: &mut [f32],
200    second_moment: &mut [f32],
201    beta1: f32,
202    beta2: f32,
203    one_minus_beta1: f32,
204    one_minus_beta2: f32,
205    bc1_inv: f32,
206    bc2_inv: f32,
207    lr: f32,
208    epsilon: f32,
209    weight_decay: f32,
210) {
211    let len = weights.len();
212
213    #[cfg(target_arch = "aarch64")]
214    if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
215        unsafe {
216            adam_update_neon(
217                weights,
218                grad,
219                first_moment,
220                second_moment,
221                beta1,
222                beta2,
223                one_minus_beta1,
224                one_minus_beta2,
225                bc1_inv,
226                bc2_inv,
227                lr,
228                epsilon,
229                weight_decay,
230            );
231        }
232        return;
233    }
234
235    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
236    if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
237        unsafe {
238            adam_update_avx(
239                weights,
240                grad,
241                first_moment,
242                second_moment,
243                beta1,
244                beta2,
245                one_minus_beta1,
246                one_minus_beta2,
247                bc1_inv,
248                bc2_inv,
249                lr,
250                epsilon,
251                weight_decay,
252            );
253        }
254        return;
255    }
256
257    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
258    if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
259        unsafe {
260            adam_update_sse(
261                weights,
262                grad,
263                first_moment,
264                second_moment,
265                beta1,
266                beta2,
267                one_minus_beta1,
268                one_minus_beta2,
269                bc1_inv,
270                bc2_inv,
271                lr,
272                epsilon,
273                weight_decay,
274            );
275        }
276        return;
277    }
278
279    let wp = weights.as_mut_ptr();
280    let gp = grad.as_ptr();
281    let mp = first_moment.as_mut_ptr();
282    let vp = second_moment.as_mut_ptr();
283    for i in 0..len {
284        unsafe {
285            let g = *gp.add(i) + weight_decay * *wp.add(i);
286            let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
287            let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
288            *mp.add(i) = m;
289            *vp.add(i) = v;
290            let m_hat = m * bc1_inv;
291            let v_hat = v * bc2_inv;
292            *wp.add(i) -= lr * m_hat / (v_hat.sqrt() + epsilon);
293        }
294    }
295}
296
297// ── NEON implementation ─────────────────────────────────────────────────
298
299#[cfg(target_arch = "aarch64")]
300#[target_feature(enable = "neon")]
301#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
302unsafe fn adam_update_neon(
303    weights: &mut [f32],
304    grad: &[f32],
305    first_moment: &mut [f32],
306    second_moment: &mut [f32],
307    beta1: f32,
308    beta2: f32,
309    one_minus_beta1: f32,
310    one_minus_beta2: f32,
311    bc1_inv: f32,
312    bc2_inv: f32,
313    lr: f32,
314    epsilon: f32,
315    weight_decay: f32,
316) {
317    use std::arch::aarch64::*;
318    let len = weights.len();
319    let wp = weights.as_mut_ptr();
320    let gp = grad.as_ptr();
321    let mp = first_moment.as_mut_ptr();
322    let vp = second_moment.as_mut_ptr();
323    let beta1_v = vdupq_n_f32(beta1);
324    let beta2_v = vdupq_n_f32(beta2);
325    let omb1_v = vdupq_n_f32(one_minus_beta1);
326    let omb2_v = vdupq_n_f32(one_minus_beta2);
327    let bc1_v = vdupq_n_f32(bc1_inv);
328    let bc2_v = vdupq_n_f32(bc2_inv);
329    let lr_v = vdupq_n_f32(lr);
330    let eps_v = vdupq_n_f32(epsilon);
331    let wd_v = vdupq_n_f32(weight_decay);
332    let mut i = 0usize;
333    while i + 4 <= len {
334        let w = vld1q_f32(wp.add(i));
335        let raw_g = vld1q_f32(gp.add(i));
336        let g = vfmaq_f32(raw_g, wd_v, w);
337        let m_old = vld1q_f32(mp.add(i));
338        let v_old = vld1q_f32(vp.add(i));
339        let m_new = vfmaq_f32(vmulq_f32(g, omb1_v), m_old, beta1_v);
340        let grad_sq = vmulq_f32(g, g);
341        let v_new = vfmaq_f32(vmulq_f32(grad_sq, omb2_v), v_old, beta2_v);
342        vst1q_f32(mp.add(i), m_new);
343        vst1q_f32(vp.add(i), v_new);
344        let m_hat = vmulq_f32(m_new, bc1_v);
345        let v_hat = vmulq_f32(v_new, bc2_v);
346        let update = vdivq_f32(vmulq_f32(m_hat, lr_v), vaddq_f32(vsqrtq_f32(v_hat), eps_v));
347        vst1q_f32(wp.add(i), vsubq_f32(w, update));
348        i += 4;
349    }
350    while i < len {
351        let g = *gp.add(i) + weight_decay * *wp.add(i);
352        let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
353        let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
354        *mp.add(i) = m;
355        *vp.add(i) = v;
356        *wp.add(i) -= lr * (m * bc1_inv) / ((v * bc2_inv).sqrt() + epsilon);
357        i += 1;
358    }
359}
360
361// ── AVX implementation ──────────────────────────────────────────────────
362
363#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
364#[target_feature(enable = "avx")]
365#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
366unsafe fn adam_update_avx(
367    weights: &mut [f32],
368    grad: &[f32],
369    first_moment: &mut [f32],
370    second_moment: &mut [f32],
371    beta1: f32,
372    beta2: f32,
373    one_minus_beta1: f32,
374    one_minus_beta2: f32,
375    bc1_inv: f32,
376    bc2_inv: f32,
377    lr: f32,
378    epsilon: f32,
379    weight_decay: f32,
380) {
381    #[cfg(target_arch = "x86")]
382    use std::arch::x86::*;
383    #[cfg(target_arch = "x86_64")]
384    use std::arch::x86_64::*;
385    let len = weights.len();
386    let wp = weights.as_mut_ptr();
387    let gp = grad.as_ptr();
388    let mp = first_moment.as_mut_ptr();
389    let vp = second_moment.as_mut_ptr();
390    let beta1_v = _mm256_set1_ps(beta1);
391    let beta2_v = _mm256_set1_ps(beta2);
392    let omb1_v = _mm256_set1_ps(one_minus_beta1);
393    let omb2_v = _mm256_set1_ps(one_minus_beta2);
394    let bc1_v = _mm256_set1_ps(bc1_inv);
395    let bc2_v = _mm256_set1_ps(bc2_inv);
396    let lr_v = _mm256_set1_ps(lr);
397    let eps_v = _mm256_set1_ps(epsilon);
398    let wd_v = _mm256_set1_ps(weight_decay);
399    let mut i = 0usize;
400    while i + 8 <= len {
401        let w = _mm256_loadu_ps(wp.add(i));
402        let raw_g = _mm256_loadu_ps(gp.add(i));
403        let g = _mm256_add_ps(raw_g, _mm256_mul_ps(wd_v, w));
404        let m_old = _mm256_loadu_ps(mp.add(i));
405        let v_old = _mm256_loadu_ps(vp.add(i));
406        let m_new = _mm256_add_ps(_mm256_mul_ps(beta1_v, m_old), _mm256_mul_ps(omb1_v, g));
407        let grad_sq = _mm256_mul_ps(g, g);
408        let v_new = _mm256_add_ps(
409            _mm256_mul_ps(beta2_v, v_old),
410            _mm256_mul_ps(omb2_v, grad_sq),
411        );
412        _mm256_storeu_ps(mp.add(i), m_new);
413        _mm256_storeu_ps(vp.add(i), v_new);
414        let m_hat = _mm256_mul_ps(m_new, bc1_v);
415        let v_hat = _mm256_mul_ps(v_new, bc2_v);
416        let update = _mm256_div_ps(
417            _mm256_mul_ps(m_hat, lr_v),
418            _mm256_add_ps(_mm256_sqrt_ps(v_hat), eps_v),
419        );
420        _mm256_storeu_ps(wp.add(i), _mm256_sub_ps(w, update));
421        i += 8;
422    }
423    while i < len {
424        let g = *gp.add(i) + weight_decay * *wp.add(i);
425        let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
426        let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
427        *mp.add(i) = m;
428        *vp.add(i) = v;
429        *wp.add(i) -= lr * (m * bc1_inv) / ((v * bc2_inv).sqrt() + epsilon);
430        i += 1;
431    }
432}
433
434// ── SSE implementation ──────────────────────────────────────────────────
435
436#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
437#[target_feature(enable = "sse")]
438#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
439unsafe fn adam_update_sse(
440    weights: &mut [f32],
441    grad: &[f32],
442    first_moment: &mut [f32],
443    second_moment: &mut [f32],
444    beta1: f32,
445    beta2: f32,
446    one_minus_beta1: f32,
447    one_minus_beta2: f32,
448    bc1_inv: f32,
449    bc2_inv: f32,
450    lr: f32,
451    epsilon: f32,
452    weight_decay: f32,
453) {
454    #[cfg(target_arch = "x86")]
455    use std::arch::x86::*;
456    #[cfg(target_arch = "x86_64")]
457    use std::arch::x86_64::*;
458    let len = weights.len();
459    let wp = weights.as_mut_ptr();
460    let gp = grad.as_ptr();
461    let mp = first_moment.as_mut_ptr();
462    let vp = second_moment.as_mut_ptr();
463    let beta1_v = _mm_set1_ps(beta1);
464    let beta2_v = _mm_set1_ps(beta2);
465    let omb1_v = _mm_set1_ps(one_minus_beta1);
466    let omb2_v = _mm_set1_ps(one_minus_beta2);
467    let bc1_v = _mm_set1_ps(bc1_inv);
468    let bc2_v = _mm_set1_ps(bc2_inv);
469    let lr_v = _mm_set1_ps(lr);
470    let eps_v = _mm_set1_ps(epsilon);
471    let wd_v = _mm_set1_ps(weight_decay);
472    let mut i = 0usize;
473    while i + 4 <= len {
474        let w = _mm_loadu_ps(wp.add(i));
475        let raw_g = _mm_loadu_ps(gp.add(i));
476        let g = _mm_add_ps(raw_g, _mm_mul_ps(wd_v, w));
477        let m_old = _mm_loadu_ps(mp.add(i));
478        let v_old = _mm_loadu_ps(vp.add(i));
479        let m_new = _mm_add_ps(_mm_mul_ps(beta1_v, m_old), _mm_mul_ps(omb1_v, g));
480        let grad_sq = _mm_mul_ps(g, g);
481        let v_new = _mm_add_ps(_mm_mul_ps(beta2_v, v_old), _mm_mul_ps(omb2_v, grad_sq));
482        _mm_storeu_ps(mp.add(i), m_new);
483        _mm_storeu_ps(vp.add(i), v_new);
484        let m_hat = _mm_mul_ps(m_new, bc1_v);
485        let v_hat = _mm_mul_ps(v_new, bc2_v);
486        let update = _mm_div_ps(
487            _mm_mul_ps(m_hat, lr_v),
488            _mm_add_ps(_mm_sqrt_ps(v_hat), eps_v),
489        );
490        _mm_storeu_ps(wp.add(i), _mm_sub_ps(w, update));
491        i += 4;
492    }
493    while i < len {
494        let g = *gp.add(i) + weight_decay * *wp.add(i);
495        let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
496        let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
497        *mp.add(i) = m;
498        *vp.add(i) = v;
499        *wp.add(i) -= lr * (m * bc1_inv) / ((v * bc2_inv).sqrt() + epsilon);
500        i += 1;
501    }
502}