Skip to main content

yscv_optim/
lamb.rs

1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use yscv_autograd::{Graph, NodeId};
5use yscv_tensor::Tensor;
6
7use super::validate::{validate_beta1, validate_beta2, validate_epsilon, validate_lr};
8use super::{LearningRate, OptimError};
9
10#[derive(Debug, Clone)]
11struct LambState {
12    first_moment: Tensor,
13    second_moment: Tensor,
14    step: u64,
15}
16
17impl LambState {
18    fn new(shape: &[usize]) -> Result<Self, OptimError> {
19        Ok(Self {
20            first_moment: Tensor::zeros(shape.to_vec())?,
21            second_moment: Tensor::zeros(shape.to_vec())?,
22            step: 0,
23        })
24    }
25
26    fn reset(&mut self, shape: &[usize]) -> Result<(), OptimError> {
27        *self = Self::new(shape)?;
28        Ok(())
29    }
30}
31
32/// Layer-wise Adaptive Moments optimizer for Batch training (LAMB).
33///
34/// Combines Adam-style adaptive moment estimation with layer-wise trust ratio
35/// scaling for stable large-batch training.
36#[derive(Debug, Clone)]
37pub struct Lamb {
38    lr: f32,
39    beta1: f32,
40    beta2: f32,
41    epsilon: f32,
42    weight_decay: f32,
43    state: HashMap<u64, LambState>,
44}
45
46impl Lamb {
47    /// Creates LAMB with required learning rate.
48    pub fn new(lr: f32) -> Result<Self, OptimError> {
49        validate_lr(lr)?;
50        Ok(Self {
51            lr,
52            beta1: 0.9,
53            beta2: 0.999,
54            epsilon: 1e-6,
55            weight_decay: 0.0,
56            state: HashMap::new(),
57        })
58    }
59
60    /// Sets beta coefficients `(beta1, beta2)` used for computing running averages.
61    pub fn with_betas(mut self, beta1: f32, beta2: f32) -> Result<Self, OptimError> {
62        validate_beta1(beta1)?;
63        validate_beta2(beta2)?;
64        self.beta1 = beta1;
65        self.beta2 = beta2;
66        Ok(self)
67    }
68
69    /// Sets L2 weight decay factor in `[0, +inf)`.
70    pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
71        if !weight_decay.is_finite() || weight_decay < 0.0 {
72            return Err(OptimError::InvalidWeightDecay { weight_decay });
73        }
74        self.weight_decay = weight_decay;
75        Ok(self)
76    }
77
78    /// Sets epsilon value, must be finite and `> 0`.
79    pub fn with_epsilon(mut self, epsilon: f32) -> Result<Self, OptimError> {
80        validate_epsilon(epsilon)?;
81        self.epsilon = epsilon;
82        Ok(self)
83    }
84
85    /// Drops optimizer state (for example when restarting training).
86    pub fn clear_state(&mut self) {
87        self.state.clear();
88    }
89
90    /// Returns current learning rate.
91    pub fn learning_rate(&self) -> f32 {
92        self.lr
93    }
94
95    /// Overrides current learning rate.
96    pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
97        validate_lr(lr)?;
98        self.lr = lr;
99        Ok(())
100    }
101
102    /// Applies one update to raw tensor weights.
103    pub fn step(
104        &mut self,
105        parameter_id: u64,
106        weights: &mut Tensor,
107        grad: &Tensor,
108    ) -> Result<(), OptimError> {
109        if weights.shape() != grad.shape() {
110            return Err(OptimError::ShapeMismatch {
111                weights: weights.shape().to_vec(),
112                grad: grad.shape().to_vec(),
113            });
114        }
115
116        let state = match self.state.entry(parameter_id) {
117            Entry::Occupied(entry) => entry.into_mut(),
118            Entry::Vacant(entry) => entry.insert(LambState::new(weights.shape())?),
119        };
120        if state.first_moment.shape() != weights.shape()
121            || state.second_moment.shape() != weights.shape()
122        {
123            state.reset(weights.shape())?;
124        }
125
126        state.step = state.step.saturating_add(1);
127        let step_f64 = state.step as f64;
128        let bias_correction1 =
129            (1.0 - (self.beta1 as f64).powf(step_f64)).max(f64::MIN_POSITIVE) as f32;
130        let bias_correction2 =
131            (1.0 - (self.beta2 as f64).powf(step_f64)).max(f64::MIN_POSITIVE) as f32;
132
133        let first_moment = state.first_moment.data_mut();
134        let second_moment = state.second_moment.data_mut();
135        let grad_data = grad.data();
136        let weights_data = weights.data_mut();
137
138        let beta1 = self.beta1;
139        let beta2 = self.beta2;
140        let one_minus_beta1 = 1.0 - beta1;
141        let one_minus_beta2 = 1.0 - beta2;
142        let bias_correction1_inv = 1.0 / bias_correction1;
143        let bias_correction2_inv = 1.0 / bias_correction2;
144        let epsilon = self.epsilon;
145        let weight_decay = self.weight_decay;
146
147        let (w_norm_sq, step_norm_sq) = lamb_pass1_inner(
148            weights_data,
149            grad_data,
150            first_moment,
151            second_moment,
152            beta1,
153            beta2,
154            one_minus_beta1,
155            one_minus_beta2,
156            bias_correction1_inv,
157            bias_correction2_inv,
158            epsilon,
159            weight_decay,
160        );
161
162        let w_norm = w_norm_sq.sqrt();
163        let step_norm = step_norm_sq.sqrt();
164        let trust_ratio = if w_norm > 0.0 && step_norm > 0.0 {
165            w_norm / step_norm
166        } else {
167            1.0
168        };
169        let scaled_lr = self.lr * trust_ratio;
170
171        lamb_pass2_inner(
172            weights_data,
173            first_moment,
174            second_moment,
175            bias_correction1_inv,
176            bias_correction2_inv,
177            scaled_lr,
178            epsilon,
179            weight_decay,
180        );
181
182        Ok(())
183    }
184
185    /// Applies one update to a trainable graph node by its `NodeId`.
186    pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
187        if !graph.requires_grad(node)? {
188            return Ok(());
189        }
190
191        let grad = match graph.grad(node)? {
192            Some(grad) => grad.clone(),
193            None => return Err(OptimError::MissingGradient { node: node.0 }),
194        };
195        let weights = graph.value_mut(node)?;
196        self.step(node.0 as u64, weights, &grad)
197    }
198}
199
200impl LearningRate for Lamb {
201    fn learning_rate(&self) -> f32 {
202        Lamb::learning_rate(self)
203    }
204
205    fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
206        Lamb::set_learning_rate(self, lr)
207    }
208}
209
210// ── SIMD-accelerated LAMB pass 1: update moments + compute norms ────────
211
212/// Pass 1: update moments, return `(w_norm_sq, step_norm_sq)`.
213#[allow(clippy::too_many_arguments, unsafe_code)]
214fn lamb_pass1_inner(
215    weights: &mut [f32],
216    grad: &[f32],
217    first_moment: &mut [f32],
218    second_moment: &mut [f32],
219    beta1: f32,
220    beta2: f32,
221    one_minus_beta1: f32,
222    one_minus_beta2: f32,
223    bc1_inv: f32,
224    bc2_inv: f32,
225    epsilon: f32,
226    weight_decay: f32,
227) -> (f32, f32) {
228    #[cfg(target_arch = "aarch64")]
229    if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
230        return unsafe {
231            lamb_pass1_neon(
232                weights,
233                grad,
234                first_moment,
235                second_moment,
236                beta1,
237                beta2,
238                one_minus_beta1,
239                one_minus_beta2,
240                bc1_inv,
241                bc2_inv,
242                epsilon,
243                weight_decay,
244            )
245        };
246    }
247
248    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
249    if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
250        return unsafe {
251            lamb_pass1_avx(
252                weights,
253                grad,
254                first_moment,
255                second_moment,
256                beta1,
257                beta2,
258                one_minus_beta1,
259                one_minus_beta2,
260                bc1_inv,
261                bc2_inv,
262                epsilon,
263                weight_decay,
264            )
265        };
266    }
267
268    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
269    if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
270        return unsafe {
271            lamb_pass1_sse(
272                weights,
273                grad,
274                first_moment,
275                second_moment,
276                beta1,
277                beta2,
278                one_minus_beta1,
279                one_minus_beta2,
280                bc1_inv,
281                bc2_inv,
282                epsilon,
283                weight_decay,
284            )
285        };
286    }
287
288    let len = weights.len();
289    let wp = weights.as_mut_ptr();
290    let gp = grad.as_ptr();
291    let mp = first_moment.as_mut_ptr();
292    let vp = second_moment.as_mut_ptr();
293    let mut w_norm_sq: f32 = 0.0;
294    let mut step_norm_sq: f32 = 0.0;
295    for i in 0..len {
296        unsafe {
297            let w = *wp.add(i);
298            let g = *gp.add(i);
299            let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
300            let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
301            *mp.add(i) = m;
302            *vp.add(i) = v;
303            let m_hat = m * bc1_inv;
304            let v_hat = v * bc2_inv;
305            let s = m_hat / (v_hat.sqrt() + epsilon) + weight_decay * w;
306            w_norm_sq += w * w;
307            step_norm_sq += s * s;
308        }
309    }
310    (w_norm_sq, step_norm_sq)
311}
312
313/// Pass 2: apply trust-ratio-scaled update from already-updated moments.
314#[allow(clippy::too_many_arguments, unsafe_code)]
315fn lamb_pass2_inner(
316    weights: &mut [f32],
317    first_moment: &[f32],
318    second_moment: &[f32],
319    bc1_inv: f32,
320    bc2_inv: f32,
321    scaled_lr: f32,
322    epsilon: f32,
323    weight_decay: f32,
324) {
325    #[cfg(target_arch = "aarch64")]
326    if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
327        unsafe {
328            lamb_pass2_neon(
329                weights,
330                first_moment,
331                second_moment,
332                bc1_inv,
333                bc2_inv,
334                scaled_lr,
335                epsilon,
336                weight_decay,
337            );
338        }
339        return;
340    }
341
342    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
343    if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
344        unsafe {
345            lamb_pass2_avx(
346                weights,
347                first_moment,
348                second_moment,
349                bc1_inv,
350                bc2_inv,
351                scaled_lr,
352                epsilon,
353                weight_decay,
354            );
355        }
356        return;
357    }
358
359    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
360    if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
361        unsafe {
362            lamb_pass2_sse(
363                weights,
364                first_moment,
365                second_moment,
366                bc1_inv,
367                bc2_inv,
368                scaled_lr,
369                epsilon,
370                weight_decay,
371            );
372        }
373        return;
374    }
375
376    let len = weights.len();
377    let wp = weights.as_mut_ptr();
378    let mp = first_moment.as_ptr();
379    let vp = second_moment.as_ptr();
380    for i in 0..len {
381        unsafe {
382            let m_hat = *mp.add(i) * bc1_inv;
383            let v_hat = *vp.add(i) * bc2_inv;
384            let s = m_hat / (v_hat.sqrt() + epsilon) + weight_decay * *wp.add(i);
385            *wp.add(i) -= scaled_lr * s;
386        }
387    }
388}
389
390// ── NEON implementations ────────────────────────────────────────────────
391
392#[cfg(target_arch = "aarch64")]
393#[target_feature(enable = "neon")]
394#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
395unsafe fn lamb_pass1_neon(
396    weights: &mut [f32],
397    grad: &[f32],
398    first_moment: &mut [f32],
399    second_moment: &mut [f32],
400    beta1: f32,
401    beta2: f32,
402    one_minus_beta1: f32,
403    one_minus_beta2: f32,
404    bc1_inv: f32,
405    bc2_inv: f32,
406    epsilon: f32,
407    weight_decay: f32,
408) -> (f32, f32) {
409    use std::arch::aarch64::*;
410    let len = weights.len();
411    let wp = weights.as_mut_ptr();
412    let gp = grad.as_ptr();
413    let mp = first_moment.as_mut_ptr();
414    let vp = second_moment.as_mut_ptr();
415    let beta1_v = vdupq_n_f32(beta1);
416    let beta2_v = vdupq_n_f32(beta2);
417    let omb1_v = vdupq_n_f32(one_minus_beta1);
418    let omb2_v = vdupq_n_f32(one_minus_beta2);
419    let bc1_v = vdupq_n_f32(bc1_inv);
420    let bc2_v = vdupq_n_f32(bc2_inv);
421    let eps_v = vdupq_n_f32(epsilon);
422    let wd_v = vdupq_n_f32(weight_decay);
423    let mut w_norm_acc = vdupq_n_f32(0.0);
424    let mut s_norm_acc = vdupq_n_f32(0.0);
425    let mut i = 0usize;
426    while i + 4 <= len {
427        let w = vld1q_f32(wp.add(i));
428        let g = vld1q_f32(gp.add(i));
429        let m_old = vld1q_f32(mp.add(i));
430        let v_old = vld1q_f32(vp.add(i));
431        let m_new = vfmaq_f32(vmulq_f32(g, omb1_v), m_old, beta1_v);
432        let grad_sq = vmulq_f32(g, g);
433        let v_new = vfmaq_f32(vmulq_f32(grad_sq, omb2_v), v_old, beta2_v);
434        vst1q_f32(mp.add(i), m_new);
435        vst1q_f32(vp.add(i), v_new);
436        let m_hat = vmulq_f32(m_new, bc1_v);
437        let v_hat = vmulq_f32(v_new, bc2_v);
438        let s = vfmaq_f32(
439            vdivq_f32(m_hat, vaddq_f32(vsqrtq_f32(v_hat), eps_v)),
440            wd_v,
441            w,
442        );
443        w_norm_acc = vfmaq_f32(w_norm_acc, w, w);
444        s_norm_acc = vfmaq_f32(s_norm_acc, s, s);
445        i += 4;
446    }
447    let mut w_norm_sq = vaddvq_f32(w_norm_acc);
448    let mut step_norm_sq = vaddvq_f32(s_norm_acc);
449    while i < len {
450        let w = *wp.add(i);
451        let g = *gp.add(i);
452        let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
453        let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
454        *mp.add(i) = m;
455        *vp.add(i) = v;
456        let m_hat = m * bc1_inv;
457        let v_hat = v * bc2_inv;
458        let s = m_hat / (v_hat.sqrt() + epsilon) + weight_decay * w;
459        w_norm_sq += w * w;
460        step_norm_sq += s * s;
461        i += 1;
462    }
463    (w_norm_sq, step_norm_sq)
464}
465
466#[cfg(target_arch = "aarch64")]
467#[target_feature(enable = "neon")]
468#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
469unsafe fn lamb_pass2_neon(
470    weights: &mut [f32],
471    first_moment: &[f32],
472    second_moment: &[f32],
473    bc1_inv: f32,
474    bc2_inv: f32,
475    scaled_lr: f32,
476    epsilon: f32,
477    weight_decay: f32,
478) {
479    use std::arch::aarch64::*;
480    let len = weights.len();
481    let wp = weights.as_mut_ptr();
482    let mp = first_moment.as_ptr();
483    let vp = second_moment.as_ptr();
484    let bc1_v = vdupq_n_f32(bc1_inv);
485    let bc2_v = vdupq_n_f32(bc2_inv);
486    let lr_v = vdupq_n_f32(scaled_lr);
487    let eps_v = vdupq_n_f32(epsilon);
488    let wd_v = vdupq_n_f32(weight_decay);
489    let mut i = 0usize;
490    while i + 4 <= len {
491        let w = vld1q_f32(wp.add(i));
492        let m_hat = vmulq_f32(vld1q_f32(mp.add(i)), bc1_v);
493        let v_hat = vmulq_f32(vld1q_f32(vp.add(i)), bc2_v);
494        let s = vfmaq_f32(
495            vdivq_f32(m_hat, vaddq_f32(vsqrtq_f32(v_hat), eps_v)),
496            wd_v,
497            w,
498        );
499        vst1q_f32(wp.add(i), vsubq_f32(w, vmulq_f32(lr_v, s)));
500        i += 4;
501    }
502    while i < len {
503        let m_hat = *mp.add(i) * bc1_inv;
504        let v_hat = *vp.add(i) * bc2_inv;
505        let s = m_hat / (v_hat.sqrt() + epsilon) + weight_decay * *wp.add(i);
506        *wp.add(i) -= scaled_lr * s;
507        i += 1;
508    }
509}
510
511// ── AVX implementations ─────────────────────────────────────────────────
512
513#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
514#[target_feature(enable = "avx")]
515#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
516unsafe fn lamb_pass1_avx(
517    weights: &mut [f32],
518    grad: &[f32],
519    first_moment: &mut [f32],
520    second_moment: &mut [f32],
521    beta1: f32,
522    beta2: f32,
523    one_minus_beta1: f32,
524    one_minus_beta2: f32,
525    bc1_inv: f32,
526    bc2_inv: f32,
527    epsilon: f32,
528    weight_decay: f32,
529) -> (f32, f32) {
530    #[cfg(target_arch = "x86")]
531    use std::arch::x86::*;
532    #[cfg(target_arch = "x86_64")]
533    use std::arch::x86_64::*;
534    let len = weights.len();
535    let wp = weights.as_mut_ptr();
536    let gp = grad.as_ptr();
537    let mp = first_moment.as_mut_ptr();
538    let vp = second_moment.as_mut_ptr();
539    let beta1_v = _mm256_set1_ps(beta1);
540    let beta2_v = _mm256_set1_ps(beta2);
541    let omb1_v = _mm256_set1_ps(one_minus_beta1);
542    let omb2_v = _mm256_set1_ps(one_minus_beta2);
543    let bc1_v = _mm256_set1_ps(bc1_inv);
544    let bc2_v = _mm256_set1_ps(bc2_inv);
545    let eps_v = _mm256_set1_ps(epsilon);
546    let wd_v = _mm256_set1_ps(weight_decay);
547    let mut w_norm_acc = _mm256_setzero_ps();
548    let mut s_norm_acc = _mm256_setzero_ps();
549    let mut i = 0usize;
550    while i + 8 <= len {
551        let w = _mm256_loadu_ps(wp.add(i));
552        let g = _mm256_loadu_ps(gp.add(i));
553        let m_old = _mm256_loadu_ps(mp.add(i));
554        let v_old = _mm256_loadu_ps(vp.add(i));
555        let m_new = _mm256_add_ps(_mm256_mul_ps(beta1_v, m_old), _mm256_mul_ps(omb1_v, g));
556        let grad_sq = _mm256_mul_ps(g, g);
557        let v_new = _mm256_add_ps(
558            _mm256_mul_ps(beta2_v, v_old),
559            _mm256_mul_ps(omb2_v, grad_sq),
560        );
561        _mm256_storeu_ps(mp.add(i), m_new);
562        _mm256_storeu_ps(vp.add(i), v_new);
563        let m_hat = _mm256_mul_ps(m_new, bc1_v);
564        let v_hat = _mm256_mul_ps(v_new, bc2_v);
565        let s = _mm256_add_ps(
566            _mm256_div_ps(m_hat, _mm256_add_ps(_mm256_sqrt_ps(v_hat), eps_v)),
567            _mm256_mul_ps(wd_v, w),
568        );
569        w_norm_acc = _mm256_add_ps(w_norm_acc, _mm256_mul_ps(w, w));
570        s_norm_acc = _mm256_add_ps(s_norm_acc, _mm256_mul_ps(s, s));
571        i += 8;
572    }
573    // Horizontal sum of 8-wide accumulators
574    let w_lo = _mm256_castps256_ps128(w_norm_acc);
575    let w_hi = _mm256_extractf128_ps(w_norm_acc, 1);
576    let w_sum4 = _mm_add_ps(w_lo, w_hi);
577    let w_shuf = _mm_movehdup_ps(w_sum4);
578    let w_sum2 = _mm_add_ps(w_sum4, w_shuf);
579    let w_shuf2 = _mm_movehl_ps(w_sum2, w_sum2);
580    let mut w_norm_sq = _mm_cvtss_f32(_mm_add_ss(w_sum2, w_shuf2));
581
582    let s_lo = _mm256_castps256_ps128(s_norm_acc);
583    let s_hi = _mm256_extractf128_ps(s_norm_acc, 1);
584    let s_sum4 = _mm_add_ps(s_lo, s_hi);
585    let s_shuf = _mm_movehdup_ps(s_sum4);
586    let s_sum2 = _mm_add_ps(s_sum4, s_shuf);
587    let s_shuf2 = _mm_movehl_ps(s_sum2, s_sum2);
588    let mut step_norm_sq = _mm_cvtss_f32(_mm_add_ss(s_sum2, s_shuf2));
589
590    while i < len {
591        let w = *wp.add(i);
592        let g = *gp.add(i);
593        let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
594        let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
595        *mp.add(i) = m;
596        *vp.add(i) = v;
597        let m_hat = m * bc1_inv;
598        let v_hat = v * bc2_inv;
599        let s = m_hat / (v_hat.sqrt() + epsilon) + weight_decay * w;
600        w_norm_sq += w * w;
601        step_norm_sq += s * s;
602        i += 1;
603    }
604    (w_norm_sq, step_norm_sq)
605}
606
607#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
608#[target_feature(enable = "avx")]
609#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
610unsafe fn lamb_pass2_avx(
611    weights: &mut [f32],
612    first_moment: &[f32],
613    second_moment: &[f32],
614    bc1_inv: f32,
615    bc2_inv: f32,
616    scaled_lr: f32,
617    epsilon: f32,
618    weight_decay: f32,
619) {
620    #[cfg(target_arch = "x86")]
621    use std::arch::x86::*;
622    #[cfg(target_arch = "x86_64")]
623    use std::arch::x86_64::*;
624    let len = weights.len();
625    let wp = weights.as_mut_ptr();
626    let mp = first_moment.as_ptr();
627    let vp = second_moment.as_ptr();
628    let bc1_v = _mm256_set1_ps(bc1_inv);
629    let bc2_v = _mm256_set1_ps(bc2_inv);
630    let lr_v = _mm256_set1_ps(scaled_lr);
631    let eps_v = _mm256_set1_ps(epsilon);
632    let wd_v = _mm256_set1_ps(weight_decay);
633    let mut i = 0usize;
634    while i + 8 <= len {
635        let w = _mm256_loadu_ps(wp.add(i));
636        let m_hat = _mm256_mul_ps(_mm256_loadu_ps(mp.add(i)), bc1_v);
637        let v_hat = _mm256_mul_ps(_mm256_loadu_ps(vp.add(i)), bc2_v);
638        let s = _mm256_add_ps(
639            _mm256_div_ps(m_hat, _mm256_add_ps(_mm256_sqrt_ps(v_hat), eps_v)),
640            _mm256_mul_ps(wd_v, w),
641        );
642        _mm256_storeu_ps(wp.add(i), _mm256_sub_ps(w, _mm256_mul_ps(lr_v, s)));
643        i += 8;
644    }
645    while i < len {
646        let m_hat = *mp.add(i) * bc1_inv;
647        let v_hat = *vp.add(i) * bc2_inv;
648        let s = m_hat / (v_hat.sqrt() + epsilon) + weight_decay * *wp.add(i);
649        *wp.add(i) -= scaled_lr * s;
650        i += 1;
651    }
652}
653
654// ── SSE implementations ─────────────────────────────────────────────────
655
656#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
657#[target_feature(enable = "sse")]
658#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
659unsafe fn lamb_pass1_sse(
660    weights: &mut [f32],
661    grad: &[f32],
662    first_moment: &mut [f32],
663    second_moment: &mut [f32],
664    beta1: f32,
665    beta2: f32,
666    one_minus_beta1: f32,
667    one_minus_beta2: f32,
668    bc1_inv: f32,
669    bc2_inv: f32,
670    epsilon: f32,
671    weight_decay: f32,
672) -> (f32, f32) {
673    #[cfg(target_arch = "x86")]
674    use std::arch::x86::*;
675    #[cfg(target_arch = "x86_64")]
676    use std::arch::x86_64::*;
677    let len = weights.len();
678    let wp = weights.as_mut_ptr();
679    let gp = grad.as_ptr();
680    let mp = first_moment.as_mut_ptr();
681    let vp = second_moment.as_mut_ptr();
682    let beta1_v = _mm_set1_ps(beta1);
683    let beta2_v = _mm_set1_ps(beta2);
684    let omb1_v = _mm_set1_ps(one_minus_beta1);
685    let omb2_v = _mm_set1_ps(one_minus_beta2);
686    let bc1_v = _mm_set1_ps(bc1_inv);
687    let bc2_v = _mm_set1_ps(bc2_inv);
688    let eps_v = _mm_set1_ps(epsilon);
689    let wd_v = _mm_set1_ps(weight_decay);
690    let mut w_norm_acc = _mm_setzero_ps();
691    let mut s_norm_acc = _mm_setzero_ps();
692    let mut i = 0usize;
693    while i + 4 <= len {
694        let w = _mm_loadu_ps(wp.add(i));
695        let g = _mm_loadu_ps(gp.add(i));
696        let m_old = _mm_loadu_ps(mp.add(i));
697        let v_old = _mm_loadu_ps(vp.add(i));
698        let m_new = _mm_add_ps(_mm_mul_ps(beta1_v, m_old), _mm_mul_ps(omb1_v, g));
699        let grad_sq = _mm_mul_ps(g, g);
700        let v_new = _mm_add_ps(_mm_mul_ps(beta2_v, v_old), _mm_mul_ps(omb2_v, grad_sq));
701        _mm_storeu_ps(mp.add(i), m_new);
702        _mm_storeu_ps(vp.add(i), v_new);
703        let m_hat = _mm_mul_ps(m_new, bc1_v);
704        let v_hat = _mm_mul_ps(v_new, bc2_v);
705        let s = _mm_add_ps(
706            _mm_div_ps(m_hat, _mm_add_ps(_mm_sqrt_ps(v_hat), eps_v)),
707            _mm_mul_ps(wd_v, w),
708        );
709        w_norm_acc = _mm_add_ps(w_norm_acc, _mm_mul_ps(w, w));
710        s_norm_acc = _mm_add_ps(s_norm_acc, _mm_mul_ps(s, s));
711        i += 4;
712    }
713    // Horizontal sum of 4-wide accumulators
714    let w_shuf = _mm_movehdup_ps(w_norm_acc);
715    let w_sum2 = _mm_add_ps(w_norm_acc, w_shuf);
716    let w_shuf2 = _mm_movehl_ps(w_sum2, w_sum2);
717    let mut w_norm_sq = _mm_cvtss_f32(_mm_add_ss(w_sum2, w_shuf2));
718
719    let s_shuf = _mm_movehdup_ps(s_norm_acc);
720    let s_sum2 = _mm_add_ps(s_norm_acc, s_shuf);
721    let s_shuf2 = _mm_movehl_ps(s_sum2, s_sum2);
722    let mut step_norm_sq = _mm_cvtss_f32(_mm_add_ss(s_sum2, s_shuf2));
723
724    while i < len {
725        let w = *wp.add(i);
726        let g = *gp.add(i);
727        let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
728        let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
729        *mp.add(i) = m;
730        *vp.add(i) = v;
731        let m_hat = m * bc1_inv;
732        let v_hat = v * bc2_inv;
733        let s = m_hat / (v_hat.sqrt() + epsilon) + weight_decay * w;
734        w_norm_sq += w * w;
735        step_norm_sq += s * s;
736        i += 1;
737    }
738    (w_norm_sq, step_norm_sq)
739}
740
741#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
742#[target_feature(enable = "sse")]
743#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
744unsafe fn lamb_pass2_sse(
745    weights: &mut [f32],
746    first_moment: &[f32],
747    second_moment: &[f32],
748    bc1_inv: f32,
749    bc2_inv: f32,
750    scaled_lr: f32,
751    epsilon: f32,
752    weight_decay: f32,
753) {
754    #[cfg(target_arch = "x86")]
755    use std::arch::x86::*;
756    #[cfg(target_arch = "x86_64")]
757    use std::arch::x86_64::*;
758    let len = weights.len();
759    let wp = weights.as_mut_ptr();
760    let mp = first_moment.as_ptr();
761    let vp = second_moment.as_ptr();
762    let bc1_v = _mm_set1_ps(bc1_inv);
763    let bc2_v = _mm_set1_ps(bc2_inv);
764    let lr_v = _mm_set1_ps(scaled_lr);
765    let eps_v = _mm_set1_ps(epsilon);
766    let wd_v = _mm_set1_ps(weight_decay);
767    let mut i = 0usize;
768    while i + 4 <= len {
769        let w = _mm_loadu_ps(wp.add(i));
770        let m_hat = _mm_mul_ps(_mm_loadu_ps(mp.add(i)), bc1_v);
771        let v_hat = _mm_mul_ps(_mm_loadu_ps(vp.add(i)), bc2_v);
772        let s = _mm_add_ps(
773            _mm_div_ps(m_hat, _mm_add_ps(_mm_sqrt_ps(v_hat), eps_v)),
774            _mm_mul_ps(wd_v, w),
775        );
776        _mm_storeu_ps(wp.add(i), _mm_sub_ps(w, _mm_mul_ps(lr_v, s)));
777        i += 4;
778    }
779    while i < len {
780        let m_hat = *mp.add(i) * bc1_inv;
781        let v_hat = *vp.add(i) * bc2_inv;
782        let s = m_hat / (v_hat.sqrt() + epsilon) + weight_decay * *wp.add(i);
783        *wp.add(i) -= scaled_lr * s;
784        i += 1;
785    }
786}