Skip to main content

yscv_kernels/ops/
simd.rs

1#[cfg(target_arch = "aarch64")]
2use std::arch::aarch64::{
3    float32x4_t, vaddq_f32, vdivq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vmaxq_f32, vminq_f32,
4    vmulq_f32, vnegq_f32, vst1q_f32, vsubq_f32,
5};
6#[cfg(target_arch = "x86")]
7use std::arch::x86::{
8    __m128, __m256, _mm_add_ps, _mm_castsi128_ps, _mm_cvtepi32_ps, _mm_cvtps_epi32, _mm_loadu_ps,
9    _mm_max_ps, _mm_min_ps, _mm_mul_ps, _mm_set1_epi32, _mm_set1_ps, _mm_setzero_ps, _mm_storeu_ps,
10    _mm_sub_ps, _mm256_add_ps, _mm256_castsi256_ps, _mm256_cvtepi32_ps, _mm256_cvtps_epi32,
11    _mm256_loadu_ps, _mm256_max_ps, _mm256_min_ps, _mm256_mul_ps, _mm256_set1_epi32,
12    _mm256_set1_ps, _mm256_setzero_ps, _mm256_storeu_ps, _mm256_sub_ps,
13};
14#[cfg(target_arch = "x86_64")]
15use std::arch::x86_64::{
16    __m128, __m256, _mm_add_ps, _mm_castsi128_ps, _mm_cvtepi32_ps, _mm_cvtps_epi32, _mm_loadu_ps,
17    _mm_max_ps, _mm_min_ps, _mm_mul_ps, _mm_set1_epi32, _mm_set1_ps, _mm_setzero_ps, _mm_storeu_ps,
18    _mm_sub_ps, _mm256_add_ps, _mm256_castsi256_ps, _mm256_cvtepi32_ps, _mm256_cvtps_epi32,
19    _mm256_loadu_ps, _mm256_max_ps, _mm256_min_ps, _mm256_mul_ps, _mm256_set1_epi32,
20    _mm256_set1_ps, _mm256_setzero_ps, _mm256_storeu_ps, _mm256_sub_ps,
21};
22
23use super::config::BinaryKind;
24
25#[cfg(all(feature = "mkl", any(target_arch = "x86", target_arch = "x86_64")))]
26#[allow(unsafe_code, dead_code)]
27unsafe extern "C" {
28    fn vsAdd(n: i32, a: *const f32, b: *const f32, y: *mut f32);
29    fn vsSub(n: i32, a: *const f32, b: *const f32, y: *mut f32);
30    fn vsMul(n: i32, a: *const f32, b: *const f32, y: *mut f32);
31    fn vsDiv(n: i32, a: *const f32, b: *const f32, y: *mut f32);
32    fn vsExp(n: i32, a: *const f32, y: *mut f32);
33    fn vsSqrt(n: i32, a: *const f32, y: *mut f32);
34    fn vsLn(n: i32, a: *const f32, y: *mut f32);
35}
36
37#[cfg(all(feature = "armpl", target_arch = "aarch64", not(target_os = "macos")))]
38#[allow(unsafe_code, dead_code)]
39unsafe extern "C" {
40    fn armpl_svexp_f32(n: i32, x: *const f32, y: *mut f32);
41    fn armpl_svadd_f32(n: i32, a: *const f32, b: *const f32, y: *mut f32);
42    fn armpl_svsub_f32(n: i32, a: *const f32, b: *const f32, y: *mut f32);
43    fn armpl_svmul_f32(n: i32, a: *const f32, b: *const f32, y: *mut f32);
44    fn armpl_svlog_f32(n: i32, x: *const f32, y: *mut f32);
45    fn armpl_svsqrt_f32(n: i32, x: *const f32, y: *mut f32);
46}
47
48#[cfg(target_os = "macos")]
49#[allow(unsafe_code, dead_code)]
50unsafe extern "C" {
51    fn vvexpf(result: *mut f32, input: *const f32, count: *const i32);
52    fn vDSP_vadd(
53        __A: *const f32,
54        __IA: i32,
55        __B: *const f32,
56        __IB: i32,
57        __C: *mut f32,
58        __IC: i32,
59        __N: u32,
60    );
61    fn vDSP_vsub(
62        __B: *const f32,
63        __IB: i32,
64        __A: *const f32,
65        __IA: i32,
66        __C: *mut f32,
67        __IC: i32,
68        __N: u32,
69    );
70    fn vDSP_vmul(
71        __A: *const f32,
72        __IA: i32,
73        __B: *const f32,
74        __IB: i32,
75        __C: *mut f32,
76        __IC: i32,
77        __N: u32,
78    );
79}
80
81// ===========================================================================
82// ReLU dispatch
83// ===========================================================================
84
85#[allow(unsafe_code)]
86#[inline]
87pub fn relu_slice_dispatch(values: &mut [f32]) {
88    if cfg!(miri) {
89        // SAFETY: scalar path only reads/writes within `values` bounds.
90        unsafe {
91            relu_slice_scalar(values);
92        }
93        return;
94    }
95
96    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
97    {
98        if std::is_x86_feature_detected!("avx") {
99            // SAFETY: guarded by runtime feature detection.
100            unsafe {
101                relu_slice_avx(values);
102            }
103            return;
104        }
105        if std::is_x86_feature_detected!("sse") {
106            // SAFETY: guarded by runtime feature detection.
107            unsafe {
108                relu_slice_sse(values);
109            }
110            return;
111        }
112    }
113
114    #[cfg(target_arch = "aarch64")]
115    {
116        if std::arch::is_aarch64_feature_detected!("neon") {
117            // SAFETY: guarded by runtime feature detection.
118            unsafe {
119                relu_slice_neon(values);
120            }
121            return;
122        }
123    }
124
125    // SAFETY: scalar path only reads/writes within `values` bounds.
126    unsafe {
127        relu_slice_scalar(values);
128    }
129}
130
131/// Two-argument ReLU: `output[i] = max(0, input[i])`.
132///
133/// Avoids the clone+in-place pattern by reading from `input` and writing to
134/// `output` in a single pass, halving memory traffic.
135#[allow(unsafe_code)]
136#[inline]
137pub fn relu_to_slice_dispatch(input: &[f32], output: &mut [f32]) {
138    debug_assert_eq!(input.len(), output.len());
139
140    if cfg!(miri) {
141        // SAFETY: scalar path only reads/writes within bounds.
142        unsafe {
143            relu_to_slice_scalar(input, output);
144        }
145        return;
146    }
147
148    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
149    {
150        if std::is_x86_feature_detected!("avx") {
151            // SAFETY: guarded by runtime feature detection.
152            unsafe {
153                relu_to_slice_avx(input, output);
154            }
155            return;
156        }
157        if std::is_x86_feature_detected!("sse") {
158            // SAFETY: guarded by runtime feature detection.
159            unsafe {
160                relu_to_slice_sse(input, output);
161            }
162            return;
163        }
164    }
165
166    #[cfg(target_arch = "aarch64")]
167    {
168        if std::arch::is_aarch64_feature_detected!("neon") {
169            // SAFETY: guarded by runtime feature detection.
170            unsafe {
171                relu_to_slice_neon(input, output);
172            }
173            return;
174        }
175    }
176
177    // SAFETY: scalar path only reads/writes within bounds.
178    unsafe {
179        relu_to_slice_scalar(input, output);
180    }
181}
182
183#[inline]
184#[allow(dead_code)]
185pub(crate) fn sigmoid_slice(values: &mut [f32]) {
186    for value in values {
187        *value = sigmoid_scalar(*value);
188    }
189}
190
191#[inline]
192pub(crate) fn sigmoid_scalar(value: f32) -> f32 {
193    if value >= 0.0 {
194        let z = (-value).exp();
195        1.0 / (1.0 + z)
196    } else {
197        let z = value.exp();
198        z / (1.0 + z)
199    }
200}
201
202// ===========================================================================
203// Exp / Sigmoid / Tanh SIMD dispatch
204// ===========================================================================
205
206/// Fast exp approximation applied element-wise: `output[i] = exp(input[i])`.
207///
208/// Uses a polynomial approximation (degree-4 minimax on [-88, 88]) that is
209/// accurate to roughly 1e-4 relative error for the typical NN activation range.
210#[allow(unsafe_code, unreachable_code)]
211#[inline]
212pub fn exp_slice_dispatch(input: &[f32], output: &mut [f32]) {
213    debug_assert_eq!(input.len(), output.len());
214
215    if cfg!(miri) {
216        exp_slice_scalar(input, output);
217        return;
218    }
219
220    // macOS aarch64: use Apple Accelerate vvexpf (heavily optimized).
221    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
222    {
223        let count = input.len() as i32;
224        // SAFETY: vvexpf reads `count` floats from `input` and writes to `output`.
225        // Both slices have equal length (debug_assert above).
226        unsafe {
227            vvexpf(output.as_mut_ptr(), input.as_ptr(), &count);
228        }
229        return;
230    }
231
232    // x86/x86_64 with MKL: use Intel VML vsExp (heavily optimized).
233    #[cfg(all(feature = "mkl", any(target_arch = "x86", target_arch = "x86_64")))]
234    {
235        let count = input.len() as i32;
236        // SAFETY: vsExp reads `count` floats from `input` and writes to `output`.
237        unsafe { vsExp(count, input.as_ptr(), output.as_mut_ptr()) };
238        return;
239    }
240
241    // aarch64 Linux with ARMPL: use ARM Performance Libraries vectorized exp.
242    #[cfg(all(feature = "armpl", target_arch = "aarch64", not(target_os = "macos")))]
243    {
244        let count = input.len() as i32;
245        // SAFETY: armpl_svexp_f32 reads `count` floats from `input` and writes to `output`.
246        unsafe { armpl_svexp_f32(count, input.as_ptr(), output.as_mut_ptr()) };
247        return;
248    }
249
250    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
251    {
252        if std::is_x86_feature_detected!("avx") {
253            // SAFETY: guarded by runtime feature detection.
254            unsafe {
255                exp_slice_avx(input, output);
256            }
257            return;
258        }
259        if std::is_x86_feature_detected!("sse") {
260            // SAFETY: guarded by runtime feature detection.
261            unsafe {
262                exp_slice_sse(input, output);
263            }
264            return;
265        }
266    }
267
268    #[cfg(all(target_arch = "aarch64", not(target_os = "macos")))]
269    {
270        if std::arch::is_aarch64_feature_detected!("neon") {
271            // SAFETY: guarded by runtime feature detection.
272            unsafe {
273                exp_slice_neon(input, output);
274            }
275            return;
276        }
277    }
278
279    exp_slice_scalar(input, output);
280}
281
282/// Fused subtract-and-exp: `output[i] = exp(input[i] - offset)`.
283///
284/// Combines the max-subtraction and exp steps of softmax into one pass,
285/// avoiding an extra read/write of the output buffer.
286#[allow(unsafe_code)]
287#[inline]
288pub fn sub_exp_slice_dispatch(input: &[f32], offset: f32, output: &mut [f32]) {
289    debug_assert_eq!(input.len(), output.len());
290
291    if cfg!(miri) {
292        sub_exp_slice_scalar(input, offset, output);
293        return;
294    }
295
296    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
297    {
298        if std::is_x86_feature_detected!("avx") {
299            // SAFETY: guarded by runtime feature detection.
300            unsafe {
301                sub_exp_slice_avx(input, offset, output);
302            }
303            return;
304        }
305        if std::is_x86_feature_detected!("sse") {
306            // SAFETY: guarded by runtime feature detection.
307            unsafe {
308                sub_exp_slice_sse(input, offset, output);
309            }
310            return;
311        }
312    }
313
314    #[cfg(target_arch = "aarch64")]
315    {
316        if std::arch::is_aarch64_feature_detected!("neon") {
317            // SAFETY: guarded by runtime feature detection.
318            unsafe {
319                sub_exp_slice_neon(input, offset, output);
320            }
321            return;
322        }
323    }
324
325    sub_exp_slice_scalar(input, offset, output);
326}
327
328/// Fast sigmoid applied element-wise: `output[i] = 1 / (1 + exp(-input[i]))`.
329#[allow(unsafe_code, clippy::needless_return)]
330#[inline]
331pub fn sigmoid_slice_dispatch(input: &[f32], output: &mut [f32]) {
332    debug_assert_eq!(input.len(), output.len());
333
334    if cfg!(miri) {
335        sigmoid_slice_dispatch_scalar(input, output);
336        return;
337    }
338
339    // NEON / AVX / SSE dispatch for sigmoid.
340    {
341        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
342        {
343            if std::is_x86_feature_detected!("avx") {
344                // SAFETY: guarded by runtime feature detection.
345                unsafe {
346                    sigmoid_slice_avx(input, output);
347                }
348                return;
349            }
350            if std::is_x86_feature_detected!("sse") {
351                // SAFETY: guarded by runtime feature detection.
352                unsafe {
353                    sigmoid_slice_sse(input, output);
354                }
355                return;
356            }
357        }
358
359        #[cfg(target_arch = "aarch64")]
360        {
361            if std::arch::is_aarch64_feature_detected!("neon") {
362                unsafe {
363                    sigmoid_slice_neon(input, output);
364                }
365                return;
366            }
367        }
368
369        sigmoid_slice_dispatch_scalar(input, output);
370    }
371}
372
373#[cfg(target_arch = "aarch64")]
374#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
375#[target_feature(enable = "neon")]
376#[inline]
377/// Fast exp for sigmoid: range reduction + 3-term Horner + IEEE bit trick.
378/// WHY 3 terms: 3rd-order polynomial suffices for sigmoid (1/(1+exp) dampens error); max error ~1e-4.
379unsafe fn fast_exp_sigmoid_neon(x: float32x4_t) -> float32x4_t {
380    use std::arch::aarch64::{
381        vaddq_s32, vcvtnq_s32_f32, vcvtq_f32_s32, vdupq_n_s32, vreinterpretq_f32_s32, vshlq_n_s32,
382        vsubq_f32,
383    };
384    let x = vmaxq_f32(vdupq_n_f32(-88.0), vminq_f32(vdupq_n_f32(88.0), x));
385    let n_f = vmulq_f32(x, vdupq_n_f32(std::f32::consts::LOG2_E));
386    let n_i = vcvtnq_s32_f32(n_f);
387    let r = vsubq_f32(
388        x,
389        vmulq_f32(vcvtq_f32_s32(n_i), vdupq_n_f32(std::f32::consts::LN_2)),
390    );
391    let pow2n = vreinterpretq_f32_s32(vshlq_n_s32::<23>(vaddq_s32(n_i, vdupq_n_s32(127))));
392    let p = vfmaq_f32(vdupq_n_f32(0.5), r, vdupq_n_f32(1.0 / 6.0));
393    let p = vfmaq_f32(vdupq_n_f32(1.0), r, p);
394    vmulq_f32(vfmaq_f32(vdupq_n_f32(1.0), r, p), pow2n)
395}
396
397/// Sigmoid via hand-scheduled NEON assembly.
398///
399/// Processes 4 elements per iteration with interleaved load/compute/store.
400/// The FMA pipeline is kept fully saturated by overlapping independent operations.
401#[cfg(target_arch = "aarch64")]
402#[allow(unsafe_code)]
403unsafe fn sigmoid_slice_neon(input: &[f32], output: &mut [f32]) {
404    let len = input.len();
405    let mut inp = input.as_ptr();
406    let mut out = output.as_mut_ptr();
407    let mut remaining = len;
408
409    // Load all constants ONCE before the loop, keep in NEON registers
410    if remaining >= 4 {
411        unsafe {
412            // Constants on stack for ld1r broadcast
413            let c_neg88: f32 = -88.0;
414            let c_pos88: f32 = 88.0;
415            // Schraudolph 1999 constants: exp(x) ≈ reinterpret(int(x * C + B))
416            // C = 2^23 / ln(2) = 12102203.16, B = 127 * 2^23 = 1065353216
417            // WHY: 2^23/ln(2) maps float mantissa bits to IEEE 754 exponent field; 127*2^23 adds the exponent bias.
418            let c_schr_c: f32 = 12102203.0; // 2^23 / ln(2)
419            let c_schr_b: i32 = 127 << 23; // 1065353216 as integer
420            let c_sixth: f32 = 1.0 / 6.0;
421            let c_half: f32 = 0.5;
422            let c_one: f32 = 1.0;
423            let c_127: i32 = 127;
424
425            // Load constants into NEON registers (stays there for entire loop)
426            std::arch::asm!(
427                "ld1r {{v16.4s}}, [{p_neg88}]",
428                "ld1r {{v17.4s}}, [{p_pos88}]",
429                "ld1r {{v18.4s}}, [{p_schr_c}]",   // Schraudolph C (float)
430                "dup  v19.4s, {p_schr_b:w}",        // Schraudolph B (integer 127<<23)
431                "ld1r {{v20.4s}}, [{p_sixth}]",
432                "ld1r {{v21.4s}}, [{p_half}]",
433                "ld1r {{v22.4s}}, [{p_one}]",
434                "dup  v23.4s, {p_127:w}",
435                p_neg88 = in(reg) &c_neg88,
436                p_pos88 = in(reg) &c_pos88,
437                p_schr_c = in(reg) &c_schr_c,
438                p_schr_b = in(reg) c_schr_b,
439                p_sixth = in(reg) &c_sixth,
440                p_half = in(reg) &c_half,
441                p_one = in(reg) &c_one,
442                p_127 = in(reg) c_127,
443                out("v16") _, out("v17") _, out("v18") _, out("v19") _,
444                out("v20") _, out("v21") _, out("v22") _, out("v23") _,
445            );
446
447            // Schraudolph bit-trick: exp(x) ≈ reinterpret_f32(int(x * 2^23/ln2) + 127<<23)
448            // Proper integer arithmetic: fcvtzs to get int, then add bias as int, then reinterpret
449            // 4× unrolled, 16 elements per iteration
450            while remaining >= 16 {
451                std::arch::asm!(
452                    "ldp q0, q1, [{inp}]",
453                    "ldp q2, q3, [{inp}, #32]",
454                    "add {inp}, {inp}, #64",
455                    "fneg v0.4s, v0.4s",
456                    "fneg v1.4s, v1.4s",
457                    "fneg v2.4s, v2.4s",
458                    "fneg v3.4s, v3.4s",
459                    "fmax v0.4s, v0.4s, v16.4s",
460                    "fmax v1.4s, v1.4s, v16.4s",
461                    "fmax v2.4s, v2.4s, v16.4s",
462                    "fmax v3.4s, v3.4s, v16.4s",
463                    "fmin v0.4s, v0.4s, v17.4s",
464                    "fmin v1.4s, v1.4s, v17.4s",
465                    "fmin v2.4s, v2.4s, v17.4s",
466                    "fmin v3.4s, v3.4s, v17.4s",
467                    // x * (2^23/ln2) → convert to int
468                    "fmul v0.4s, v0.4s, v18.4s",
469                    "fmul v1.4s, v1.4s, v18.4s",
470                    "fmul v2.4s, v2.4s, v18.4s",
471                    "fmul v3.4s, v3.4s, v18.4s",
472                    "fcvtzs v0.4s, v0.4s",
473                    "fcvtzs v1.4s, v1.4s",
474                    "fcvtzs v2.4s, v2.4s",
475                    "fcvtzs v3.4s, v3.4s",
476                    // + 127*2^23 (integer add)
477                    "add v0.4s, v0.4s, v19.4s",
478                    "add v1.4s, v1.4s, v19.4s",
479                    "add v2.4s, v2.4s, v19.4s",
480                    "add v3.4s, v3.4s, v19.4s",
481                    // v0-v3 bits ARE exp(-x) when reinterpreted as float
482                    // sigmoid = 1 / (1 + exp)
483                    "fadd v0.4s, v22.4s, v0.4s",
484                    "fadd v1.4s, v22.4s, v1.4s",
485                    "fadd v2.4s, v22.4s, v2.4s",
486                    "fadd v3.4s, v22.4s, v3.4s",
487                    "fdiv v0.4s, v22.4s, v0.4s",
488                    "fdiv v1.4s, v22.4s, v1.4s",
489                    "fdiv v2.4s, v22.4s, v2.4s",
490                    "fdiv v3.4s, v22.4s, v3.4s",
491                    "stp q0, q1, [{out}]",
492                    "stp q2, q3, [{out}, #32]",
493                    "add {out}, {out}, #64",
494                    inp = inout(reg) inp,
495                    out = inout(reg) out,
496                    out("v0") _, out("v1") _, out("v2") _, out("v3") _,
497                );
498                remaining -= 16;
499            }
500            // 4-element tail — Schraudolph
501            while remaining >= 4 {
502                std::arch::asm!(
503                    "ld1 {{v0.4s}}, [{inp}], #16",
504                    "fneg v0.4s, v0.4s",
505                    "fmax v0.4s, v0.4s, v16.4s",
506                    "fmin v0.4s, v0.4s, v17.4s",
507                    "fmul v0.4s, v0.4s, v18.4s",
508                    "fcvtzs v0.4s, v0.4s",
509                    "add v0.4s, v0.4s, v19.4s",
510                    "fadd v0.4s, v22.4s, v0.4s",
511                    "fdiv v0.4s, v22.4s, v0.4s",
512                    "st1 {{v0.4s}}, [{out}], #16",
513                    inp = inout(reg) inp,
514                    out = inout(reg) out,
515                    out("v0") _,
516                );
517                remaining -= 4;
518            }
519            // 4-element tail — Schraudolph
520            while remaining >= 4 {
521                std::arch::asm!(
522                    "ld1 {{v0.4s}}, [{inp}], #16",
523                    "fneg v0.4s, v0.4s",
524                    "fmax v0.4s, v0.4s, v16.4s",
525                    "fmin v0.4s, v0.4s, v17.4s",
526                    "fmul v0.4s, v0.4s, v18.4s",
527                    "fcvtzs v0.4s, v0.4s",
528                    "add v0.4s, v0.4s, v19.4s",
529                    "fadd v0.4s, v22.4s, v0.4s",
530                    "fdiv v0.4s, v22.4s, v0.4s",
531                    "st1 {{v0.4s}}, [{out}], #16",
532                    inp = inout(reg) inp,
533                    out = inout(reg) out,
534                    out("v0") _,
535                );
536                remaining -= 4;
537            }
538        }
539    }
540
541    // Scalar tail
542    for i in 0..remaining {
543        unsafe {
544            let x = *inp.add(i);
545            *out.add(i) = 1.0 / (1.0 + (-x).exp());
546        }
547    }
548}
549
550// (sigmoid_vdsp and silu_vdsp removed — benchmarked slower than NEON polynomial)
551
552/// Fast tanh applied element-wise: `output[i] = tanh(input[i])`.
553///
554/// Computed as `2 * sigmoid(2x) - 1`.
555#[allow(unsafe_code)]
556#[inline]
557pub fn tanh_slice_dispatch(input: &[f32], output: &mut [f32]) {
558    debug_assert_eq!(input.len(), output.len());
559
560    if cfg!(miri) {
561        tanh_slice_dispatch_scalar(input, output);
562        return;
563    }
564
565    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
566    {
567        if std::is_x86_feature_detected!("avx") {
568            // SAFETY: guarded by runtime feature detection.
569            unsafe {
570                tanh_slice_avx(input, output);
571            }
572            return;
573        }
574        if std::is_x86_feature_detected!("sse") {
575            // SAFETY: guarded by runtime feature detection.
576            unsafe {
577                tanh_slice_sse(input, output);
578            }
579            return;
580        }
581    }
582
583    #[cfg(target_arch = "aarch64")]
584    {
585        if std::arch::is_aarch64_feature_detected!("neon") {
586            // SAFETY: guarded by runtime feature detection.
587            unsafe {
588                tanh_slice_neon(input, output);
589            }
590            return;
591        }
592    }
593
594    tanh_slice_dispatch_scalar(input, output);
595}
596
597/// Fused SiLU (Swish) applied element-wise: `output[i] = input[i] * sigmoid(input[i])`.
598///
599/// Single-pass over the data avoids the 2× bandwidth penalty of separate sigmoid + multiply.
600#[allow(unsafe_code)]
601#[inline]
602pub fn silu_slice_dispatch(input: &[f32], output: &mut [f32]) {
603    debug_assert_eq!(input.len(), output.len());
604
605    if cfg!(miri) {
606        silu_slice_dispatch_scalar(input, output);
607        return;
608    }
609
610    #[cfg(target_arch = "aarch64")]
611    {
612        if std::arch::is_aarch64_feature_detected!("neon") {
613            unsafe {
614                silu_slice_neon(input, output);
615            }
616            return;
617        }
618    }
619
620    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
621    {
622        if std::is_x86_feature_detected!("avx") {
623            unsafe { silu_slice_avx(input, output) };
624            return;
625        }
626        if std::is_x86_feature_detected!("sse") {
627            unsafe { silu_slice_sse(input, output) };
628            return;
629        }
630    }
631
632    silu_slice_dispatch_scalar(input, output);
633}
634
635// ===========================================================================
636// Reduction dispatchers: max_reduce, add_reduce
637// ===========================================================================
638
639/// Find the maximum value in `data`.  Returns `f32::NEG_INFINITY` for empty slices.
640#[allow(unsafe_code, dead_code)]
641#[inline]
642pub fn max_reduce_dispatch(data: &[f32]) -> f32 {
643    if data.is_empty() {
644        return f32::NEG_INFINITY;
645    }
646
647    if cfg!(miri) {
648        return max_reduce_scalar(data);
649    }
650
651    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
652    {
653        if std::is_x86_feature_detected!("avx") {
654            // SAFETY: guarded by runtime feature detection.
655            return unsafe { max_reduce_avx(data) };
656        }
657        if std::is_x86_feature_detected!("sse") {
658            // SAFETY: guarded by runtime feature detection.
659            return unsafe { max_reduce_sse(data) };
660        }
661    }
662
663    #[cfg(target_arch = "aarch64")]
664    {
665        if std::arch::is_aarch64_feature_detected!("neon") {
666            // SAFETY: guarded by runtime feature detection.
667            return unsafe { max_reduce_neon(data) };
668        }
669    }
670
671    max_reduce_scalar(data)
672}
673
674/// Sum all values in `data`.  Returns `0.0` for empty slices.
675#[allow(unsafe_code, dead_code)]
676#[inline]
677pub fn add_reduce_dispatch(data: &[f32]) -> f32 {
678    if data.is_empty() {
679        return 0.0;
680    }
681
682    if cfg!(miri) {
683        return add_reduce_scalar(data);
684    }
685
686    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
687    {
688        if std::is_x86_feature_detected!("avx") {
689            // SAFETY: guarded by runtime feature detection.
690            return unsafe { add_reduce_avx(data) };
691        }
692        if std::is_x86_feature_detected!("sse") {
693            // SAFETY: guarded by runtime feature detection.
694            return unsafe { add_reduce_sse(data) };
695        }
696    }
697
698    #[cfg(target_arch = "aarch64")]
699    {
700        if std::arch::is_aarch64_feature_detected!("neon") {
701            // SAFETY: guarded by runtime feature detection.
702            return unsafe { add_reduce_neon(data) };
703        }
704    }
705
706    add_reduce_scalar(data)
707}
708
709// ===========================================================================
710// Scalar-broadcast multiply in-place
711// ===========================================================================
712
713/// Multiply every element of `data` by `scalar` in-place.
714#[allow(unsafe_code, dead_code)]
715#[inline]
716pub fn mul_scalar_inplace_dispatch(data: &mut [f32], scalar: f32) {
717    if cfg!(miri) || data.is_empty() {
718        for v in data.iter_mut() {
719            *v *= scalar;
720        }
721        return;
722    }
723
724    #[cfg(target_arch = "aarch64")]
725    {
726        if std::arch::is_aarch64_feature_detected!("neon") {
727            // SAFETY: guarded by runtime feature detection.
728            unsafe { mul_scalar_inplace_neon(data, scalar) };
729            return;
730        }
731    }
732
733    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
734    {
735        if std::is_x86_feature_detected!("avx") {
736            // SAFETY: guarded by runtime feature detection.
737            unsafe { mul_scalar_inplace_avx(data, scalar) };
738            return;
739        }
740        if std::is_x86_feature_detected!("sse") {
741            // SAFETY: guarded by runtime feature detection.
742            unsafe { mul_scalar_inplace_sse(data, scalar) };
743            return;
744        }
745    }
746
747    for v in data.iter_mut() {
748        *v *= scalar;
749    }
750}
751
752#[cfg(target_arch = "aarch64")]
753#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
754#[target_feature(enable = "neon")]
755unsafe fn mul_scalar_inplace_neon(data: &mut [f32], scalar: f32) {
756    let len = data.len();
757    let ptr = data.as_mut_ptr();
758    let vs = vdupq_n_f32(scalar);
759    let mut i = 0usize;
760    while i + 4 <= len {
761        let v = vld1q_f32(ptr.add(i));
762        vst1q_f32(ptr.add(i), vmulq_f32(v, vs));
763        i += 4;
764    }
765    while i < len {
766        *ptr.add(i) *= scalar;
767        i += 1;
768    }
769}
770
771#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
772#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
773#[target_feature(enable = "avx")]
774unsafe fn mul_scalar_inplace_avx(data: &mut [f32], scalar: f32) {
775    #[cfg(target_arch = "x86")]
776    use std::arch::x86::*;
777    #[cfg(target_arch = "x86_64")]
778    use std::arch::x86_64::*;
779    let len = data.len();
780    let ptr = data.as_mut_ptr();
781    let vs = _mm256_set1_ps(scalar);
782    let mut i = 0usize;
783    while i + 8 <= len {
784        let v = _mm256_loadu_ps(ptr.add(i));
785        _mm256_storeu_ps(ptr.add(i), _mm256_mul_ps(v, vs));
786        i += 8;
787    }
788    // SSE tail
789    let vs4 = _mm_set1_ps(scalar);
790    while i + 4 <= len {
791        let v = _mm_loadu_ps(ptr.add(i));
792        _mm_storeu_ps(ptr.add(i), _mm_mul_ps(v, vs4));
793        i += 4;
794    }
795    while i < len {
796        *ptr.add(i) *= scalar;
797        i += 1;
798    }
799}
800
801#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
802#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
803#[target_feature(enable = "sse")]
804unsafe fn mul_scalar_inplace_sse(data: &mut [f32], scalar: f32) {
805    #[cfg(target_arch = "x86")]
806    use std::arch::x86::*;
807    #[cfg(target_arch = "x86_64")]
808    use std::arch::x86_64::*;
809    let len = data.len();
810    let ptr = data.as_mut_ptr();
811    let vs = _mm_set1_ps(scalar);
812    let mut i = 0usize;
813    while i + 4 <= len {
814        let v = _mm_loadu_ps(ptr.add(i));
815        _mm_storeu_ps(ptr.add(i), _mm_mul_ps(v, vs));
816        i += 4;
817    }
818    while i < len {
819        *ptr.add(i) *= scalar;
820        i += 1;
821    }
822}
823
824// ===========================================================================
825// FMA dispatch (conv2d inner loop helper)
826// ===========================================================================
827
828/// Fused multiply-accumulate: `acc[i] += a[i] * b[i]`.
829#[allow(unsafe_code, dead_code)]
830#[inline]
831pub fn fma_slice_dispatch(a: &[f32], b: &[f32], acc: &mut [f32]) {
832    debug_assert_eq!(a.len(), b.len());
833    debug_assert_eq!(a.len(), acc.len());
834
835    if cfg!(miri) {
836        // SAFETY: scalar path only reads/writes within equal-sized slice bounds.
837        unsafe {
838            fma_slice_scalar(a, b, acc);
839        }
840        return;
841    }
842
843    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
844    {
845        if std::is_x86_feature_detected!("avx") {
846            // SAFETY: guarded by runtime feature detection.
847            unsafe {
848                fma_slice_avx(a, b, acc);
849            }
850            return;
851        }
852        if std::is_x86_feature_detected!("sse") {
853            // SAFETY: guarded by runtime feature detection.
854            unsafe {
855                fma_slice_sse(a, b, acc);
856            }
857            return;
858        }
859    }
860
861    #[cfg(target_arch = "aarch64")]
862    {
863        if std::arch::is_aarch64_feature_detected!("neon") {
864            // SAFETY: guarded by runtime feature detection.
865            unsafe {
866                fma_slice_neon(a, b, acc);
867            }
868            return;
869        }
870    }
871
872    // SAFETY: scalar path only reads/writes within equal-sized slice bounds.
873    unsafe {
874        fma_slice_scalar(a, b, acc);
875    }
876}
877
878// ===========================================================================
879// Binary same-shape dispatch (existing)
880// ===========================================================================
881
882#[allow(unsafe_code, unreachable_code)]
883#[inline]
884pub fn binary_same_shape_dispatch(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
885    debug_assert_eq!(lhs.len(), rhs.len());
886    debug_assert_eq!(lhs.len(), out.len());
887
888    if cfg!(miri) {
889        // SAFETY: scalar path only reads/writes within equal-sized slice bounds.
890        unsafe {
891            binary_same_shape_scalar(lhs, rhs, out, kind);
892        }
893        return;
894    }
895
896    // macOS: use vDSP for add/sub/mul (heavily optimized, zero loop overhead).
897    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
898    {
899        let n = lhs.len() as u32;
900        // SAFETY: vDSP functions read/write `n` floats from contiguous slices.
901        unsafe {
902            match kind {
903                BinaryKind::Add => {
904                    vDSP_vadd(lhs.as_ptr(), 1, rhs.as_ptr(), 1, out.as_mut_ptr(), 1, n)
905                }
906                // NOTE: vDSP_vsub computes A - B with reversed argument order: vsub(B, ..., A, ..., C, ...)
907                BinaryKind::Sub => {
908                    vDSP_vsub(rhs.as_ptr(), 1, lhs.as_ptr(), 1, out.as_mut_ptr(), 1, n)
909                }
910                BinaryKind::Mul => {
911                    vDSP_vmul(lhs.as_ptr(), 1, rhs.as_ptr(), 1, out.as_mut_ptr(), 1, n)
912                }
913            }
914        }
915        return;
916    }
917
918    // x86/x86_64 with MKL: use Intel VML for add/sub/mul (heavily optimized).
919    #[cfg(all(feature = "mkl", any(target_arch = "x86", target_arch = "x86_64")))]
920    {
921        let n = lhs.len() as i32;
922        // SAFETY: VML functions read `n` floats from contiguous slices and write to `out`.
923        unsafe {
924            match kind {
925                BinaryKind::Add => vsAdd(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
926                BinaryKind::Sub => vsSub(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
927                BinaryKind::Mul => vsMul(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
928            }
929        }
930        return;
931    }
932
933    // aarch64 Linux with ARMPL: use ARM Performance Libraries for add/sub/mul.
934    #[cfg(all(feature = "armpl", target_arch = "aarch64", not(target_os = "macos")))]
935    {
936        let n = lhs.len() as i32;
937        // SAFETY: ARMPL functions read `n` floats from contiguous slices and write to `out`.
938        unsafe {
939            match kind {
940                BinaryKind::Add => armpl_svadd_f32(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
941                BinaryKind::Sub => armpl_svsub_f32(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
942                BinaryKind::Mul => armpl_svmul_f32(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
943            }
944        }
945        return;
946    }
947
948    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
949    {
950        if std::is_x86_feature_detected!("avx") {
951            // SAFETY: guarded by runtime feature detection.
952            unsafe {
953                binary_same_shape_avx(lhs, rhs, out, kind);
954            }
955            return;
956        }
957        if std::is_x86_feature_detected!("sse") {
958            // SAFETY: guarded by runtime feature detection.
959            unsafe {
960                binary_same_shape_sse(lhs, rhs, out, kind);
961            }
962            return;
963        }
964    }
965
966    #[cfg(all(target_arch = "aarch64", not(target_os = "macos")))]
967    {
968        if std::arch::is_aarch64_feature_detected!("neon") {
969            // SAFETY: guarded by runtime feature detection.
970            unsafe {
971                binary_same_shape_neon(lhs, rhs, out, kind);
972            }
973            return;
974        }
975    }
976
977    // SAFETY: scalar path only reads/writes within equal-sized slice bounds.
978    unsafe {
979        binary_same_shape_scalar(lhs, rhs, out, kind);
980    }
981}
982
983// ===========================================================================
984// Scalar fallbacks
985// ===========================================================================
986
987#[allow(unsafe_code)]
988#[allow(unsafe_op_in_unsafe_fn)]
989unsafe fn relu_slice_scalar(values: &mut [f32]) {
990    let len = values.len();
991    let ptr = values.as_mut_ptr();
992    let mut index = 0usize;
993
994    while index + 8 <= len {
995        let v0 = *ptr.add(index);
996        let v1 = *ptr.add(index + 1);
997        let v2 = *ptr.add(index + 2);
998        let v3 = *ptr.add(index + 3);
999        let v4 = *ptr.add(index + 4);
1000        let v5 = *ptr.add(index + 5);
1001        let v6 = *ptr.add(index + 6);
1002        let v7 = *ptr.add(index + 7);
1003        *ptr.add(index) = v0.max(0.0);
1004        *ptr.add(index + 1) = v1.max(0.0);
1005        *ptr.add(index + 2) = v2.max(0.0);
1006        *ptr.add(index + 3) = v3.max(0.0);
1007        *ptr.add(index + 4) = v4.max(0.0);
1008        *ptr.add(index + 5) = v5.max(0.0);
1009        *ptr.add(index + 6) = v6.max(0.0);
1010        *ptr.add(index + 7) = v7.max(0.0);
1011        index += 8;
1012    }
1013
1014    while index < len {
1015        *ptr.add(index) = (*ptr.add(index)).max(0.0);
1016        index += 1;
1017    }
1018}
1019
1020#[allow(unsafe_code)]
1021#[allow(unsafe_op_in_unsafe_fn)]
1022unsafe fn relu_to_slice_scalar(input: &[f32], output: &mut [f32]) {
1023    let len = input.len();
1024    let in_ptr = input.as_ptr();
1025    let out_ptr = output.as_mut_ptr();
1026    let mut index = 0usize;
1027
1028    while index + 8 <= len {
1029        *out_ptr.add(index) = (*in_ptr.add(index)).max(0.0);
1030        *out_ptr.add(index + 1) = (*in_ptr.add(index + 1)).max(0.0);
1031        *out_ptr.add(index + 2) = (*in_ptr.add(index + 2)).max(0.0);
1032        *out_ptr.add(index + 3) = (*in_ptr.add(index + 3)).max(0.0);
1033        *out_ptr.add(index + 4) = (*in_ptr.add(index + 4)).max(0.0);
1034        *out_ptr.add(index + 5) = (*in_ptr.add(index + 5)).max(0.0);
1035        *out_ptr.add(index + 6) = (*in_ptr.add(index + 6)).max(0.0);
1036        *out_ptr.add(index + 7) = (*in_ptr.add(index + 7)).max(0.0);
1037        index += 8;
1038    }
1039
1040    while index < len {
1041        *out_ptr.add(index) = (*in_ptr.add(index)).max(0.0);
1042        index += 1;
1043    }
1044}
1045
1046#[allow(unsafe_code)]
1047#[allow(unsafe_op_in_unsafe_fn)]
1048unsafe fn binary_same_shape_scalar(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
1049    let len = lhs.len();
1050    let left_ptr = lhs.as_ptr();
1051    let right_ptr = rhs.as_ptr();
1052    let out_ptr = out.as_mut_ptr();
1053    let mut index = 0usize;
1054
1055    match kind {
1056        BinaryKind::Add => {
1057            while index + 8 <= len {
1058                *out_ptr.add(index) = *left_ptr.add(index) + *right_ptr.add(index);
1059                *out_ptr.add(index + 1) = *left_ptr.add(index + 1) + *right_ptr.add(index + 1);
1060                *out_ptr.add(index + 2) = *left_ptr.add(index + 2) + *right_ptr.add(index + 2);
1061                *out_ptr.add(index + 3) = *left_ptr.add(index + 3) + *right_ptr.add(index + 3);
1062                *out_ptr.add(index + 4) = *left_ptr.add(index + 4) + *right_ptr.add(index + 4);
1063                *out_ptr.add(index + 5) = *left_ptr.add(index + 5) + *right_ptr.add(index + 5);
1064                *out_ptr.add(index + 6) = *left_ptr.add(index + 6) + *right_ptr.add(index + 6);
1065                *out_ptr.add(index + 7) = *left_ptr.add(index + 7) + *right_ptr.add(index + 7);
1066                index += 8;
1067            }
1068            while index < len {
1069                *out_ptr.add(index) = *left_ptr.add(index) + *right_ptr.add(index);
1070                index += 1;
1071            }
1072        }
1073        BinaryKind::Sub => {
1074            while index + 8 <= len {
1075                *out_ptr.add(index) = *left_ptr.add(index) - *right_ptr.add(index);
1076                *out_ptr.add(index + 1) = *left_ptr.add(index + 1) - *right_ptr.add(index + 1);
1077                *out_ptr.add(index + 2) = *left_ptr.add(index + 2) - *right_ptr.add(index + 2);
1078                *out_ptr.add(index + 3) = *left_ptr.add(index + 3) - *right_ptr.add(index + 3);
1079                *out_ptr.add(index + 4) = *left_ptr.add(index + 4) - *right_ptr.add(index + 4);
1080                *out_ptr.add(index + 5) = *left_ptr.add(index + 5) - *right_ptr.add(index + 5);
1081                *out_ptr.add(index + 6) = *left_ptr.add(index + 6) - *right_ptr.add(index + 6);
1082                *out_ptr.add(index + 7) = *left_ptr.add(index + 7) - *right_ptr.add(index + 7);
1083                index += 8;
1084            }
1085            while index < len {
1086                *out_ptr.add(index) = *left_ptr.add(index) - *right_ptr.add(index);
1087                index += 1;
1088            }
1089        }
1090        BinaryKind::Mul => {
1091            while index + 8 <= len {
1092                *out_ptr.add(index) = *left_ptr.add(index) * *right_ptr.add(index);
1093                *out_ptr.add(index + 1) = *left_ptr.add(index + 1) * *right_ptr.add(index + 1);
1094                *out_ptr.add(index + 2) = *left_ptr.add(index + 2) * *right_ptr.add(index + 2);
1095                *out_ptr.add(index + 3) = *left_ptr.add(index + 3) * *right_ptr.add(index + 3);
1096                *out_ptr.add(index + 4) = *left_ptr.add(index + 4) * *right_ptr.add(index + 4);
1097                *out_ptr.add(index + 5) = *left_ptr.add(index + 5) * *right_ptr.add(index + 5);
1098                *out_ptr.add(index + 6) = *left_ptr.add(index + 6) * *right_ptr.add(index + 6);
1099                *out_ptr.add(index + 7) = *left_ptr.add(index + 7) * *right_ptr.add(index + 7);
1100                index += 8;
1101            }
1102            while index < len {
1103                *out_ptr.add(index) = *left_ptr.add(index) * *right_ptr.add(index);
1104                index += 1;
1105            }
1106        }
1107    }
1108}
1109
1110fn exp_slice_scalar(input: &[f32], output: &mut [f32]) {
1111    for (o, &v) in output.iter_mut().zip(input.iter()) {
1112        *o = v.exp();
1113    }
1114}
1115
1116fn sub_exp_slice_scalar(input: &[f32], offset: f32, output: &mut [f32]) {
1117    for (o, &v) in output.iter_mut().zip(input.iter()) {
1118        *o = (v - offset).exp();
1119    }
1120}
1121
1122fn sigmoid_slice_dispatch_scalar(input: &[f32], output: &mut [f32]) {
1123    for (o, &v) in output.iter_mut().zip(input.iter()) {
1124        *o = sigmoid_scalar(v);
1125    }
1126}
1127
1128fn tanh_slice_dispatch_scalar(input: &[f32], output: &mut [f32]) {
1129    for (o, &v) in output.iter_mut().zip(input.iter()) {
1130        *o = v.tanh();
1131    }
1132}
1133
1134fn silu_slice_dispatch_scalar(input: &[f32], output: &mut [f32]) {
1135    for (o, &v) in output.iter_mut().zip(input.iter()) {
1136        let s = 1.0 / (1.0 + (-v).exp());
1137        *o = v * s;
1138    }
1139}
1140
1141#[allow(dead_code)]
1142fn max_reduce_scalar(data: &[f32]) -> f32 {
1143    let mut acc = f32::NEG_INFINITY;
1144    for &v in data {
1145        acc = acc.max(v);
1146    }
1147    acc
1148}
1149
1150#[allow(dead_code)]
1151fn add_reduce_scalar(data: &[f32]) -> f32 {
1152    let mut acc = 0.0f32;
1153    for &v in data {
1154        acc += v;
1155    }
1156    acc
1157}
1158
1159#[allow(unsafe_code, dead_code)]
1160#[allow(unsafe_op_in_unsafe_fn)]
1161unsafe fn fma_slice_scalar(a: &[f32], b: &[f32], acc: &mut [f32]) {
1162    let len = a.len();
1163    let a_ptr = a.as_ptr();
1164    let b_ptr = b.as_ptr();
1165    let acc_ptr = acc.as_mut_ptr();
1166    let mut index = 0usize;
1167
1168    while index + 4 <= len {
1169        *acc_ptr.add(index) += *a_ptr.add(index) * *b_ptr.add(index);
1170        *acc_ptr.add(index + 1) += *a_ptr.add(index + 1) * *b_ptr.add(index + 1);
1171        *acc_ptr.add(index + 2) += *a_ptr.add(index + 2) * *b_ptr.add(index + 2);
1172        *acc_ptr.add(index + 3) += *a_ptr.add(index + 3) * *b_ptr.add(index + 3);
1173        index += 4;
1174    }
1175    while index < len {
1176        *acc_ptr.add(index) += *a_ptr.add(index) * *b_ptr.add(index);
1177        index += 1;
1178    }
1179}
1180
1181// ===========================================================================
1182// SSE fast-exp helper (4-wide)
1183// ===========================================================================
1184//
1185// Uses the classic range-reduction approach:
1186//   exp(x) = 2^n * exp(r)  where  n = round(x / ln2), r = x - n*ln2
1187// Then exp(r) is approximated with a degree-4 polynomial on [-ln2/2, ln2/2].
1188
1189/// Schraudolph 1999 bit-trick exp for SSE: exp(x) ≈ reinterpret(int(x * 2^23/ln2) + 127*2^23).
1190/// WHY: ~3x faster than polynomial, ~1e-3 accuracy is sufficient for sigmoid/tanh where 1/(1+exp) dampens error.
1191#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1192#[allow(unsafe_code)]
1193#[allow(unsafe_op_in_unsafe_fn)]
1194#[target_feature(enable = "sse")]
1195#[inline]
1196unsafe fn fast_exp_bittrick_sse(x: __m128) -> __m128 {
1197    // SSE2 intrinsics used below are always available on x86_64.
1198    #[cfg(target_arch = "x86")]
1199    use std::arch::x86::{_mm_add_epi32, _mm_cvtps_epi32, _mm_set1_epi32};
1200    #[cfg(target_arch = "x86_64")]
1201    use std::arch::x86_64::{_mm_add_epi32, _mm_cvtps_epi32, _mm_set1_epi32};
1202    // exp(x) ≈ reinterpret(int(x * C + B)) where C = 2^23/ln2, B = 127*2^23
1203    let scale = _mm_set1_ps(12102203.0); // WHY: 2^23/ln(2) maps float to IEEE 754 exponent field
1204    let offset = _mm_set1_epi32(1065353216); // WHY: 127*2^23 is the IEEE 754 exponent bias in integer form
1205    let clamp_lo = _mm_set1_ps(-87.0); // WHY: below this exp() produces denormals (underflow)
1206    let clamp_hi = _mm_set1_ps(88.0); // WHY: above this exp() exceeds f32 max (overflow to inf)
1207    let x_clamped = _mm_max_ps(_mm_min_ps(x, clamp_hi), clamp_lo);
1208    let val = _mm_cvtps_epi32(_mm_mul_ps(x_clamped, scale));
1209    _mm_castsi128_ps(_mm_add_epi32(val, offset))
1210}
1211
1212/// Polynomial exp for SSE: range-reduction + 6-term Taylor. Higher accuracy (~1e-6)
1213/// for standalone exp (softmax, etc.) where precision matters more.
1214/// WHY 6 terms: 6th-order Taylor series for 2^f on [0,1), max error ~1e-7, good accuracy/speed tradeoff.
1215#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1216#[allow(unsafe_code)]
1217#[allow(unsafe_op_in_unsafe_fn)]
1218#[target_feature(enable = "sse")]
1219unsafe fn fast_exp_sse(x: __m128) -> __m128 {
1220    let ln2_inv = _mm_set1_ps(std::f32::consts::LOG2_E);
1221    let ln2_hi = _mm_set1_ps(0.693_359_4); // upper bits of ln(2)
1222    let ln2_lo = _mm_set1_ps(-2.121_944_4e-4); // lower bits of ln(2)
1223
1224    // Polynomial coefficients (Taylor series for exp(r) on [-ln2/2, ln2/2])
1225    let c0 = _mm_set1_ps(1.0);
1226    let c1 = _mm_set1_ps(1.0);
1227    let c2 = _mm_set1_ps(0.5);
1228    let c3 = _mm_set1_ps(1.0 / 6.0);
1229    let c4 = _mm_set1_ps(1.0 / 24.0);
1230    let c5 = _mm_set1_ps(1.0 / 120.0);
1231    let c6 = _mm_set1_ps(1.0 / 720.0);
1232
1233    // Clamp input to prevent overflow/underflow
1234    let x = _mm_max_ps(_mm_set1_ps(-88.0), _mm_min_ps(_mm_set1_ps(88.0), x));
1235
1236    // n = round(x / ln2)
1237    let n_f = _mm_mul_ps(x, ln2_inv);
1238    // Round to nearest integer using convert (rounds to nearest by default)
1239    let n_i = _mm_cvtps_epi32(n_f);
1240    let n_f = _mm_cvtepi32_ps(n_i);
1241
1242    // r = x - n * ln2  (two-step for accuracy)
1243    let r = _mm_sub_ps(
1244        _mm_sub_ps(x, _mm_mul_ps(n_f, ln2_hi)),
1245        _mm_mul_ps(n_f, ln2_lo),
1246    );
1247
1248    // Polynomial: c0 + r*(c1 + r*(c2 + r*(c3 + r*(c4 + r*(c5 + r*c6)))))
1249    let mut poly = _mm_add_ps(c5, _mm_mul_ps(r, c6));
1250    poly = _mm_add_ps(c4, _mm_mul_ps(r, poly));
1251    poly = _mm_add_ps(c3, _mm_mul_ps(r, poly));
1252    poly = _mm_add_ps(c2, _mm_mul_ps(r, poly));
1253    poly = _mm_add_ps(c1, _mm_mul_ps(r, poly));
1254    poly = _mm_add_ps(c0, _mm_mul_ps(r, poly));
1255
1256    // Multiply by 2^n using bit manipulation: reinterpret (n + 127) << 23 as f32.
1257    // _mm_add_epi32 and _mm_slli_epi32 are SSE2, always available on x86_64.
1258    let pow2n = {
1259        #[cfg(target_arch = "x86")]
1260        use std::arch::x86::{_mm_add_epi32, _mm_slli_epi32};
1261        #[cfg(target_arch = "x86_64")]
1262        use std::arch::x86_64::{_mm_add_epi32, _mm_slli_epi32};
1263        let bias = _mm_set1_epi32(127);
1264        _mm_castsi128_ps(_mm_slli_epi32(_mm_add_epi32(n_i, bias), 23))
1265    };
1266
1267    _mm_mul_ps(poly, pow2n)
1268}
1269
1270// ===========================================================================
1271// AVX fast-exp helper (8-wide)
1272// ===========================================================================
1273
1274/// Schraudolph 1999 bit-trick exp for AVX: exp(x) ≈ reinterpret(int(x * 2^23/ln2) + 127*2^23).
1275/// WHY: ~3x faster than polynomial, ~1e-3 accuracy is sufficient for sigmoid/tanh where 1/(1+exp) dampens error.
1276#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1277#[allow(unsafe_code)]
1278#[allow(unsafe_op_in_unsafe_fn)]
1279#[target_feature(enable = "avx")]
1280#[inline]
1281unsafe fn fast_exp_bittrick_avx(x: __m256) -> __m256 {
1282    #[cfg(target_arch = "x86")]
1283    use std::arch::x86::{_mm256_add_epi32, _mm256_cvtps_epi32, _mm256_set1_epi32};
1284    #[cfg(target_arch = "x86_64")]
1285    use std::arch::x86_64::{_mm256_add_epi32, _mm256_cvtps_epi32, _mm256_set1_epi32};
1286    let scale = _mm256_set1_ps(12102203.0); // WHY: 2^23/ln(2) maps float to IEEE 754 exponent field
1287    let offset = _mm256_set1_epi32(1065353216); // WHY: 127*2^23 is the IEEE 754 exponent bias in integer form
1288    let clamp_lo = _mm256_set1_ps(-87.0); // WHY: below this exp() produces denormals
1289    let clamp_hi = _mm256_set1_ps(88.0); // WHY: above this exp() exceeds f32 max
1290    let x_clamped = _mm256_max_ps(_mm256_min_ps(x, clamp_hi), clamp_lo);
1291    let val = _mm256_cvtps_epi32(_mm256_mul_ps(x_clamped, scale));
1292    _mm256_castsi256_ps(_mm256_add_epi32(val, offset))
1293}
1294
1295/// Polynomial exp for AVX: range-reduction + 6-term Taylor. Higher accuracy (~1e-6)
1296/// for standalone exp (softmax, etc.) where precision matters more.
1297/// WHY 6 terms: 6th-order Taylor series for 2^f on [0,1), max error ~1e-7, good accuracy/speed tradeoff.
1298#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1299#[allow(unsafe_code)]
1300#[allow(unsafe_op_in_unsafe_fn)]
1301#[target_feature(enable = "avx")]
1302unsafe fn fast_exp_avx(x: __m256) -> __m256 {
1303    let ln2_inv = _mm256_set1_ps(std::f32::consts::LOG2_E);
1304    let ln2_hi = _mm256_set1_ps(0.693_359_4);
1305    let ln2_lo = _mm256_set1_ps(-2.121_944_4e-4);
1306
1307    let c0 = _mm256_set1_ps(1.0);
1308    let c1 = _mm256_set1_ps(1.0);
1309    let c2 = _mm256_set1_ps(0.5);
1310    let c3 = _mm256_set1_ps(1.0 / 6.0);
1311    let c4 = _mm256_set1_ps(1.0 / 24.0);
1312    let c5 = _mm256_set1_ps(1.0 / 120.0);
1313    let c6 = _mm256_set1_ps(1.0 / 720.0);
1314
1315    let x = _mm256_max_ps(
1316        _mm256_set1_ps(-88.0),
1317        _mm256_min_ps(_mm256_set1_ps(88.0), x),
1318    );
1319
1320    let n_f = _mm256_mul_ps(x, ln2_inv);
1321    let n_i = _mm256_cvtps_epi32(n_f);
1322    let n_f = _mm256_cvtepi32_ps(n_i);
1323
1324    let r = _mm256_sub_ps(
1325        _mm256_sub_ps(x, _mm256_mul_ps(n_f, ln2_hi)),
1326        _mm256_mul_ps(n_f, ln2_lo),
1327    );
1328
1329    let mut poly = _mm256_add_ps(c5, _mm256_mul_ps(r, c6));
1330    poly = _mm256_add_ps(c4, _mm256_mul_ps(r, poly));
1331    poly = _mm256_add_ps(c3, _mm256_mul_ps(r, poly));
1332    poly = _mm256_add_ps(c2, _mm256_mul_ps(r, poly));
1333    poly = _mm256_add_ps(c1, _mm256_mul_ps(r, poly));
1334    poly = _mm256_add_ps(c0, _mm256_mul_ps(r, poly));
1335
1336    let bias = _mm256_set1_epi32(127);
1337    let pow2n = {
1338        #[cfg(target_arch = "x86")]
1339        use std::arch::x86::{_mm256_add_epi32, _mm256_slli_epi32};
1340        #[cfg(target_arch = "x86_64")]
1341        use std::arch::x86_64::{_mm256_add_epi32, _mm256_slli_epi32};
1342        _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_add_epi32(n_i, bias), 23))
1343    };
1344
1345    _mm256_mul_ps(poly, pow2n)
1346}
1347
1348// ===========================================================================
1349// NEON fast-exp helper (4-wide)
1350// ===========================================================================
1351
1352#[cfg(target_arch = "aarch64")]
1353#[allow(unsafe_code)]
1354#[allow(unsafe_op_in_unsafe_fn)]
1355#[target_feature(enable = "neon")]
1356unsafe fn fast_exp_neon(x: float32x4_t) -> float32x4_t {
1357    use std::arch::aarch64::{
1358        vaddq_s32, vcvtnq_s32_f32, vcvtq_f32_s32, vreinterpretq_f32_s32, vshlq_n_s32,
1359    };
1360
1361    let ln2_inv = vdupq_n_f32(std::f32::consts::LOG2_E);
1362    let ln2_hi = vdupq_n_f32(0.693_359_4);
1363    let ln2_lo = vdupq_n_f32(-2.121_944_4e-4);
1364
1365    let c0 = vdupq_n_f32(1.0);
1366    let c1 = vdupq_n_f32(1.0);
1367    let c2 = vdupq_n_f32(0.5);
1368    let c3 = vdupq_n_f32(1.0 / 6.0);
1369    let c4 = vdupq_n_f32(1.0 / 24.0);
1370    let c5 = vdupq_n_f32(1.0 / 120.0);
1371    let c6 = vdupq_n_f32(1.0 / 720.0);
1372
1373    let x = vmaxq_f32(vdupq_n_f32(-88.0), vminq_f32(vdupq_n_f32(88.0), x));
1374
1375    let n_f = vmulq_f32(x, ln2_inv);
1376    let n_i = vcvtnq_s32_f32(n_f);
1377    let n_f = vcvtq_f32_s32(n_i);
1378
1379    let r = vsubq_f32(vsubq_f32(x, vmulq_f32(n_f, ln2_hi)), vmulq_f32(n_f, ln2_lo));
1380
1381    let mut poly = vfmaq_f32(c5, r, c6);
1382    poly = vfmaq_f32(c4, r, poly);
1383    poly = vfmaq_f32(c3, r, poly);
1384    poly = vfmaq_f32(c2, r, poly);
1385    poly = vfmaq_f32(c1, r, poly);
1386    poly = vfmaq_f32(c0, r, poly);
1387
1388    use std::arch::aarch64::vdupq_n_s32;
1389    let bias = vdupq_n_s32(127);
1390    let pow2n = vreinterpretq_f32_s32(vshlq_n_s32::<23>(vaddq_s32(n_i, bias)));
1391
1392    vmulq_f32(poly, pow2n)
1393}
1394
1395// ===========================================================================
1396// Exp slice implementations
1397// ===========================================================================
1398
1399#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1400#[allow(unsafe_code)]
1401#[allow(unsafe_op_in_unsafe_fn)]
1402#[target_feature(enable = "sse")]
1403unsafe fn exp_slice_sse(input: &[f32], output: &mut [f32]) {
1404    let len = input.len();
1405    let in_ptr = input.as_ptr();
1406    let out_ptr = output.as_mut_ptr();
1407    let mut index = 0usize;
1408
1409    while index + 4 <= len {
1410        let v = _mm_loadu_ps(in_ptr.add(index));
1411        let r = fast_exp_sse(v);
1412        _mm_storeu_ps(out_ptr.add(index), r);
1413        index += 4;
1414    }
1415
1416    while index < len {
1417        *out_ptr.add(index) = (*in_ptr.add(index)).exp();
1418        index += 1;
1419    }
1420}
1421
1422#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1423#[allow(unsafe_code)]
1424#[allow(unsafe_op_in_unsafe_fn)]
1425#[target_feature(enable = "avx")]
1426unsafe fn exp_slice_avx(input: &[f32], output: &mut [f32]) {
1427    let len = input.len();
1428    let in_ptr = input.as_ptr();
1429    let out_ptr = output.as_mut_ptr();
1430    let mut index = 0usize;
1431
1432    // 2x unrolled: process 16 floats per iteration to hide FMA latency.
1433    while index + 16 <= len {
1434        // Prefetch next cacheline (64 bytes = 16 floats ahead)
1435        #[cfg(target_arch = "x86")]
1436        {
1437            use std::arch::x86::_mm_prefetch;
1438            _mm_prefetch::<3>(in_ptr.add(index + 16) as *const i8);
1439        }
1440        #[cfg(target_arch = "x86_64")]
1441        {
1442            use std::arch::x86_64::_mm_prefetch;
1443            _mm_prefetch::<3>(in_ptr.add(index + 16) as *const i8);
1444        }
1445        let v0 = _mm256_loadu_ps(in_ptr.add(index));
1446        let v1 = _mm256_loadu_ps(in_ptr.add(index + 8));
1447        let r0 = fast_exp_avx(v0);
1448        let r1 = fast_exp_avx(v1);
1449        _mm256_storeu_ps(out_ptr.add(index), r0);
1450        _mm256_storeu_ps(out_ptr.add(index + 8), r1);
1451        index += 16;
1452    }
1453
1454    // Handle remaining 8-float chunk
1455    while index + 8 <= len {
1456        let v = _mm256_loadu_ps(in_ptr.add(index));
1457        let r = fast_exp_avx(v);
1458        _mm256_storeu_ps(out_ptr.add(index), r);
1459        index += 8;
1460    }
1461
1462    if index < len {
1463        exp_slice_sse(&input[index..], &mut output[index..]);
1464    }
1465}
1466
1467#[cfg(target_arch = "aarch64")]
1468#[allow(unsafe_code, dead_code)]
1469#[allow(unsafe_op_in_unsafe_fn)]
1470#[target_feature(enable = "neon")]
1471unsafe fn exp_slice_neon(input: &[f32], output: &mut [f32]) {
1472    let len = input.len();
1473    let in_ptr = input.as_ptr();
1474    let out_ptr = output.as_mut_ptr();
1475    let mut index = 0usize;
1476
1477    while index + 4 <= len {
1478        let v = vld1q_f32(in_ptr.add(index));
1479        let r = fast_exp_neon(v);
1480        vst1q_f32(out_ptr.add(index), r);
1481        index += 4;
1482    }
1483
1484    while index < len {
1485        *out_ptr.add(index) = (*in_ptr.add(index)).exp();
1486        index += 1;
1487    }
1488}
1489
1490// ===========================================================================
1491// Fused subtract-and-exp: output[i] = exp(input[i] - offset)
1492// ===========================================================================
1493
1494#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1495#[allow(unsafe_code)]
1496#[allow(unsafe_op_in_unsafe_fn)]
1497#[target_feature(enable = "sse")]
1498unsafe fn sub_exp_slice_sse(input: &[f32], offset: f32, output: &mut [f32]) {
1499    let len = input.len();
1500    let in_ptr = input.as_ptr();
1501    let out_ptr = output.as_mut_ptr();
1502    let off = _mm_set1_ps(offset);
1503    let mut index = 0usize;
1504
1505    while index + 4 <= len {
1506        let v = _mm_loadu_ps(in_ptr.add(index));
1507        let shifted = _mm_sub_ps(v, off);
1508        let r = fast_exp_sse(shifted);
1509        _mm_storeu_ps(out_ptr.add(index), r);
1510        index += 4;
1511    }
1512
1513    while index < len {
1514        *out_ptr.add(index) = (*in_ptr.add(index) - offset).exp();
1515        index += 1;
1516    }
1517}
1518
1519#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1520#[allow(unsafe_code)]
1521#[allow(unsafe_op_in_unsafe_fn)]
1522#[target_feature(enable = "avx")]
1523unsafe fn sub_exp_slice_avx(input: &[f32], offset: f32, output: &mut [f32]) {
1524    let len = input.len();
1525    let in_ptr = input.as_ptr();
1526    let out_ptr = output.as_mut_ptr();
1527    let off = _mm256_set1_ps(offset);
1528    let mut index = 0usize;
1529
1530    // 2x unrolled: process 16 floats per iteration to hide FMA latency.
1531    while index + 16 <= len {
1532        #[cfg(target_arch = "x86")]
1533        {
1534            use std::arch::x86::_mm_prefetch;
1535            _mm_prefetch::<3>(in_ptr.add(index + 16) as *const i8);
1536        }
1537        #[cfg(target_arch = "x86_64")]
1538        {
1539            use std::arch::x86_64::_mm_prefetch;
1540            _mm_prefetch::<3>(in_ptr.add(index + 16) as *const i8);
1541        }
1542        let v0 = _mm256_loadu_ps(in_ptr.add(index));
1543        let v1 = _mm256_loadu_ps(in_ptr.add(index + 8));
1544        let shifted0 = _mm256_sub_ps(v0, off);
1545        let shifted1 = _mm256_sub_ps(v1, off);
1546        let r0 = fast_exp_avx(shifted0);
1547        let r1 = fast_exp_avx(shifted1);
1548        _mm256_storeu_ps(out_ptr.add(index), r0);
1549        _mm256_storeu_ps(out_ptr.add(index + 8), r1);
1550        index += 16;
1551    }
1552
1553    // Handle remaining 8-float chunk
1554    while index + 8 <= len {
1555        let v = _mm256_loadu_ps(in_ptr.add(index));
1556        let shifted = _mm256_sub_ps(v, off);
1557        let r = fast_exp_avx(shifted);
1558        _mm256_storeu_ps(out_ptr.add(index), r);
1559        index += 8;
1560    }
1561
1562    if index < len {
1563        sub_exp_slice_sse(&input[index..], offset, &mut output[index..]);
1564    }
1565}
1566
1567#[cfg(target_arch = "aarch64")]
1568#[allow(unsafe_code)]
1569#[allow(unsafe_op_in_unsafe_fn)]
1570#[target_feature(enable = "neon")]
1571unsafe fn sub_exp_slice_neon(input: &[f32], offset: f32, output: &mut [f32]) {
1572    let len = input.len();
1573    let in_ptr = input.as_ptr();
1574    let out_ptr = output.as_mut_ptr();
1575    let off = vdupq_n_f32(offset);
1576    let mut index = 0usize;
1577
1578    while index + 4 <= len {
1579        let v = vld1q_f32(in_ptr.add(index));
1580        let shifted = vsubq_f32(v, off);
1581        let r = fast_exp_neon(shifted);
1582        vst1q_f32(out_ptr.add(index), r);
1583        index += 4;
1584    }
1585
1586    while index < len {
1587        *out_ptr.add(index) = (*in_ptr.add(index) - offset).exp();
1588        index += 1;
1589    }
1590}
1591
1592// ===========================================================================
1593// Sigmoid slice implementations: sigmoid(x) = 1 / (1 + exp(-x))
1594// ===========================================================================
1595
1596#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1597#[allow(unsafe_code)]
1598#[allow(unsafe_op_in_unsafe_fn)]
1599#[target_feature(enable = "sse")]
1600unsafe fn sigmoid_slice_sse(input: &[f32], output: &mut [f32]) {
1601    #[cfg(target_arch = "x86")]
1602    use std::arch::x86::_mm_div_ps;
1603    #[cfg(target_arch = "x86_64")]
1604    use std::arch::x86_64::_mm_div_ps;
1605
1606    let len = input.len();
1607    let in_ptr = input.as_ptr();
1608    let out_ptr = output.as_mut_ptr();
1609    let one = _mm_set1_ps(1.0);
1610    let zero = _mm_setzero_ps();
1611    let mut index = 0usize;
1612
1613    // Process 16 elements per iteration (4 SSE registers)
1614    while index + 16 <= len {
1615        let x0 = _mm_loadu_ps(in_ptr.add(index));
1616        let x1 = _mm_loadu_ps(in_ptr.add(index + 4));
1617        let x2 = _mm_loadu_ps(in_ptr.add(index + 8));
1618        let x3 = _mm_loadu_ps(in_ptr.add(index + 12));
1619
1620        // Bit-trick exp is sufficient for sigmoid (output clamped 0-1, errors wash out)
1621        let e0 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x0));
1622        let e1 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x1));
1623        let e2 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x2));
1624        let e3 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x3));
1625
1626        let r0 = _mm_div_ps(one, _mm_add_ps(one, e0));
1627        let r1 = _mm_div_ps(one, _mm_add_ps(one, e1));
1628        let r2 = _mm_div_ps(one, _mm_add_ps(one, e2));
1629        let r3 = _mm_div_ps(one, _mm_add_ps(one, e3));
1630
1631        _mm_storeu_ps(out_ptr.add(index), r0);
1632        _mm_storeu_ps(out_ptr.add(index + 4), r1);
1633        _mm_storeu_ps(out_ptr.add(index + 8), r2);
1634        _mm_storeu_ps(out_ptr.add(index + 12), r3);
1635
1636        index += 16;
1637    }
1638
1639    // Remaining 4 at a time
1640    while index + 4 <= len {
1641        let x = _mm_loadu_ps(in_ptr.add(index));
1642        let neg_x = _mm_sub_ps(zero, x);
1643        let exp_neg_x = fast_exp_bittrick_sse(neg_x);
1644        let denom = _mm_add_ps(one, exp_neg_x);
1645        let result = _mm_div_ps(one, denom);
1646        _mm_storeu_ps(out_ptr.add(index), result);
1647        index += 4;
1648    }
1649
1650    while index < len {
1651        *out_ptr.add(index) = sigmoid_scalar(*in_ptr.add(index));
1652        index += 1;
1653    }
1654}
1655
1656#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1657#[allow(unsafe_code)]
1658#[allow(unsafe_op_in_unsafe_fn)]
1659#[target_feature(enable = "avx")]
1660unsafe fn sigmoid_slice_avx(input: &[f32], output: &mut [f32]) {
1661    #[cfg(target_arch = "x86")]
1662    use std::arch::x86::_mm256_div_ps;
1663    #[cfg(target_arch = "x86_64")]
1664    use std::arch::x86_64::_mm256_div_ps;
1665
1666    let len = input.len();
1667    let in_ptr = input.as_ptr();
1668    let out_ptr = output.as_mut_ptr();
1669    let one = _mm256_set1_ps(1.0);
1670    let zero = _mm256_setzero_ps();
1671    let mut index = 0usize;
1672
1673    // Process 32 elements per iteration (4 AVX registers)
1674    while index + 32 <= len {
1675        let x0 = _mm256_loadu_ps(in_ptr.add(index));
1676        let x1 = _mm256_loadu_ps(in_ptr.add(index + 8));
1677        let x2 = _mm256_loadu_ps(in_ptr.add(index + 16));
1678        let x3 = _mm256_loadu_ps(in_ptr.add(index + 24));
1679
1680        // Use Schraudolph bit-trick exp for ~3x speedup over polynomial
1681        let e0 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x0));
1682        let e1 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x1));
1683        let e2 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x2));
1684        let e3 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x3));
1685
1686        let r0 = _mm256_div_ps(one, _mm256_add_ps(one, e0));
1687        let r1 = _mm256_div_ps(one, _mm256_add_ps(one, e1));
1688        let r2 = _mm256_div_ps(one, _mm256_add_ps(one, e2));
1689        let r3 = _mm256_div_ps(one, _mm256_add_ps(one, e3));
1690
1691        _mm256_storeu_ps(out_ptr.add(index), r0);
1692        _mm256_storeu_ps(out_ptr.add(index + 8), r1);
1693        _mm256_storeu_ps(out_ptr.add(index + 16), r2);
1694        _mm256_storeu_ps(out_ptr.add(index + 24), r3);
1695
1696        index += 32;
1697    }
1698
1699    // Remaining 8 at a time
1700    while index + 8 <= len {
1701        let x = _mm256_loadu_ps(in_ptr.add(index));
1702        let neg_x = _mm256_sub_ps(zero, x);
1703        let exp_neg_x = fast_exp_bittrick_avx(neg_x);
1704        let denom = _mm256_add_ps(one, exp_neg_x);
1705        let result = _mm256_div_ps(one, denom);
1706        _mm256_storeu_ps(out_ptr.add(index), result);
1707        index += 8;
1708    }
1709
1710    if index < len {
1711        sigmoid_slice_sse(&input[index..], &mut output[index..]);
1712    }
1713}
1714
1715// (sigmoid_slice_neon defined above at line ~291)
1716
1717// ===========================================================================
1718// Tanh slice implementations: tanh(x) = 2 * sigmoid(2x) - 1
1719// ===========================================================================
1720
1721#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1722#[allow(unsafe_code)]
1723#[allow(unsafe_op_in_unsafe_fn)]
1724#[target_feature(enable = "sse")]
1725unsafe fn tanh_slice_sse(input: &[f32], output: &mut [f32]) {
1726    #[cfg(target_arch = "x86")]
1727    use std::arch::x86::_mm_div_ps;
1728    #[cfg(target_arch = "x86_64")]
1729    use std::arch::x86_64::_mm_div_ps;
1730    let len = input.len();
1731    let in_ptr = input.as_ptr();
1732    let out_ptr = output.as_mut_ptr();
1733    let two = _mm_set1_ps(2.0);
1734    let one = _mm_set1_ps(1.0);
1735    let zero = _mm_setzero_ps();
1736    let mut index = 0usize;
1737
1738    while index + 4 <= len {
1739        let x = _mm_loadu_ps(in_ptr.add(index));
1740        let two_x = _mm_mul_ps(two, x);
1741        // sigmoid(2x) = 1 / (1 + exp(-2x))
1742        let neg_two_x = _mm_sub_ps(zero, two_x);
1743        // Use polynomial exp (not bit-trick) for tanh — needs ~1e-4 accuracy
1744        let exp_neg = fast_exp_sse(neg_two_x);
1745        let sig = _mm_div_ps(one, _mm_add_ps(one, exp_neg));
1746        // tanh = 2 * sig - 1
1747        let result = _mm_sub_ps(_mm_mul_ps(two, sig), one);
1748        _mm_storeu_ps(out_ptr.add(index), result);
1749        index += 4;
1750    }
1751
1752    while index < len {
1753        *out_ptr.add(index) = (*in_ptr.add(index)).tanh();
1754        index += 1;
1755    }
1756}
1757
1758#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1759#[allow(unsafe_code)]
1760#[allow(unsafe_op_in_unsafe_fn)]
1761#[target_feature(enable = "avx")]
1762unsafe fn tanh_slice_avx(input: &[f32], output: &mut [f32]) {
1763    #[cfg(target_arch = "x86")]
1764    use std::arch::x86::_mm256_div_ps;
1765    #[cfg(target_arch = "x86_64")]
1766    use std::arch::x86_64::_mm256_div_ps;
1767    let len = input.len();
1768    let in_ptr = input.as_ptr();
1769    let out_ptr = output.as_mut_ptr();
1770    let two = _mm256_set1_ps(2.0);
1771    let one = _mm256_set1_ps(1.0);
1772    let zero = _mm256_setzero_ps();
1773    let mut index = 0usize;
1774
1775    while index + 8 <= len {
1776        let x = _mm256_loadu_ps(in_ptr.add(index));
1777        let two_x = _mm256_mul_ps(two, x);
1778        let neg_two_x = _mm256_sub_ps(zero, two_x);
1779        // Use polynomial exp (not bit-trick) for tanh — needs ~1e-4 accuracy
1780        let exp_neg = fast_exp_avx(neg_two_x);
1781        let sig = _mm256_div_ps(one, _mm256_add_ps(one, exp_neg));
1782        let result = _mm256_sub_ps(_mm256_mul_ps(two, sig), one);
1783        _mm256_storeu_ps(out_ptr.add(index), result);
1784        index += 8;
1785    }
1786
1787    if index < len {
1788        tanh_slice_sse(&input[index..], &mut output[index..]);
1789    }
1790}
1791
1792#[cfg(target_arch = "aarch64")]
1793#[allow(unsafe_code, dead_code)]
1794#[allow(unsafe_op_in_unsafe_fn)]
1795#[target_feature(enable = "neon")]
1796unsafe fn tanh_slice_neon(input: &[f32], output: &mut [f32]) {
1797    let len = input.len();
1798    let in_ptr = input.as_ptr();
1799    let out_ptr = output.as_mut_ptr();
1800    let two = vdupq_n_f32(2.0);
1801    let one = vdupq_n_f32(1.0);
1802    let mut index = 0usize;
1803
1804    // 8x unrolled: 32 elements per iteration, using fast 3-term exp polynomial
1805    while index + 32 <= len {
1806        let x0 = vld1q_f32(in_ptr.add(index));
1807        let x1 = vld1q_f32(in_ptr.add(index + 4));
1808        let x2 = vld1q_f32(in_ptr.add(index + 8));
1809        let x3 = vld1q_f32(in_ptr.add(index + 12));
1810        let x4 = vld1q_f32(in_ptr.add(index + 16));
1811        let x5 = vld1q_f32(in_ptr.add(index + 20));
1812        let x6 = vld1q_f32(in_ptr.add(index + 24));
1813        let x7 = vld1q_f32(in_ptr.add(index + 28));
1814
1815        // exp(-2x) using fast 3-term polynomial (sufficient for tanh)
1816        let e0 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x0)));
1817        let e1 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x1)));
1818        let e2 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x2)));
1819        let e3 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x3)));
1820        let e4 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x4)));
1821        let e5 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x5)));
1822        let e6 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x6)));
1823        let e7 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x7)));
1824
1825        // tanh(x) = 2 * sigmoid(2x) - 1 = 2/(1+exp(-2x)) - 1
1826        vst1q_f32(
1827            out_ptr.add(index),
1828            vsubq_f32(vdivq_f32(two, vaddq_f32(one, e0)), one),
1829        );
1830        vst1q_f32(
1831            out_ptr.add(index + 4),
1832            vsubq_f32(vdivq_f32(two, vaddq_f32(one, e1)), one),
1833        );
1834        vst1q_f32(
1835            out_ptr.add(index + 8),
1836            vsubq_f32(vdivq_f32(two, vaddq_f32(one, e2)), one),
1837        );
1838        vst1q_f32(
1839            out_ptr.add(index + 12),
1840            vsubq_f32(vdivq_f32(two, vaddq_f32(one, e3)), one),
1841        );
1842        vst1q_f32(
1843            out_ptr.add(index + 16),
1844            vsubq_f32(vdivq_f32(two, vaddq_f32(one, e4)), one),
1845        );
1846        vst1q_f32(
1847            out_ptr.add(index + 20),
1848            vsubq_f32(vdivq_f32(two, vaddq_f32(one, e5)), one),
1849        );
1850        vst1q_f32(
1851            out_ptr.add(index + 24),
1852            vsubq_f32(vdivq_f32(two, vaddq_f32(one, e6)), one),
1853        );
1854        vst1q_f32(
1855            out_ptr.add(index + 28),
1856            vsubq_f32(vdivq_f32(two, vaddq_f32(one, e7)), one),
1857        );
1858        index += 32;
1859    }
1860
1861    while index + 4 <= len {
1862        let x = vld1q_f32(in_ptr.add(index));
1863        let two_x = vmulq_f32(two, x);
1864        let neg_two_x = vnegq_f32(two_x);
1865        let exp_neg = fast_exp_sigmoid_neon(neg_two_x);
1866        let denom = vaddq_f32(one, exp_neg);
1867        let result = vsubq_f32(vdivq_f32(two, denom), one);
1868        vst1q_f32(out_ptr.add(index), result);
1869        index += 4;
1870    }
1871
1872    while index < len {
1873        *out_ptr.add(index) = (*in_ptr.add(index)).tanh();
1874        index += 1;
1875    }
1876}
1877
1878#[cfg(target_arch = "aarch64")]
1879#[allow(unsafe_code, dead_code)]
1880#[allow(unsafe_op_in_unsafe_fn)]
1881#[target_feature(enable = "neon")]
1882/// Fused SiLU: output[i] = x * sigmoid(x) in a single pass.
1883/// 8x unrolled with fast 3-term exp polynomial.
1884unsafe fn silu_slice_neon(input: &[f32], output: &mut [f32]) {
1885    let len = input.len();
1886    let in_ptr = input.as_ptr();
1887    let out_ptr = output.as_mut_ptr();
1888    let one = vdupq_n_f32(1.0);
1889    let mut index = 0usize;
1890
1891    // 8x unrolled: 32 elements per iteration
1892    while index + 32 <= len {
1893        let x0 = vld1q_f32(in_ptr.add(index));
1894        let x1 = vld1q_f32(in_ptr.add(index + 4));
1895        let x2 = vld1q_f32(in_ptr.add(index + 8));
1896        let x3 = vld1q_f32(in_ptr.add(index + 12));
1897        let x4 = vld1q_f32(in_ptr.add(index + 16));
1898        let x5 = vld1q_f32(in_ptr.add(index + 20));
1899        let x6 = vld1q_f32(in_ptr.add(index + 24));
1900        let x7 = vld1q_f32(in_ptr.add(index + 28));
1901
1902        // sigmoid(x) = 1 / (1 + exp(-x))
1903        let e0 = fast_exp_sigmoid_neon(vnegq_f32(x0));
1904        let e1 = fast_exp_sigmoid_neon(vnegq_f32(x1));
1905        let e2 = fast_exp_sigmoid_neon(vnegq_f32(x2));
1906        let e3 = fast_exp_sigmoid_neon(vnegq_f32(x3));
1907        let e4 = fast_exp_sigmoid_neon(vnegq_f32(x4));
1908        let e5 = fast_exp_sigmoid_neon(vnegq_f32(x5));
1909        let e6 = fast_exp_sigmoid_neon(vnegq_f32(x6));
1910        let e7 = fast_exp_sigmoid_neon(vnegq_f32(x7));
1911
1912        // silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
1913        vst1q_f32(
1914            out_ptr.add(index),
1915            vmulq_f32(x0, vdivq_f32(one, vaddq_f32(one, e0))),
1916        );
1917        vst1q_f32(
1918            out_ptr.add(index + 4),
1919            vmulq_f32(x1, vdivq_f32(one, vaddq_f32(one, e1))),
1920        );
1921        vst1q_f32(
1922            out_ptr.add(index + 8),
1923            vmulq_f32(x2, vdivq_f32(one, vaddq_f32(one, e2))),
1924        );
1925        vst1q_f32(
1926            out_ptr.add(index + 12),
1927            vmulq_f32(x3, vdivq_f32(one, vaddq_f32(one, e3))),
1928        );
1929        vst1q_f32(
1930            out_ptr.add(index + 16),
1931            vmulq_f32(x4, vdivq_f32(one, vaddq_f32(one, e4))),
1932        );
1933        vst1q_f32(
1934            out_ptr.add(index + 20),
1935            vmulq_f32(x5, vdivq_f32(one, vaddq_f32(one, e5))),
1936        );
1937        vst1q_f32(
1938            out_ptr.add(index + 24),
1939            vmulq_f32(x6, vdivq_f32(one, vaddq_f32(one, e6))),
1940        );
1941        vst1q_f32(
1942            out_ptr.add(index + 28),
1943            vmulq_f32(x7, vdivq_f32(one, vaddq_f32(one, e7))),
1944        );
1945        index += 32;
1946    }
1947
1948    while index + 4 <= len {
1949        let x = vld1q_f32(in_ptr.add(index));
1950        let e = fast_exp_sigmoid_neon(vnegq_f32(x));
1951        let sig = vdivq_f32(one, vaddq_f32(one, e));
1952        vst1q_f32(out_ptr.add(index), vmulq_f32(x, sig));
1953        index += 4;
1954    }
1955
1956    while index < len {
1957        let x = *in_ptr.add(index);
1958        let s = 1.0 / (1.0 + (-x).exp());
1959        *out_ptr.add(index) = x * s;
1960        index += 1;
1961    }
1962}
1963
1964/// Fused SiLU (x * sigmoid(x)) using SSE.
1965#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1966#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1967#[target_feature(enable = "sse")]
1968unsafe fn silu_slice_sse(input: &[f32], output: &mut [f32]) {
1969    #[cfg(target_arch = "x86")]
1970    use std::arch::x86::_mm_div_ps;
1971    #[cfg(target_arch = "x86_64")]
1972    use std::arch::x86_64::_mm_div_ps;
1973
1974    let len = input.len();
1975    let in_ptr = input.as_ptr();
1976    let out_ptr = output.as_mut_ptr();
1977    let one = _mm_set1_ps(1.0);
1978    let zero = _mm_setzero_ps();
1979    let mut index = 0usize;
1980
1981    while index + 16 <= len {
1982        let x0 = _mm_loadu_ps(in_ptr.add(index));
1983        let x1 = _mm_loadu_ps(in_ptr.add(index + 4));
1984        let x2 = _mm_loadu_ps(in_ptr.add(index + 8));
1985        let x3 = _mm_loadu_ps(in_ptr.add(index + 12));
1986
1987        // Use Schraudolph bit-trick exp for ~3x speedup
1988        let e0 = fast_exp_sse(_mm_sub_ps(zero, x0));
1989        let e1 = fast_exp_sse(_mm_sub_ps(zero, x1));
1990        let e2 = fast_exp_sse(_mm_sub_ps(zero, x2));
1991        let e3 = fast_exp_sse(_mm_sub_ps(zero, x3));
1992
1993        // silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
1994        _mm_storeu_ps(
1995            out_ptr.add(index),
1996            _mm_mul_ps(x0, _mm_div_ps(one, _mm_add_ps(one, e0))),
1997        );
1998        _mm_storeu_ps(
1999            out_ptr.add(index + 4),
2000            _mm_mul_ps(x1, _mm_div_ps(one, _mm_add_ps(one, e1))),
2001        );
2002        _mm_storeu_ps(
2003            out_ptr.add(index + 8),
2004            _mm_mul_ps(x2, _mm_div_ps(one, _mm_add_ps(one, e2))),
2005        );
2006        _mm_storeu_ps(
2007            out_ptr.add(index + 12),
2008            _mm_mul_ps(x3, _mm_div_ps(one, _mm_add_ps(one, e3))),
2009        );
2010
2011        index += 16;
2012    }
2013
2014    while index + 4 <= len {
2015        let x = _mm_loadu_ps(in_ptr.add(index));
2016        let e = fast_exp_sse(_mm_sub_ps(zero, x));
2017        let sig = _mm_div_ps(one, _mm_add_ps(one, e));
2018        _mm_storeu_ps(out_ptr.add(index), _mm_mul_ps(x, sig));
2019        index += 4;
2020    }
2021
2022    while index < len {
2023        let v = *in_ptr.add(index);
2024        let s = 1.0 / (1.0 + (-v).exp());
2025        *out_ptr.add(index) = v * s;
2026        index += 1;
2027    }
2028}
2029
2030/// Fused SiLU (x * sigmoid(x)) using AVX.
2031#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2032#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
2033#[target_feature(enable = "avx")]
2034unsafe fn silu_slice_avx(input: &[f32], output: &mut [f32]) {
2035    #[cfg(target_arch = "x86")]
2036    use std::arch::x86::_mm256_div_ps;
2037    #[cfg(target_arch = "x86_64")]
2038    use std::arch::x86_64::_mm256_div_ps;
2039
2040    let len = input.len();
2041    let in_ptr = input.as_ptr();
2042    let out_ptr = output.as_mut_ptr();
2043    let one = _mm256_set1_ps(1.0);
2044    let zero = _mm256_setzero_ps();
2045    let mut index = 0usize;
2046
2047    while index + 32 <= len {
2048        let x0 = _mm256_loadu_ps(in_ptr.add(index));
2049        let x1 = _mm256_loadu_ps(in_ptr.add(index + 8));
2050        let x2 = _mm256_loadu_ps(in_ptr.add(index + 16));
2051        let x3 = _mm256_loadu_ps(in_ptr.add(index + 24));
2052
2053        // Use Schraudolph bit-trick exp for ~3x speedup
2054        let e0 = fast_exp_avx(_mm256_sub_ps(zero, x0));
2055        let e1 = fast_exp_avx(_mm256_sub_ps(zero, x1));
2056        let e2 = fast_exp_avx(_mm256_sub_ps(zero, x2));
2057        let e3 = fast_exp_avx(_mm256_sub_ps(zero, x3));
2058
2059        // silu(x) = x / (1 + exp(-x))
2060        _mm256_storeu_ps(
2061            out_ptr.add(index),
2062            _mm256_mul_ps(x0, _mm256_div_ps(one, _mm256_add_ps(one, e0))),
2063        );
2064        _mm256_storeu_ps(
2065            out_ptr.add(index + 8),
2066            _mm256_mul_ps(x1, _mm256_div_ps(one, _mm256_add_ps(one, e1))),
2067        );
2068        _mm256_storeu_ps(
2069            out_ptr.add(index + 16),
2070            _mm256_mul_ps(x2, _mm256_div_ps(one, _mm256_add_ps(one, e2))),
2071        );
2072        _mm256_storeu_ps(
2073            out_ptr.add(index + 24),
2074            _mm256_mul_ps(x3, _mm256_div_ps(one, _mm256_add_ps(one, e3))),
2075        );
2076
2077        index += 32;
2078    }
2079
2080    while index + 8 <= len {
2081        let x = _mm256_loadu_ps(in_ptr.add(index));
2082        let e = fast_exp_avx(_mm256_sub_ps(zero, x));
2083        let sig = _mm256_div_ps(one, _mm256_add_ps(one, e));
2084        _mm256_storeu_ps(out_ptr.add(index), _mm256_mul_ps(x, sig));
2085        index += 8;
2086    }
2087
2088    if index < len {
2089        silu_slice_sse(&input[index..], &mut output[index..]);
2090    }
2091}
2092
2093// ===========================================================================
2094// Max-reduce implementations
2095// ===========================================================================
2096
2097#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2098#[allow(unsafe_code)]
2099#[allow(unsafe_op_in_unsafe_fn)]
2100#[target_feature(enable = "sse")]
2101unsafe fn max_reduce_sse(data: &[f32]) -> f32 {
2102    let len = data.len();
2103    let ptr = data.as_ptr();
2104    let mut index = 0usize;
2105    let mut acc = _mm_set1_ps(f32::NEG_INFINITY);
2106
2107    while index + 4 <= len {
2108        let v = _mm_loadu_ps(ptr.add(index));
2109        acc = _mm_max_ps(acc, v);
2110        index += 4;
2111    }
2112
2113    // Horizontal max of 4-lane accumulator
2114    let mut buf = [0.0f32; 4];
2115    _mm_storeu_ps(buf.as_mut_ptr(), acc);
2116    let mut result = buf[0].max(buf[1]).max(buf[2]).max(buf[3]);
2117
2118    while index < len {
2119        result = result.max(*ptr.add(index));
2120        index += 1;
2121    }
2122    result
2123}
2124
2125#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2126#[allow(unsafe_code)]
2127#[allow(unsafe_op_in_unsafe_fn)]
2128#[target_feature(enable = "avx")]
2129unsafe fn max_reduce_avx(data: &[f32]) -> f32 {
2130    let len = data.len();
2131    let ptr = data.as_ptr();
2132    let mut index = 0usize;
2133    let mut acc = _mm256_set1_ps(f32::NEG_INFINITY);
2134
2135    while index + 8 <= len {
2136        let v = _mm256_loadu_ps(ptr.add(index));
2137        acc = _mm256_max_ps(acc, v);
2138        index += 8;
2139    }
2140
2141    // Horizontal max of 8-lane accumulator
2142    let mut buf = [0.0f32; 8];
2143    _mm256_storeu_ps(buf.as_mut_ptr(), acc);
2144    let mut result = buf[0];
2145    for i in 1..8 {
2146        result = result.max(buf[i]);
2147    }
2148
2149    while index < len {
2150        result = result.max(*ptr.add(index));
2151        index += 1;
2152    }
2153    result
2154}
2155
2156#[cfg(target_arch = "aarch64")]
2157#[allow(unsafe_code, dead_code)]
2158#[allow(unsafe_op_in_unsafe_fn)]
2159#[target_feature(enable = "neon")]
2160unsafe fn max_reduce_neon(data: &[f32]) -> f32 {
2161    use std::arch::aarch64::vmaxvq_f32;
2162
2163    let len = data.len();
2164    let ptr = data.as_ptr();
2165    let mut index = 0usize;
2166    let mut acc = vdupq_n_f32(f32::NEG_INFINITY);
2167
2168    while index + 4 <= len {
2169        let v = vld1q_f32(ptr.add(index));
2170        acc = vmaxq_f32(acc, v);
2171        index += 4;
2172    }
2173
2174    let mut result = vmaxvq_f32(acc);
2175    while index < len {
2176        result = result.max(*ptr.add(index));
2177        index += 1;
2178    }
2179    result
2180}
2181
2182// ===========================================================================
2183// Add-reduce (sum) implementations
2184// ===========================================================================
2185
2186#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2187#[allow(unsafe_code)]
2188#[allow(unsafe_op_in_unsafe_fn)]
2189#[target_feature(enable = "sse")]
2190unsafe fn add_reduce_sse(data: &[f32]) -> f32 {
2191    let len = data.len();
2192    let ptr = data.as_ptr();
2193    let mut index = 0usize;
2194    let mut acc = _mm_setzero_ps();
2195
2196    while index + 4 <= len {
2197        let v = _mm_loadu_ps(ptr.add(index));
2198        acc = _mm_add_ps(acc, v);
2199        index += 4;
2200    }
2201
2202    // Horizontal sum of 4-lane accumulator
2203    let mut buf = [0.0f32; 4];
2204    _mm_storeu_ps(buf.as_mut_ptr(), acc);
2205    let mut result = buf[0] + buf[1] + buf[2] + buf[3];
2206
2207    while index < len {
2208        result += *ptr.add(index);
2209        index += 1;
2210    }
2211    result
2212}
2213
2214#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2215#[allow(unsafe_code)]
2216#[allow(unsafe_op_in_unsafe_fn)]
2217#[target_feature(enable = "avx")]
2218unsafe fn add_reduce_avx(data: &[f32]) -> f32 {
2219    let len = data.len();
2220    let ptr = data.as_ptr();
2221    let mut index = 0usize;
2222    let mut acc = _mm256_setzero_ps();
2223
2224    while index + 8 <= len {
2225        let v = _mm256_loadu_ps(ptr.add(index));
2226        acc = _mm256_add_ps(acc, v);
2227        index += 8;
2228    }
2229
2230    let mut buf = [0.0f32; 8];
2231    _mm256_storeu_ps(buf.as_mut_ptr(), acc);
2232    let mut result = buf[0] + buf[1] + buf[2] + buf[3] + buf[4] + buf[5] + buf[6] + buf[7];
2233
2234    while index < len {
2235        result += *ptr.add(index);
2236        index += 1;
2237    }
2238    result
2239}
2240
2241#[cfg(target_arch = "aarch64")]
2242#[allow(unsafe_code, dead_code)]
2243#[allow(unsafe_op_in_unsafe_fn)]
2244#[target_feature(enable = "neon")]
2245unsafe fn add_reduce_neon(data: &[f32]) -> f32 {
2246    use std::arch::aarch64::vaddvq_f32;
2247
2248    let len = data.len();
2249    let ptr = data.as_ptr();
2250    let mut index = 0usize;
2251    let mut acc = vdupq_n_f32(0.0);
2252
2253    while index + 4 <= len {
2254        let v = vld1q_f32(ptr.add(index));
2255        acc = vaddq_f32(acc, v);
2256        index += 4;
2257    }
2258
2259    let mut result = vaddvq_f32(acc);
2260    while index < len {
2261        result += *ptr.add(index);
2262        index += 1;
2263    }
2264    result
2265}
2266
2267// ===========================================================================
2268// FMA slice implementations: acc[i] += a[i] * b[i]
2269// ===========================================================================
2270
2271#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2272#[allow(unsafe_code)]
2273#[allow(unsafe_op_in_unsafe_fn)]
2274#[target_feature(enable = "sse")]
2275unsafe fn fma_slice_sse(a: &[f32], b: &[f32], acc: &mut [f32]) {
2276    let len = a.len();
2277    let a_ptr = a.as_ptr();
2278    let b_ptr = b.as_ptr();
2279    let acc_ptr = acc.as_mut_ptr();
2280    let mut index = 0usize;
2281
2282    while index + 4 <= len {
2283        let av = _mm_loadu_ps(a_ptr.add(index));
2284        let bv = _mm_loadu_ps(b_ptr.add(index));
2285        let cv = _mm_loadu_ps(acc_ptr.add(index));
2286        let result = _mm_add_ps(cv, _mm_mul_ps(av, bv));
2287        _mm_storeu_ps(acc_ptr.add(index), result);
2288        index += 4;
2289    }
2290
2291    if index < len {
2292        fma_slice_scalar(&a[index..], &b[index..], &mut acc[index..]);
2293    }
2294}
2295
2296#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2297#[allow(unsafe_code)]
2298#[allow(unsafe_op_in_unsafe_fn)]
2299#[target_feature(enable = "avx")]
2300unsafe fn fma_slice_avx(a: &[f32], b: &[f32], acc: &mut [f32]) {
2301    let len = a.len();
2302    let a_ptr = a.as_ptr();
2303    let b_ptr = b.as_ptr();
2304    let acc_ptr = acc.as_mut_ptr();
2305    let mut index = 0usize;
2306
2307    while index + 8 <= len {
2308        let av = _mm256_loadu_ps(a_ptr.add(index));
2309        let bv = _mm256_loadu_ps(b_ptr.add(index));
2310        let cv = _mm256_loadu_ps(acc_ptr.add(index));
2311        let result = _mm256_add_ps(cv, _mm256_mul_ps(av, bv));
2312        _mm256_storeu_ps(acc_ptr.add(index), result);
2313        index += 8;
2314    }
2315
2316    if index < len {
2317        fma_slice_sse(&a[index..], &b[index..], &mut acc[index..]);
2318    }
2319}
2320
2321#[cfg(target_arch = "aarch64")]
2322#[allow(unsafe_code, dead_code)]
2323#[allow(unsafe_op_in_unsafe_fn)]
2324#[target_feature(enable = "neon")]
2325unsafe fn fma_slice_neon(a: &[f32], b: &[f32], acc: &mut [f32]) {
2326    let len = a.len();
2327    let a_ptr = a.as_ptr();
2328    let b_ptr = b.as_ptr();
2329    let acc_ptr = acc.as_mut_ptr();
2330    let mut index = 0usize;
2331
2332    while index + 4 <= len {
2333        let av = vld1q_f32(a_ptr.add(index));
2334        let bv = vld1q_f32(b_ptr.add(index));
2335        let cv = vld1q_f32(acc_ptr.add(index));
2336        let result = vfmaq_f32(cv, av, bv);
2337        vst1q_f32(acc_ptr.add(index), result);
2338        index += 4;
2339    }
2340
2341    if index < len {
2342        fma_slice_scalar(&a[index..], &b[index..], &mut acc[index..]);
2343    }
2344}
2345
2346// ===========================================================================
2347// ReLU SIMD implementations (existing)
2348// ===========================================================================
2349
2350#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2351#[allow(unsafe_code)]
2352#[allow(unsafe_op_in_unsafe_fn)]
2353#[target_feature(enable = "sse")]
2354unsafe fn relu_slice_sse(values: &mut [f32]) {
2355    let len = values.len();
2356    let ptr = values.as_mut_ptr();
2357    let zero = _mm_setzero_ps();
2358    let mut index = 0usize;
2359
2360    while index + 4 <= len {
2361        let input = _mm_loadu_ps(ptr.add(index));
2362        let out = _mm_max_ps(input, zero);
2363        _mm_storeu_ps(ptr.add(index), out);
2364        index += 4;
2365    }
2366
2367    if index < len {
2368        relu_slice_scalar(&mut values[index..]);
2369    }
2370}
2371
2372#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2373#[allow(unsafe_code)]
2374#[allow(unsafe_op_in_unsafe_fn)]
2375#[target_feature(enable = "avx")]
2376unsafe fn relu_slice_avx(values: &mut [f32]) {
2377    let len = values.len();
2378    let ptr = values.as_mut_ptr();
2379    let zero = _mm256_setzero_ps();
2380    let mut index = 0usize;
2381
2382    // 4× unrolled: 32 elements per iteration
2383    while index + 32 <= len {
2384        let v0 = _mm256_max_ps(_mm256_loadu_ps(ptr.add(index)), zero);
2385        let v1 = _mm256_max_ps(_mm256_loadu_ps(ptr.add(index + 8)), zero);
2386        let v2 = _mm256_max_ps(_mm256_loadu_ps(ptr.add(index + 16)), zero);
2387        let v3 = _mm256_max_ps(_mm256_loadu_ps(ptr.add(index + 24)), zero);
2388        _mm256_storeu_ps(ptr.add(index), v0);
2389        _mm256_storeu_ps(ptr.add(index + 8), v1);
2390        _mm256_storeu_ps(ptr.add(index + 16), v2);
2391        _mm256_storeu_ps(ptr.add(index + 24), v3);
2392        index += 32;
2393    }
2394
2395    while index + 8 <= len {
2396        _mm256_storeu_ps(
2397            ptr.add(index),
2398            _mm256_max_ps(_mm256_loadu_ps(ptr.add(index)), zero),
2399        );
2400        index += 8;
2401    }
2402
2403    if index < len {
2404        relu_slice_sse(&mut values[index..]);
2405    }
2406}
2407
2408#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2409#[allow(unsafe_code)]
2410#[allow(unsafe_op_in_unsafe_fn)]
2411#[target_feature(enable = "sse")]
2412unsafe fn binary_same_shape_sse(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
2413    let len = lhs.len();
2414    let left_ptr = lhs.as_ptr();
2415    let right_ptr = rhs.as_ptr();
2416    let out_ptr = out.as_mut_ptr();
2417    let mut index = 0usize;
2418
2419    while index + 4 <= len {
2420        let left = _mm_loadu_ps(left_ptr.add(index));
2421        let right = _mm_loadu_ps(right_ptr.add(index));
2422        let result = match kind {
2423            BinaryKind::Add => _mm_add_ps(left, right),
2424            BinaryKind::Sub => _mm_sub_ps(left, right),
2425            BinaryKind::Mul => _mm_mul_ps(left, right),
2426        };
2427        _mm_storeu_ps(out_ptr.add(index), result);
2428        index += 4;
2429    }
2430
2431    if index < len {
2432        binary_same_shape_scalar(&lhs[index..], &rhs[index..], &mut out[index..], kind);
2433    }
2434}
2435
2436#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2437#[allow(unsafe_code)]
2438#[allow(unsafe_op_in_unsafe_fn)]
2439#[target_feature(enable = "avx")]
2440unsafe fn binary_same_shape_avx(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
2441    let len = lhs.len();
2442    let left_ptr = lhs.as_ptr();
2443    let right_ptr = rhs.as_ptr();
2444    let out_ptr = out.as_mut_ptr();
2445    let mut index = 0usize;
2446
2447    // 4x unrolled: process 32 floats per iteration with software prefetch.
2448    // Matches vDSP throughput by keeping the OoO pipeline fully saturated.
2449    match kind {
2450        BinaryKind::Add => {
2451            while index + 32 <= len {
2452                #[cfg(target_arch = "x86")]
2453                {
2454                    use std::arch::x86::_mm_prefetch;
2455                    _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2456                    _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2457                }
2458                #[cfg(target_arch = "x86_64")]
2459                {
2460                    use std::arch::x86_64::_mm_prefetch;
2461                    _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2462                    _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2463                }
2464                let a0 = _mm256_loadu_ps(left_ptr.add(index));
2465                let b0 = _mm256_loadu_ps(right_ptr.add(index));
2466                let a1 = _mm256_loadu_ps(left_ptr.add(index + 8));
2467                let b1 = _mm256_loadu_ps(right_ptr.add(index + 8));
2468                _mm256_storeu_ps(out_ptr.add(index), _mm256_add_ps(a0, b0));
2469                _mm256_storeu_ps(out_ptr.add(index + 8), _mm256_add_ps(a1, b1));
2470                let a2 = _mm256_loadu_ps(left_ptr.add(index + 16));
2471                let b2 = _mm256_loadu_ps(right_ptr.add(index + 16));
2472                let a3 = _mm256_loadu_ps(left_ptr.add(index + 24));
2473                let b3 = _mm256_loadu_ps(right_ptr.add(index + 24));
2474                _mm256_storeu_ps(out_ptr.add(index + 16), _mm256_add_ps(a2, b2));
2475                _mm256_storeu_ps(out_ptr.add(index + 24), _mm256_add_ps(a3, b3));
2476                index += 32;
2477            }
2478        }
2479        BinaryKind::Sub => {
2480            while index + 32 <= len {
2481                #[cfg(target_arch = "x86")]
2482                {
2483                    use std::arch::x86::_mm_prefetch;
2484                    _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2485                    _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2486                }
2487                #[cfg(target_arch = "x86_64")]
2488                {
2489                    use std::arch::x86_64::_mm_prefetch;
2490                    _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2491                    _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2492                }
2493                let a0 = _mm256_loadu_ps(left_ptr.add(index));
2494                let b0 = _mm256_loadu_ps(right_ptr.add(index));
2495                let a1 = _mm256_loadu_ps(left_ptr.add(index + 8));
2496                let b1 = _mm256_loadu_ps(right_ptr.add(index + 8));
2497                _mm256_storeu_ps(out_ptr.add(index), _mm256_sub_ps(a0, b0));
2498                _mm256_storeu_ps(out_ptr.add(index + 8), _mm256_sub_ps(a1, b1));
2499                let a2 = _mm256_loadu_ps(left_ptr.add(index + 16));
2500                let b2 = _mm256_loadu_ps(right_ptr.add(index + 16));
2501                let a3 = _mm256_loadu_ps(left_ptr.add(index + 24));
2502                let b3 = _mm256_loadu_ps(right_ptr.add(index + 24));
2503                _mm256_storeu_ps(out_ptr.add(index + 16), _mm256_sub_ps(a2, b2));
2504                _mm256_storeu_ps(out_ptr.add(index + 24), _mm256_sub_ps(a3, b3));
2505                index += 32;
2506            }
2507        }
2508        BinaryKind::Mul => {
2509            while index + 32 <= len {
2510                #[cfg(target_arch = "x86")]
2511                {
2512                    use std::arch::x86::_mm_prefetch;
2513                    _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2514                    _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2515                }
2516                #[cfg(target_arch = "x86_64")]
2517                {
2518                    use std::arch::x86_64::_mm_prefetch;
2519                    _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2520                    _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2521                }
2522                let a0 = _mm256_loadu_ps(left_ptr.add(index));
2523                let b0 = _mm256_loadu_ps(right_ptr.add(index));
2524                let a1 = _mm256_loadu_ps(left_ptr.add(index + 8));
2525                let b1 = _mm256_loadu_ps(right_ptr.add(index + 8));
2526                _mm256_storeu_ps(out_ptr.add(index), _mm256_mul_ps(a0, b0));
2527                _mm256_storeu_ps(out_ptr.add(index + 8), _mm256_mul_ps(a1, b1));
2528                let a2 = _mm256_loadu_ps(left_ptr.add(index + 16));
2529                let b2 = _mm256_loadu_ps(right_ptr.add(index + 16));
2530                let a3 = _mm256_loadu_ps(left_ptr.add(index + 24));
2531                let b3 = _mm256_loadu_ps(right_ptr.add(index + 24));
2532                _mm256_storeu_ps(out_ptr.add(index + 16), _mm256_mul_ps(a2, b2));
2533                _mm256_storeu_ps(out_ptr.add(index + 24), _mm256_mul_ps(a3, b3));
2534                index += 32;
2535            }
2536        }
2537    }
2538
2539    // Handle remaining elements 8 at a time
2540    while index + 8 <= len {
2541        let left = _mm256_loadu_ps(left_ptr.add(index));
2542        let right = _mm256_loadu_ps(right_ptr.add(index));
2543        let result = match kind {
2544            BinaryKind::Add => _mm256_add_ps(left, right),
2545            BinaryKind::Sub => _mm256_sub_ps(left, right),
2546            BinaryKind::Mul => _mm256_mul_ps(left, right),
2547        };
2548        _mm256_storeu_ps(out_ptr.add(index), result);
2549        index += 8;
2550    }
2551
2552    if index < len {
2553        binary_same_shape_sse(&lhs[index..], &rhs[index..], &mut out[index..], kind);
2554    }
2555}
2556
2557#[cfg(target_arch = "aarch64")]
2558#[allow(unsafe_code)]
2559#[allow(unsafe_op_in_unsafe_fn)]
2560#[target_feature(enable = "neon")]
2561unsafe fn relu_slice_neon(values: &mut [f32]) {
2562    let len = values.len();
2563    let ptr = values.as_mut_ptr();
2564    let zero = vdupq_n_f32(0.0);
2565    let mut index = 0usize;
2566
2567    // 8× unrolled: 32 elements per iteration
2568    while index + 32 <= len {
2569        let v0 = vmaxq_f32(vld1q_f32(ptr.add(index)), zero);
2570        let v1 = vmaxq_f32(vld1q_f32(ptr.add(index + 4)), zero);
2571        let v2 = vmaxq_f32(vld1q_f32(ptr.add(index + 8)), zero);
2572        let v3 = vmaxq_f32(vld1q_f32(ptr.add(index + 12)), zero);
2573        let v4 = vmaxq_f32(vld1q_f32(ptr.add(index + 16)), zero);
2574        let v5 = vmaxq_f32(vld1q_f32(ptr.add(index + 20)), zero);
2575        let v6 = vmaxq_f32(vld1q_f32(ptr.add(index + 24)), zero);
2576        let v7 = vmaxq_f32(vld1q_f32(ptr.add(index + 28)), zero);
2577        vst1q_f32(ptr.add(index), v0);
2578        vst1q_f32(ptr.add(index + 4), v1);
2579        vst1q_f32(ptr.add(index + 8), v2);
2580        vst1q_f32(ptr.add(index + 12), v3);
2581        vst1q_f32(ptr.add(index + 16), v4);
2582        vst1q_f32(ptr.add(index + 20), v5);
2583        vst1q_f32(ptr.add(index + 24), v6);
2584        vst1q_f32(ptr.add(index + 28), v7);
2585        index += 32;
2586    }
2587
2588    while index + 4 <= len {
2589        vst1q_f32(ptr.add(index), vmaxq_f32(vld1q_f32(ptr.add(index)), zero));
2590        index += 4;
2591    }
2592
2593    if index < len {
2594        relu_slice_scalar(&mut values[index..]);
2595    }
2596}
2597
2598// ===========================================================================
2599// Two-argument ReLU SIMD implementations (input -> output)
2600// ===========================================================================
2601
2602#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2603#[allow(unsafe_code)]
2604#[allow(unsafe_op_in_unsafe_fn)]
2605#[target_feature(enable = "sse")]
2606unsafe fn relu_to_slice_sse(input: &[f32], output: &mut [f32]) {
2607    let len = input.len();
2608    let in_ptr = input.as_ptr();
2609    let out_ptr = output.as_mut_ptr();
2610    let zero = _mm_setzero_ps();
2611    let mut index = 0usize;
2612
2613    while index + 4 <= len {
2614        let v = _mm_loadu_ps(in_ptr.add(index));
2615        let r = _mm_max_ps(v, zero);
2616        _mm_storeu_ps(out_ptr.add(index), r);
2617        index += 4;
2618    }
2619
2620    if index < len {
2621        relu_to_slice_scalar(&input[index..], &mut output[index..]);
2622    }
2623}
2624
2625#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2626#[allow(unsafe_code)]
2627#[allow(unsafe_op_in_unsafe_fn)]
2628#[target_feature(enable = "avx")]
2629unsafe fn relu_to_slice_avx(input: &[f32], output: &mut [f32]) {
2630    let len = input.len();
2631    let in_ptr = input.as_ptr();
2632    let out_ptr = output.as_mut_ptr();
2633    let zero = _mm256_setzero_ps();
2634    let mut index = 0usize;
2635
2636    // 4× unrolled: 32 elements per iteration (matches NEON unrolling)
2637    while index + 32 <= len {
2638        let a0 = _mm256_loadu_ps(in_ptr.add(index));
2639        let a1 = _mm256_loadu_ps(in_ptr.add(index + 8));
2640        let a2 = _mm256_loadu_ps(in_ptr.add(index + 16));
2641        let a3 = _mm256_loadu_ps(in_ptr.add(index + 24));
2642        _mm256_storeu_ps(out_ptr.add(index), _mm256_max_ps(a0, zero));
2643        _mm256_storeu_ps(out_ptr.add(index + 8), _mm256_max_ps(a1, zero));
2644        _mm256_storeu_ps(out_ptr.add(index + 16), _mm256_max_ps(a2, zero));
2645        _mm256_storeu_ps(out_ptr.add(index + 24), _mm256_max_ps(a3, zero));
2646        index += 32;
2647    }
2648
2649    while index + 8 <= len {
2650        _mm256_storeu_ps(
2651            out_ptr.add(index),
2652            _mm256_max_ps(_mm256_loadu_ps(in_ptr.add(index)), zero),
2653        );
2654        index += 8;
2655    }
2656
2657    if index < len {
2658        relu_to_slice_sse(&input[index..], &mut output[index..]);
2659    }
2660}
2661
2662#[cfg(target_arch = "aarch64")]
2663#[allow(unsafe_code)]
2664#[allow(unsafe_op_in_unsafe_fn)]
2665#[target_feature(enable = "neon")]
2666unsafe fn relu_to_slice_neon(input: &[f32], output: &mut [f32]) {
2667    let len = input.len();
2668    let in_ptr = input.as_ptr();
2669    let out_ptr = output.as_mut_ptr();
2670    let zero = vdupq_n_f32(0.0);
2671    let mut index = 0usize;
2672
2673    // 8× unrolled with interleaved load/compute/store for better OoO pipelining
2674    while index + 32 <= len {
2675        let a0 = vld1q_f32(in_ptr.add(index));
2676        let a1 = vld1q_f32(in_ptr.add(index + 4));
2677        let a2 = vld1q_f32(in_ptr.add(index + 8));
2678        let a3 = vld1q_f32(in_ptr.add(index + 12));
2679        vst1q_f32(out_ptr.add(index), vmaxq_f32(a0, zero));
2680        vst1q_f32(out_ptr.add(index + 4), vmaxq_f32(a1, zero));
2681        let a4 = vld1q_f32(in_ptr.add(index + 16));
2682        let a5 = vld1q_f32(in_ptr.add(index + 20));
2683        vst1q_f32(out_ptr.add(index + 8), vmaxq_f32(a2, zero));
2684        vst1q_f32(out_ptr.add(index + 12), vmaxq_f32(a3, zero));
2685        let a6 = vld1q_f32(in_ptr.add(index + 24));
2686        let a7 = vld1q_f32(in_ptr.add(index + 28));
2687        vst1q_f32(out_ptr.add(index + 16), vmaxq_f32(a4, zero));
2688        vst1q_f32(out_ptr.add(index + 20), vmaxq_f32(a5, zero));
2689        vst1q_f32(out_ptr.add(index + 24), vmaxq_f32(a6, zero));
2690        vst1q_f32(out_ptr.add(index + 28), vmaxq_f32(a7, zero));
2691        index += 32;
2692    }
2693
2694    while index + 4 <= len {
2695        vst1q_f32(
2696            out_ptr.add(index),
2697            vmaxq_f32(vld1q_f32(in_ptr.add(index)), zero),
2698        );
2699        index += 4;
2700    }
2701
2702    if index < len {
2703        relu_to_slice_scalar(&input[index..], &mut output[index..]);
2704    }
2705}
2706
2707#[cfg(all(target_arch = "aarch64", not(target_os = "macos")))]
2708#[allow(unsafe_code)]
2709#[allow(unsafe_op_in_unsafe_fn)]
2710#[target_feature(enable = "neon")]
2711unsafe fn binary_same_shape_neon(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
2712    let len = lhs.len();
2713    let left_ptr = lhs.as_ptr();
2714    let right_ptr = rhs.as_ptr();
2715    let out_ptr = out.as_mut_ptr();
2716    let mut index = 0usize;
2717
2718    while index + 4 <= len {
2719        let left = vld1q_f32(left_ptr.add(index));
2720        let right = vld1q_f32(right_ptr.add(index));
2721        let result = match kind {
2722            BinaryKind::Add => vaddq_f32(left, right),
2723            BinaryKind::Sub => vsubq_f32(left, right),
2724            BinaryKind::Mul => vmulq_f32(left, right),
2725        };
2726        vst1q_f32(out_ptr.add(index), result);
2727        index += 4;
2728    }
2729
2730    if index < len {
2731        binary_same_shape_scalar(&lhs[index..], &rhs[index..], &mut out[index..], kind);
2732    }
2733}
2734
2735// ---------------------------------------------------------------------------
2736// SIMD-accelerated matmul inner loop
2737// ---------------------------------------------------------------------------
2738//
2739// Computes one output row of C = A * B by iterating over the shared dimension k.
2740// For each k, broadcasts a[row*K + k] and multiplies by the contiguous B row
2741// b[k*N .. k*N + N], accumulating into the output row out[0..N].
2742//
2743// The "broadcast A, contiguous B row" access pattern is SIMD-friendly because
2744// all loads from B are contiguous.
2745
2746/// Dispatch to the best available SIMD path for a single matmul output row.
2747///
2748/// # Safety
2749/// - `left_row` must point to at least `k` valid f32 elements.
2750/// - `right` must point to at least `k * n` valid f32 elements (row-major B).
2751/// - `out_row` must point to at least `n` valid f32 elements.
2752/// - The caller must ensure no aliasing between `out_row` and the input pointers.
2753#[inline]
2754#[allow(unsafe_code)]
2755#[allow(unsafe_op_in_unsafe_fn)]
2756pub unsafe fn matmul_row_dispatch(
2757    left_row: *const f32,
2758    right: *const f32,
2759    out_row: *mut f32,
2760    k: usize,
2761    n: usize,
2762) {
2763    if cfg!(miri) {
2764        matmul_row_scalar(left_row, right, out_row, k, n);
2765        return;
2766    }
2767
2768    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2769    {
2770        if std::is_x86_feature_detected!("avx") {
2771            matmul_row_avx(left_row, right, out_row, k, n);
2772            return;
2773        }
2774        if std::is_x86_feature_detected!("sse") {
2775            matmul_row_sse(left_row, right, out_row, k, n);
2776            return;
2777        }
2778    }
2779
2780    #[cfg(target_arch = "aarch64")]
2781    {
2782        if std::arch::is_aarch64_feature_detected!("neon") {
2783            matmul_row_neon(left_row, right, out_row, k, n);
2784            return;
2785        }
2786    }
2787
2788    matmul_row_scalar(left_row, right, out_row, k, n);
2789}
2790
2791/// Scalar fallback: broadcast-multiply-accumulate, unrolled by 4.
2792#[allow(unsafe_code)]
2793#[allow(unsafe_op_in_unsafe_fn)]
2794unsafe fn matmul_row_scalar(
2795    left_row: *const f32,
2796    right: *const f32,
2797    out_row: *mut f32,
2798    k: usize,
2799    n: usize,
2800) {
2801    for p in 0..k {
2802        let a_val = *left_row.add(p);
2803        let b_row = right.add(p * n);
2804
2805        let mut col = 0usize;
2806        while col + 4 <= n {
2807            *out_row.add(col) += a_val * *b_row.add(col);
2808            *out_row.add(col + 1) += a_val * *b_row.add(col + 1);
2809            *out_row.add(col + 2) += a_val * *b_row.add(col + 2);
2810            *out_row.add(col + 3) += a_val * *b_row.add(col + 3);
2811            col += 4;
2812        }
2813        while col < n {
2814            *out_row.add(col) += a_val * *b_row.add(col);
2815            col += 1;
2816        }
2817    }
2818}
2819
2820#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2821#[allow(unsafe_code)]
2822#[allow(unsafe_op_in_unsafe_fn)]
2823#[target_feature(enable = "sse")]
2824unsafe fn matmul_row_sse(
2825    left_row: *const f32,
2826    right: *const f32,
2827    out_row: *mut f32,
2828    k: usize,
2829    n: usize,
2830) {
2831    for p in 0..k {
2832        let a_val = _mm_set1_ps(*left_row.add(p));
2833        let b_row = right.add(p * n);
2834
2835        let mut col = 0usize;
2836        while col + 4 <= n {
2837            let b_vec = _mm_loadu_ps(b_row.add(col));
2838            let out_vec = _mm_loadu_ps(out_row.add(col));
2839            let result = _mm_add_ps(out_vec, _mm_mul_ps(a_val, b_vec));
2840            _mm_storeu_ps(out_row.add(col), result);
2841            col += 4;
2842        }
2843        while col < n {
2844            *out_row.add(col) += *left_row.add(p) * *b_row.add(col);
2845            col += 1;
2846        }
2847    }
2848}
2849
2850#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2851#[allow(unsafe_code)]
2852#[allow(unsafe_op_in_unsafe_fn)]
2853#[target_feature(enable = "avx")]
2854unsafe fn matmul_row_avx(
2855    left_row: *const f32,
2856    right: *const f32,
2857    out_row: *mut f32,
2858    k: usize,
2859    n: usize,
2860) {
2861    for p in 0..k {
2862        let a_val_avx = _mm256_set1_ps(*left_row.add(p));
2863        let a_val_sse = _mm_set1_ps(*left_row.add(p));
2864        let b_row = right.add(p * n);
2865
2866        let mut col = 0usize;
2867        while col + 8 <= n {
2868            let b_vec = _mm256_loadu_ps(b_row.add(col));
2869            let out_vec = _mm256_loadu_ps(out_row.add(col));
2870            let result = _mm256_add_ps(out_vec, _mm256_mul_ps(a_val_avx, b_vec));
2871            _mm256_storeu_ps(out_row.add(col), result);
2872            col += 8;
2873        }
2874        // Handle 4-element remainder with SSE.
2875        while col + 4 <= n {
2876            let b_vec = _mm_loadu_ps(b_row.add(col));
2877            let out_vec = _mm_loadu_ps(out_row.add(col));
2878            let result = _mm_add_ps(out_vec, _mm_mul_ps(a_val_sse, b_vec));
2879            _mm_storeu_ps(out_row.add(col), result);
2880            col += 4;
2881        }
2882        while col < n {
2883            *out_row.add(col) += *left_row.add(p) * *b_row.add(col);
2884            col += 1;
2885        }
2886    }
2887}
2888
2889#[cfg(target_arch = "aarch64")]
2890#[allow(unsafe_code)]
2891#[allow(unsafe_op_in_unsafe_fn)]
2892#[target_feature(enable = "neon")]
2893unsafe fn matmul_row_neon(
2894    left_row: *const f32,
2895    right: *const f32,
2896    out_row: *mut f32,
2897    k: usize,
2898    n: usize,
2899) {
2900    for p in 0..k {
2901        let a_val: float32x4_t = vdupq_n_f32(*left_row.add(p));
2902        let b_row = right.add(p * n);
2903
2904        let mut col = 0usize;
2905        while col + 4 <= n {
2906            let b_vec = vld1q_f32(b_row.add(col));
2907            let out_vec = vld1q_f32(out_row.add(col));
2908            let result = vfmaq_f32(out_vec, a_val, b_vec);
2909            vst1q_f32(out_row.add(col), result);
2910            col += 4;
2911        }
2912        while col < n {
2913            *out_row.add(col) += *left_row.add(p) * *b_row.add(col);
2914            col += 1;
2915        }
2916    }
2917}
2918
2919// ===========================================================================
2920// Fused softmax: max + sub-exp + sum + divide in one function
2921// ===========================================================================
2922
2923/// Fused softmax row: `out[i] = exp(input[i] - max) / sum(exp(input - max))`.
2924///
2925/// Performs all four steps (max, subtract+exp, sum, divide) inside a single
2926/// function so that data stays in L1 cache and dispatcher overhead is eliminated.
2927#[allow(unsafe_code)]
2928#[inline]
2929pub fn softmax_row_fused_dispatch(input: &[f32], output: &mut [f32]) {
2930    debug_assert_eq!(input.len(), output.len());
2931
2932    if cfg!(miri) || input.is_empty() {
2933        softmax_row_fused_scalar(input, output);
2934        return;
2935    }
2936
2937    #[cfg(target_arch = "aarch64")]
2938    {
2939        if std::arch::is_aarch64_feature_detected!("neon") {
2940            // SAFETY: guarded by runtime feature detection.
2941            unsafe {
2942                softmax_row_fused_neon(input, output);
2943            }
2944            return;
2945        }
2946    }
2947
2948    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2949    {
2950        if std::is_x86_feature_detected!("avx") {
2951            // SAFETY: guarded by runtime feature detection.
2952            unsafe {
2953                softmax_row_fused_avx(input, output);
2954            }
2955            return;
2956        }
2957        if std::is_x86_feature_detected!("sse") {
2958            // SAFETY: guarded by runtime feature detection.
2959            unsafe {
2960                softmax_row_fused_sse(input, output);
2961            }
2962            return;
2963        }
2964    }
2965
2966    softmax_row_fused_scalar(input, output);
2967}
2968
2969fn softmax_row_fused_scalar(input: &[f32], output: &mut [f32]) {
2970    if input.is_empty() {
2971        return;
2972    }
2973
2974    // 1. max
2975    let mut max_val = f32::NEG_INFINITY;
2976    for &v in input {
2977        max_val = max_val.max(v);
2978    }
2979
2980    // 2. sub+exp + 3. accumulate sum
2981    let mut sum_exp = 0.0f32;
2982    for (o, &v) in output.iter_mut().zip(input.iter()) {
2983        let e = (v - max_val).exp();
2984        *o = e;
2985        sum_exp += e;
2986    }
2987
2988    // 4. divide
2989    let inv = 1.0 / sum_exp;
2990    for o in output.iter_mut() {
2991        *o *= inv;
2992    }
2993}
2994
2995#[cfg(target_arch = "aarch64")]
2996#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
2997#[target_feature(enable = "neon")]
2998unsafe fn softmax_row_fused_neon(input: &[f32], output: &mut [f32]) {
2999    use std::arch::aarch64::{vaddvq_f32, vmaxvq_f32};
3000
3001    let len = input.len();
3002    let in_ptr = input.as_ptr();
3003    let out_ptr = output.as_mut_ptr();
3004
3005    // 1. Find max (NEON reduce)
3006    let mut acc_max = vdupq_n_f32(f32::NEG_INFINITY);
3007    let mut i = 0usize;
3008    while i + 16 <= len {
3009        let v0 = vld1q_f32(in_ptr.add(i));
3010        let v1 = vld1q_f32(in_ptr.add(i + 4));
3011        let v2 = vld1q_f32(in_ptr.add(i + 8));
3012        let v3 = vld1q_f32(in_ptr.add(i + 12));
3013        acc_max = vmaxq_f32(acc_max, vmaxq_f32(vmaxq_f32(v0, v1), vmaxq_f32(v2, v3)));
3014        i += 16;
3015    }
3016    while i + 4 <= len {
3017        let v = vld1q_f32(in_ptr.add(i));
3018        acc_max = vmaxq_f32(acc_max, v);
3019        i += 4;
3020    }
3021    let mut max_val = vmaxvq_f32(acc_max);
3022    while i < len {
3023        max_val = max_val.max(*in_ptr.add(i));
3024        i += 1;
3025    }
3026
3027    // 2. sub+exp (NEON fast_exp, writes output) + 3. accumulate sum
3028    let off = vdupq_n_f32(max_val);
3029    let mut acc_sum = vdupq_n_f32(0.0);
3030    i = 0;
3031    while i + 16 <= len {
3032        let v0 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i)), off));
3033        let v1 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 4)), off));
3034        let v2 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 8)), off));
3035        let v3 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 12)), off));
3036        vst1q_f32(out_ptr.add(i), v0);
3037        vst1q_f32(out_ptr.add(i + 4), v1);
3038        vst1q_f32(out_ptr.add(i + 8), v2);
3039        vst1q_f32(out_ptr.add(i + 12), v3);
3040        acc_sum = vaddq_f32(acc_sum, vaddq_f32(vaddq_f32(v0, v1), vaddq_f32(v2, v3)));
3041        i += 16;
3042    }
3043    while i + 4 <= len {
3044        let v = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i)), off));
3045        vst1q_f32(out_ptr.add(i), v);
3046        acc_sum = vaddq_f32(acc_sum, v);
3047        i += 4;
3048    }
3049    let mut sum_exp = vaddvq_f32(acc_sum);
3050    while i < len {
3051        let e = (*in_ptr.add(i) - max_val).exp();
3052        *out_ptr.add(i) = e;
3053        sum_exp += e;
3054        i += 1;
3055    }
3056
3057    // 4. divide (NEON multiply by 1/sum)
3058    let inv = vdupq_n_f32(1.0 / sum_exp);
3059    i = 0;
3060    while i + 16 <= len {
3061        vst1q_f32(out_ptr.add(i), vmulq_f32(vld1q_f32(out_ptr.add(i)), inv));
3062        vst1q_f32(
3063            out_ptr.add(i + 4),
3064            vmulq_f32(vld1q_f32(out_ptr.add(i + 4)), inv),
3065        );
3066        vst1q_f32(
3067            out_ptr.add(i + 8),
3068            vmulq_f32(vld1q_f32(out_ptr.add(i + 8)), inv),
3069        );
3070        vst1q_f32(
3071            out_ptr.add(i + 12),
3072            vmulq_f32(vld1q_f32(out_ptr.add(i + 12)), inv),
3073        );
3074        i += 16;
3075    }
3076    while i + 4 <= len {
3077        vst1q_f32(out_ptr.add(i), vmulq_f32(vld1q_f32(out_ptr.add(i)), inv));
3078        i += 4;
3079    }
3080    let inv_s = 1.0 / sum_exp;
3081    while i < len {
3082        *out_ptr.add(i) *= inv_s;
3083        i += 1;
3084    }
3085}
3086
3087/// SSE fused softmax fallback.
3088#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3089#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3090#[target_feature(enable = "sse")]
3091unsafe fn softmax_row_fused_sse(input: &[f32], output: &mut [f32]) {
3092    let len = input.len();
3093    let in_ptr = input.as_ptr();
3094    let out_ptr = output.as_mut_ptr();
3095
3096    // 1. max
3097    let mut acc_max = _mm_set1_ps(f32::NEG_INFINITY);
3098    let mut i = 0usize;
3099    while i + 4 <= len {
3100        acc_max = _mm_max_ps(acc_max, _mm_loadu_ps(in_ptr.add(i)));
3101        i += 4;
3102    }
3103    let mut buf = [0.0f32; 4];
3104    _mm_storeu_ps(buf.as_mut_ptr(), acc_max);
3105    let mut max_val = buf[0].max(buf[1]).max(buf[2].max(buf[3]));
3106    while i < len {
3107        max_val = max_val.max(*in_ptr.add(i));
3108        i += 1;
3109    }
3110
3111    // 2. sub+exp + 3. sum
3112    let off = _mm_set1_ps(max_val);
3113    let mut acc_sum = _mm_setzero_ps();
3114    i = 0;
3115    while i + 4 <= len {
3116        let v = fast_exp_sse(_mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), off));
3117        _mm_storeu_ps(out_ptr.add(i), v);
3118        acc_sum = _mm_add_ps(acc_sum, v);
3119        i += 4;
3120    }
3121    _mm_storeu_ps(buf.as_mut_ptr(), acc_sum);
3122    let mut sum_exp = buf[0] + buf[1] + buf[2] + buf[3];
3123    while i < len {
3124        let e = (*in_ptr.add(i) - max_val).exp();
3125        *out_ptr.add(i) = e;
3126        sum_exp += e;
3127        i += 1;
3128    }
3129
3130    // 4. divide
3131    let inv = _mm_set1_ps(1.0 / sum_exp);
3132    i = 0;
3133    while i + 4 <= len {
3134        _mm_storeu_ps(
3135            out_ptr.add(i),
3136            _mm_mul_ps(_mm_loadu_ps(out_ptr.add(i)), inv),
3137        );
3138        i += 4;
3139    }
3140    let inv_s = 1.0 / sum_exp;
3141    while i < len {
3142        *out_ptr.add(i) *= inv_s;
3143        i += 1;
3144    }
3145}
3146
3147/// AVX fused softmax fallback — delegates tail to SSE.
3148#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3149#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3150#[target_feature(enable = "avx")]
3151unsafe fn softmax_row_fused_avx(input: &[f32], output: &mut [f32]) {
3152    let len = input.len();
3153    let in_ptr = input.as_ptr();
3154    let out_ptr = output.as_mut_ptr();
3155
3156    // 1. max
3157    let mut acc_max = _mm256_set1_ps(f32::NEG_INFINITY);
3158    let mut i = 0usize;
3159    while i + 8 <= len {
3160        acc_max = _mm256_max_ps(acc_max, _mm256_loadu_ps(in_ptr.add(i)));
3161        i += 8;
3162    }
3163    let mut buf8 = [0.0f32; 8];
3164    _mm256_storeu_ps(buf8.as_mut_ptr(), acc_max);
3165    let mut max_val = buf8[0];
3166    for &v in &buf8[1..] {
3167        max_val = max_val.max(v);
3168    }
3169    while i < len {
3170        max_val = max_val.max(*in_ptr.add(i));
3171        i += 1;
3172    }
3173
3174    // 2. sub+exp + 3. sum
3175    let off = _mm256_set1_ps(max_val);
3176    let mut acc_sum = _mm256_setzero_ps();
3177    i = 0;
3178    while i + 8 <= len {
3179        let v = fast_exp_avx(_mm256_sub_ps(_mm256_loadu_ps(in_ptr.add(i)), off));
3180        _mm256_storeu_ps(out_ptr.add(i), v);
3181        acc_sum = _mm256_add_ps(acc_sum, v);
3182        i += 8;
3183    }
3184    _mm256_storeu_ps(buf8.as_mut_ptr(), acc_sum);
3185    let mut sum_exp: f32 = buf8.iter().sum();
3186    // SSE tail for remaining < 8 elements
3187    let off4 = _mm_set1_ps(max_val);
3188    while i + 4 <= len {
3189        let v = fast_exp_sse(_mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), off4));
3190        _mm_storeu_ps(out_ptr.add(i), v);
3191        let mut b4 = [0.0f32; 4];
3192        _mm_storeu_ps(b4.as_mut_ptr(), v);
3193        sum_exp += b4[0] + b4[1] + b4[2] + b4[3];
3194        i += 4;
3195    }
3196    while i < len {
3197        let e = (*in_ptr.add(i) - max_val).exp();
3198        *out_ptr.add(i) = e;
3199        sum_exp += e;
3200        i += 1;
3201    }
3202
3203    // 4. divide
3204    let inv8 = _mm256_set1_ps(1.0 / sum_exp);
3205    i = 0;
3206    while i + 8 <= len {
3207        _mm256_storeu_ps(
3208            out_ptr.add(i),
3209            _mm256_mul_ps(_mm256_loadu_ps(out_ptr.add(i)), inv8),
3210        );
3211        i += 8;
3212    }
3213    let inv4 = _mm_set1_ps(1.0 / sum_exp);
3214    while i + 4 <= len {
3215        _mm_storeu_ps(
3216            out_ptr.add(i),
3217            _mm_mul_ps(_mm_loadu_ps(out_ptr.add(i)), inv4),
3218        );
3219        i += 4;
3220    }
3221    let inv_s = 1.0 / sum_exp;
3222    while i < len {
3223        *out_ptr.add(i) *= inv_s;
3224        i += 1;
3225    }
3226}
3227
3228// ===========================================================================
3229// Fused log-softmax: out[i] = x[i] - max - log(sum(exp(x - max)))
3230// ===========================================================================
3231
3232#[allow(unsafe_code)]
3233#[inline]
3234pub fn log_softmax_row_fused_dispatch(input: &[f32], output: &mut [f32]) {
3235    debug_assert_eq!(input.len(), output.len());
3236
3237    if cfg!(miri) || input.is_empty() {
3238        log_softmax_row_fused_scalar(input, output);
3239        return;
3240    }
3241
3242    #[cfg(target_arch = "aarch64")]
3243    {
3244        if std::arch::is_aarch64_feature_detected!("neon") {
3245            // SAFETY: guarded by runtime feature detection.
3246            unsafe {
3247                log_softmax_row_fused_neon(input, output);
3248            }
3249            return;
3250        }
3251    }
3252
3253    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3254    {
3255        if std::is_x86_feature_detected!("avx") {
3256            // SAFETY: guarded by runtime feature detection.
3257            unsafe {
3258                log_softmax_row_fused_avx(input, output);
3259            }
3260            return;
3261        }
3262        if std::is_x86_feature_detected!("sse") {
3263            // SAFETY: guarded by runtime feature detection.
3264            unsafe {
3265                log_softmax_row_fused_sse(input, output);
3266            }
3267            return;
3268        }
3269    }
3270
3271    log_softmax_row_fused_scalar(input, output);
3272}
3273
3274fn log_softmax_row_fused_scalar(input: &[f32], output: &mut [f32]) {
3275    if input.is_empty() {
3276        return;
3277    }
3278
3279    // 1. max
3280    let mut max_val = f32::NEG_INFINITY;
3281    for &v in input {
3282        max_val = max_val.max(v);
3283    }
3284
3285    // 2. sum(exp(x - max))
3286    let mut sum_exp = 0.0f32;
3287    for &v in input {
3288        sum_exp += (v - max_val).exp();
3289    }
3290
3291    // 3. output[i] = x[i] - max - log(sum_exp)
3292    let log_denom = max_val + sum_exp.ln();
3293    for (o, &v) in output.iter_mut().zip(input.iter()) {
3294        *o = v - log_denom;
3295    }
3296}
3297
3298#[cfg(target_arch = "aarch64")]
3299#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3300#[target_feature(enable = "neon")]
3301unsafe fn log_softmax_row_fused_neon(input: &[f32], output: &mut [f32]) {
3302    use std::arch::aarch64::{vaddvq_f32, vmaxvq_f32};
3303
3304    let len = input.len();
3305    let in_ptr = input.as_ptr();
3306    let out_ptr = output.as_mut_ptr();
3307
3308    // 1. Find max (NEON reduce)
3309    let mut acc_max = vdupq_n_f32(f32::NEG_INFINITY);
3310    let mut i = 0usize;
3311    while i + 16 <= len {
3312        let v0 = vld1q_f32(in_ptr.add(i));
3313        let v1 = vld1q_f32(in_ptr.add(i + 4));
3314        let v2 = vld1q_f32(in_ptr.add(i + 8));
3315        let v3 = vld1q_f32(in_ptr.add(i + 12));
3316        acc_max = vmaxq_f32(acc_max, vmaxq_f32(vmaxq_f32(v0, v1), vmaxq_f32(v2, v3)));
3317        i += 16;
3318    }
3319    while i + 4 <= len {
3320        acc_max = vmaxq_f32(acc_max, vld1q_f32(in_ptr.add(i)));
3321        i += 4;
3322    }
3323    let mut max_val = vmaxvq_f32(acc_max);
3324    while i < len {
3325        max_val = max_val.max(*in_ptr.add(i));
3326        i += 1;
3327    }
3328
3329    // 2. sum(exp(x - max))
3330    let off = vdupq_n_f32(max_val);
3331    let mut acc_sum = vdupq_n_f32(0.0);
3332    i = 0;
3333    while i + 16 <= len {
3334        let e0 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i)), off));
3335        let e1 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 4)), off));
3336        let e2 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 8)), off));
3337        let e3 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 12)), off));
3338        acc_sum = vaddq_f32(acc_sum, vaddq_f32(vaddq_f32(e0, e1), vaddq_f32(e2, e3)));
3339        i += 16;
3340    }
3341    while i + 4 <= len {
3342        let e = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i)), off));
3343        acc_sum = vaddq_f32(acc_sum, e);
3344        i += 4;
3345    }
3346    let mut sum_exp = vaddvq_f32(acc_sum);
3347    while i < len {
3348        sum_exp += (*in_ptr.add(i) - max_val).exp();
3349        i += 1;
3350    }
3351
3352    // 3. output[i] = x[i] - (max + log(sum_exp))
3353    let log_denom = vdupq_n_f32(max_val + sum_exp.ln());
3354    i = 0;
3355    while i + 16 <= len {
3356        vst1q_f32(
3357            out_ptr.add(i),
3358            vsubq_f32(vld1q_f32(in_ptr.add(i)), log_denom),
3359        );
3360        vst1q_f32(
3361            out_ptr.add(i + 4),
3362            vsubq_f32(vld1q_f32(in_ptr.add(i + 4)), log_denom),
3363        );
3364        vst1q_f32(
3365            out_ptr.add(i + 8),
3366            vsubq_f32(vld1q_f32(in_ptr.add(i + 8)), log_denom),
3367        );
3368        vst1q_f32(
3369            out_ptr.add(i + 12),
3370            vsubq_f32(vld1q_f32(in_ptr.add(i + 12)), log_denom),
3371        );
3372        i += 16;
3373    }
3374    while i + 4 <= len {
3375        vst1q_f32(
3376            out_ptr.add(i),
3377            vsubq_f32(vld1q_f32(in_ptr.add(i)), log_denom),
3378        );
3379        i += 4;
3380    }
3381    let log_denom_s = max_val + sum_exp.ln();
3382    while i < len {
3383        *out_ptr.add(i) = *in_ptr.add(i) - log_denom_s;
3384        i += 1;
3385    }
3386}
3387
3388/// SSE fused log-softmax.
3389#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3390#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3391#[target_feature(enable = "sse")]
3392unsafe fn log_softmax_row_fused_sse(input: &[f32], output: &mut [f32]) {
3393    let len = input.len();
3394    let in_ptr = input.as_ptr();
3395    let out_ptr = output.as_mut_ptr();
3396
3397    // 1. max
3398    let mut acc_max = _mm_set1_ps(f32::NEG_INFINITY);
3399    let mut i = 0usize;
3400    while i + 4 <= len {
3401        acc_max = _mm_max_ps(acc_max, _mm_loadu_ps(in_ptr.add(i)));
3402        i += 4;
3403    }
3404    let mut buf = [0.0f32; 4];
3405    _mm_storeu_ps(buf.as_mut_ptr(), acc_max);
3406    let mut max_val = buf[0].max(buf[1]).max(buf[2].max(buf[3]));
3407    while i < len {
3408        max_val = max_val.max(*in_ptr.add(i));
3409        i += 1;
3410    }
3411
3412    // 2. sum(exp(x - max))
3413    let off = _mm_set1_ps(max_val);
3414    let mut acc_sum = _mm_setzero_ps();
3415    i = 0;
3416    while i + 4 <= len {
3417        let e = fast_exp_sse(_mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), off));
3418        acc_sum = _mm_add_ps(acc_sum, e);
3419        i += 4;
3420    }
3421    _mm_storeu_ps(buf.as_mut_ptr(), acc_sum);
3422    let mut sum_exp = buf[0] + buf[1] + buf[2] + buf[3];
3423    while i < len {
3424        sum_exp += (*in_ptr.add(i) - max_val).exp();
3425        i += 1;
3426    }
3427
3428    // 3. output[i] = x[i] - (max + log(sum_exp))
3429    let log_denom = _mm_set1_ps(max_val + sum_exp.ln());
3430    i = 0;
3431    while i + 4 <= len {
3432        _mm_storeu_ps(
3433            out_ptr.add(i),
3434            _mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), log_denom),
3435        );
3436        i += 4;
3437    }
3438    let log_denom_s = max_val + sum_exp.ln();
3439    while i < len {
3440        *out_ptr.add(i) = *in_ptr.add(i) - log_denom_s;
3441        i += 1;
3442    }
3443}
3444
3445/// AVX fused log-softmax.
3446#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3447#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3448#[target_feature(enable = "avx")]
3449unsafe fn log_softmax_row_fused_avx(input: &[f32], output: &mut [f32]) {
3450    let len = input.len();
3451    let in_ptr = input.as_ptr();
3452    let out_ptr = output.as_mut_ptr();
3453
3454    // 1. max
3455    let mut acc_max = _mm256_set1_ps(f32::NEG_INFINITY);
3456    let mut i = 0usize;
3457    while i + 8 <= len {
3458        acc_max = _mm256_max_ps(acc_max, _mm256_loadu_ps(in_ptr.add(i)));
3459        i += 8;
3460    }
3461    let mut buf8 = [0.0f32; 8];
3462    _mm256_storeu_ps(buf8.as_mut_ptr(), acc_max);
3463    let mut max_val = buf8[0];
3464    for &v in &buf8[1..] {
3465        max_val = max_val.max(v);
3466    }
3467    while i < len {
3468        max_val = max_val.max(*in_ptr.add(i));
3469        i += 1;
3470    }
3471
3472    // 2. sum(exp(x - max))
3473    let off = _mm256_set1_ps(max_val);
3474    let mut acc_sum = _mm256_setzero_ps();
3475    i = 0;
3476    while i + 8 <= len {
3477        let e = fast_exp_avx(_mm256_sub_ps(_mm256_loadu_ps(in_ptr.add(i)), off));
3478        acc_sum = _mm256_add_ps(acc_sum, e);
3479        i += 8;
3480    }
3481    _mm256_storeu_ps(buf8.as_mut_ptr(), acc_sum);
3482    let mut sum_exp: f32 = buf8.iter().sum();
3483    // SSE tail for remaining < 8 elements
3484    let off4 = _mm_set1_ps(max_val);
3485    while i + 4 <= len {
3486        let e = fast_exp_sse(_mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), off4));
3487        let mut b4 = [0.0f32; 4];
3488        _mm_storeu_ps(b4.as_mut_ptr(), e);
3489        sum_exp += b4[0] + b4[1] + b4[2] + b4[3];
3490        i += 4;
3491    }
3492    while i < len {
3493        sum_exp += (*in_ptr.add(i) - max_val).exp();
3494        i += 1;
3495    }
3496
3497    // 3. output[i] = x[i] - (max + log(sum_exp))
3498    let log_denom_val = max_val + sum_exp.ln();
3499    let log_denom8 = _mm256_set1_ps(log_denom_val);
3500    i = 0;
3501    while i + 8 <= len {
3502        _mm256_storeu_ps(
3503            out_ptr.add(i),
3504            _mm256_sub_ps(_mm256_loadu_ps(in_ptr.add(i)), log_denom8),
3505        );
3506        i += 8;
3507    }
3508    let log_denom4 = _mm_set1_ps(log_denom_val);
3509    while i + 4 <= len {
3510        _mm_storeu_ps(
3511            out_ptr.add(i),
3512            _mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), log_denom4),
3513        );
3514        i += 4;
3515    }
3516    while i < len {
3517        *out_ptr.add(i) = *in_ptr.add(i) - log_denom_val;
3518        i += 1;
3519    }
3520}
3521
3522// ===========================================================================
3523// Tests
3524// ===========================================================================
3525
3526#[cfg(test)]
3527mod tests {
3528    use super::*;
3529
3530    fn assert_close(a: &[f32], b: &[f32], tol: f32) {
3531        assert_eq!(a.len(), b.len(), "length mismatch");
3532        for (i, (&x, &y)) in a.iter().zip(b.iter()).enumerate() {
3533            let d = (x - y).abs();
3534            assert!(d <= tol, "index {i}: {x} vs {y}, diff={d}, tolerance={tol}");
3535        }
3536    }
3537
3538    #[test]
3539    fn exp_matches_scalar() {
3540        let input: Vec<f32> = (-20..=20).map(|i| i as f32 * 0.5).collect();
3541        let mut simd_out = vec![0.0f32; input.len()];
3542        let mut scalar_out = vec![0.0f32; input.len()];
3543
3544        exp_slice_dispatch(&input, &mut simd_out);
3545        exp_slice_scalar(&input, &mut scalar_out);
3546
3547        // Degree-6 Taylor polynomial is accurate to roughly 1e-6 relative error
3548        for (i, (&s, &r)) in simd_out.iter().zip(scalar_out.iter()).enumerate() {
3549            let rel = if r.abs() > 1e-10 {
3550                (s - r).abs() / r.abs()
3551            } else {
3552                (s - r).abs()
3553            };
3554            assert!(
3555                rel < 1e-5,
3556                "exp mismatch at index {i}: simd={s}, scalar={r}, rel_err={rel}"
3557            );
3558        }
3559    }
3560
3561    #[test]
3562    fn sigmoid_dispatch_matches_scalar() {
3563        let input: Vec<f32> = (-30..=30).map(|i| i as f32 * 0.3).collect();
3564        let mut simd_out = vec![0.0f32; input.len()];
3565        let mut scalar_out = vec![0.0f32; input.len()];
3566
3567        sigmoid_slice_dispatch(&input, &mut simd_out);
3568        sigmoid_slice_dispatch_scalar(&input, &mut scalar_out);
3569
3570        // Sigmoid uses Schraudolph bit-trick exp (~4% max error on exp,
3571        // but sigmoid squashes error near 0/1, practical max ~0.03).
3572        assert_close(&simd_out, &scalar_out, 0.035);
3573    }
3574
3575    #[test]
3576    fn tanh_dispatch_matches_scalar() {
3577        let input: Vec<f32> = (-30..=30).map(|i| i as f32 * 0.3).collect();
3578        let mut simd_out = vec![0.0f32; input.len()];
3579        let mut scalar_out = vec![0.0f32; input.len()];
3580
3581        tanh_slice_dispatch(&input, &mut simd_out);
3582        tanh_slice_dispatch_scalar(&input, &mut scalar_out);
3583
3584        // Uses fast 3-term exp polynomial for sigmoid path (~2e-3 max error vs scalar tanh).
3585        assert_close(&simd_out, &scalar_out, 2e-3);
3586    }
3587
3588    #[test]
3589    fn max_reduce_matches_scalar() {
3590        let data: Vec<f32> = (0..37).map(|i| (i as f32 * 0.7 - 12.0).sin()).collect();
3591        let simd_result = max_reduce_dispatch(&data);
3592        let scalar_result = max_reduce_scalar(&data);
3593        assert!((simd_result - scalar_result).abs() < 1e-6);
3594    }
3595
3596    #[test]
3597    fn max_reduce_empty() {
3598        assert_eq!(max_reduce_dispatch(&[]), f32::NEG_INFINITY);
3599    }
3600
3601    #[test]
3602    fn add_reduce_matches_scalar() {
3603        let data: Vec<f32> = (0..37).map(|i| i as f32 * 0.1).collect();
3604        let simd_result = add_reduce_dispatch(&data);
3605        let scalar_result = add_reduce_scalar(&data);
3606        assert!(
3607            (simd_result - scalar_result).abs() < 1e-3,
3608            "simd={simd_result}, scalar={scalar_result}"
3609        );
3610    }
3611
3612    #[test]
3613    fn add_reduce_empty() {
3614        assert_eq!(add_reduce_dispatch(&[]), 0.0);
3615    }
3616
3617    #[test]
3618    #[allow(unsafe_code)]
3619    fn fma_matches_scalar() {
3620        let a: Vec<f32> = (0..33).map(|i| i as f32 * 0.3).collect();
3621        let b: Vec<f32> = (0..33).map(|i| (i as f32 * 0.7).sin()).collect();
3622        let mut simd_acc = vec![1.0f32; 33];
3623        let mut scalar_acc = vec![1.0f32; 33];
3624
3625        fma_slice_dispatch(&a, &b, &mut simd_acc);
3626        unsafe { fma_slice_scalar(&a, &b, &mut scalar_acc) };
3627
3628        assert_close(&simd_acc, &scalar_acc, 1e-5);
3629    }
3630
3631    #[test]
3632    fn sigmoid_dispatch_boundary_values() {
3633        // Verify sigmoid at key points
3634        let input = vec![-100.0, -10.0, 0.0, 10.0, 100.0];
3635        let mut output = vec![0.0f32; 5];
3636        sigmoid_slice_dispatch(&input, &mut output);
3637
3638        // sigmoid(-100) ~ 0, sigmoid(0) = 0.5, sigmoid(100) ~ 1
3639        assert!(
3640            output[0] < 0.01,
3641            "sigmoid(-100) should be near 0: {}",
3642            output[0]
3643        );
3644        assert!(
3645            (output[2] - 0.5).abs() < 0.01,
3646            "sigmoid(0) should be near 0.5: {}",
3647            output[2]
3648        );
3649        assert!(
3650            output[4] > 0.99,
3651            "sigmoid(100) should be near 1: {}",
3652            output[4]
3653        );
3654    }
3655
3656    #[test]
3657    fn tanh_dispatch_boundary_values() {
3658        let input = vec![-100.0, -1.0, 0.0, 1.0, 100.0];
3659        let mut output = vec![0.0f32; 5];
3660        tanh_slice_dispatch(&input, &mut output);
3661
3662        assert!(
3663            output[0] < -0.99,
3664            "tanh(-100) should be near -1: {}",
3665            output[0]
3666        );
3667        assert!(
3668            (output[2]).abs() < 0.01,
3669            "tanh(0) should be near 0: {}",
3670            output[2]
3671        );
3672        assert!(
3673            output[4] > 0.99,
3674            "tanh(100) should be near 1: {}",
3675            output[4]
3676        );
3677    }
3678}