1#[cfg(target_arch = "aarch64")]
2use std::arch::aarch64::{
3 float32x4_t, vaddq_f32, vdivq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vmaxq_f32, vminq_f32,
4 vmulq_f32, vnegq_f32, vst1q_f32, vsubq_f32,
5};
6#[cfg(target_arch = "x86")]
7use std::arch::x86::{
8 __m128, __m256, _mm_add_ps, _mm_castsi128_ps, _mm_cvtepi32_ps, _mm_cvtps_epi32, _mm_loadu_ps,
9 _mm_max_ps, _mm_min_ps, _mm_mul_ps, _mm_set1_epi32, _mm_set1_ps, _mm_setzero_ps, _mm_storeu_ps,
10 _mm_sub_ps, _mm256_add_ps, _mm256_castsi256_ps, _mm256_cvtepi32_ps, _mm256_cvtps_epi32,
11 _mm256_loadu_ps, _mm256_max_ps, _mm256_min_ps, _mm256_mul_ps, _mm256_set1_epi32,
12 _mm256_set1_ps, _mm256_setzero_ps, _mm256_storeu_ps, _mm256_sub_ps,
13};
14#[cfg(target_arch = "x86_64")]
15use std::arch::x86_64::{
16 __m128, __m256, _mm_add_ps, _mm_castsi128_ps, _mm_cvtepi32_ps, _mm_cvtps_epi32, _mm_loadu_ps,
17 _mm_max_ps, _mm_min_ps, _mm_mul_ps, _mm_set1_epi32, _mm_set1_ps, _mm_setzero_ps, _mm_storeu_ps,
18 _mm_sub_ps, _mm256_add_ps, _mm256_castsi256_ps, _mm256_cvtepi32_ps, _mm256_cvtps_epi32,
19 _mm256_loadu_ps, _mm256_max_ps, _mm256_min_ps, _mm256_mul_ps, _mm256_set1_epi32,
20 _mm256_set1_ps, _mm256_setzero_ps, _mm256_storeu_ps, _mm256_sub_ps,
21};
22
23use super::config::BinaryKind;
24
25#[cfg(all(feature = "mkl", any(target_arch = "x86", target_arch = "x86_64")))]
26#[allow(unsafe_code, dead_code)]
27unsafe extern "C" {
28 fn vsAdd(n: i32, a: *const f32, b: *const f32, y: *mut f32);
29 fn vsSub(n: i32, a: *const f32, b: *const f32, y: *mut f32);
30 fn vsMul(n: i32, a: *const f32, b: *const f32, y: *mut f32);
31 fn vsDiv(n: i32, a: *const f32, b: *const f32, y: *mut f32);
32 fn vsExp(n: i32, a: *const f32, y: *mut f32);
33 fn vsSqrt(n: i32, a: *const f32, y: *mut f32);
34 fn vsLn(n: i32, a: *const f32, y: *mut f32);
35}
36
37#[cfg(all(feature = "armpl", target_arch = "aarch64", not(target_os = "macos")))]
38#[allow(unsafe_code, dead_code)]
39unsafe extern "C" {
40 fn armpl_svexp_f32(n: i32, x: *const f32, y: *mut f32);
41 fn armpl_svadd_f32(n: i32, a: *const f32, b: *const f32, y: *mut f32);
42 fn armpl_svsub_f32(n: i32, a: *const f32, b: *const f32, y: *mut f32);
43 fn armpl_svmul_f32(n: i32, a: *const f32, b: *const f32, y: *mut f32);
44 fn armpl_svlog_f32(n: i32, x: *const f32, y: *mut f32);
45 fn armpl_svsqrt_f32(n: i32, x: *const f32, y: *mut f32);
46}
47
48#[cfg(target_os = "macos")]
49#[allow(unsafe_code, dead_code)]
50unsafe extern "C" {
51 fn vvexpf(result: *mut f32, input: *const f32, count: *const i32);
52 fn vDSP_vadd(
53 __A: *const f32,
54 __IA: i32,
55 __B: *const f32,
56 __IB: i32,
57 __C: *mut f32,
58 __IC: i32,
59 __N: u32,
60 );
61 fn vDSP_vsub(
62 __B: *const f32,
63 __IB: i32,
64 __A: *const f32,
65 __IA: i32,
66 __C: *mut f32,
67 __IC: i32,
68 __N: u32,
69 );
70 fn vDSP_vmul(
71 __A: *const f32,
72 __IA: i32,
73 __B: *const f32,
74 __IB: i32,
75 __C: *mut f32,
76 __IC: i32,
77 __N: u32,
78 );
79}
80
81#[allow(unsafe_code)]
86#[inline]
87pub fn relu_slice_dispatch(values: &mut [f32]) {
88 if cfg!(miri) {
89 unsafe {
91 relu_slice_scalar(values);
92 }
93 return;
94 }
95
96 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
97 {
98 if std::is_x86_feature_detected!("avx") {
99 unsafe {
101 relu_slice_avx(values);
102 }
103 return;
104 }
105 if std::is_x86_feature_detected!("sse") {
106 unsafe {
108 relu_slice_sse(values);
109 }
110 return;
111 }
112 }
113
114 #[cfg(target_arch = "aarch64")]
115 {
116 if std::arch::is_aarch64_feature_detected!("neon") {
117 unsafe {
119 relu_slice_neon(values);
120 }
121 return;
122 }
123 }
124
125 unsafe {
127 relu_slice_scalar(values);
128 }
129}
130
131#[allow(unsafe_code)]
136#[inline]
137pub fn relu_to_slice_dispatch(input: &[f32], output: &mut [f32]) {
138 debug_assert_eq!(input.len(), output.len());
139
140 if cfg!(miri) {
141 unsafe {
143 relu_to_slice_scalar(input, output);
144 }
145 return;
146 }
147
148 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
149 {
150 if std::is_x86_feature_detected!("avx") {
151 unsafe {
153 relu_to_slice_avx(input, output);
154 }
155 return;
156 }
157 if std::is_x86_feature_detected!("sse") {
158 unsafe {
160 relu_to_slice_sse(input, output);
161 }
162 return;
163 }
164 }
165
166 #[cfg(target_arch = "aarch64")]
167 {
168 if std::arch::is_aarch64_feature_detected!("neon") {
169 unsafe {
171 relu_to_slice_neon(input, output);
172 }
173 return;
174 }
175 }
176
177 unsafe {
179 relu_to_slice_scalar(input, output);
180 }
181}
182
183#[inline]
184#[allow(dead_code)]
185pub(crate) fn sigmoid_slice(values: &mut [f32]) {
186 for value in values {
187 *value = sigmoid_scalar(*value);
188 }
189}
190
191#[inline]
192pub(crate) fn sigmoid_scalar(value: f32) -> f32 {
193 if value >= 0.0 {
194 let z = (-value).exp();
195 1.0 / (1.0 + z)
196 } else {
197 let z = value.exp();
198 z / (1.0 + z)
199 }
200}
201
202#[allow(unsafe_code, unreachable_code)]
211#[inline]
212pub fn exp_slice_dispatch(input: &[f32], output: &mut [f32]) {
213 debug_assert_eq!(input.len(), output.len());
214
215 if cfg!(miri) {
216 exp_slice_scalar(input, output);
217 return;
218 }
219
220 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
222 {
223 let count = input.len() as i32;
224 unsafe {
227 vvexpf(output.as_mut_ptr(), input.as_ptr(), &count);
228 }
229 return;
230 }
231
232 #[cfg(all(feature = "mkl", any(target_arch = "x86", target_arch = "x86_64")))]
234 {
235 let count = input.len() as i32;
236 unsafe { vsExp(count, input.as_ptr(), output.as_mut_ptr()) };
238 return;
239 }
240
241 #[cfg(all(feature = "armpl", target_arch = "aarch64", not(target_os = "macos")))]
243 {
244 let count = input.len() as i32;
245 unsafe { armpl_svexp_f32(count, input.as_ptr(), output.as_mut_ptr()) };
247 return;
248 }
249
250 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
251 {
252 if std::is_x86_feature_detected!("avx") {
253 unsafe {
255 exp_slice_avx(input, output);
256 }
257 return;
258 }
259 if std::is_x86_feature_detected!("sse") {
260 unsafe {
262 exp_slice_sse(input, output);
263 }
264 return;
265 }
266 }
267
268 #[cfg(all(target_arch = "aarch64", not(target_os = "macos")))]
269 {
270 if std::arch::is_aarch64_feature_detected!("neon") {
271 unsafe {
273 exp_slice_neon(input, output);
274 }
275 return;
276 }
277 }
278
279 exp_slice_scalar(input, output);
280}
281
282#[allow(unsafe_code)]
287#[inline]
288pub fn sub_exp_slice_dispatch(input: &[f32], offset: f32, output: &mut [f32]) {
289 debug_assert_eq!(input.len(), output.len());
290
291 if cfg!(miri) {
292 sub_exp_slice_scalar(input, offset, output);
293 return;
294 }
295
296 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
297 {
298 if std::is_x86_feature_detected!("avx") {
299 unsafe {
301 sub_exp_slice_avx(input, offset, output);
302 }
303 return;
304 }
305 if std::is_x86_feature_detected!("sse") {
306 unsafe {
308 sub_exp_slice_sse(input, offset, output);
309 }
310 return;
311 }
312 }
313
314 #[cfg(target_arch = "aarch64")]
315 {
316 if std::arch::is_aarch64_feature_detected!("neon") {
317 unsafe {
319 sub_exp_slice_neon(input, offset, output);
320 }
321 return;
322 }
323 }
324
325 sub_exp_slice_scalar(input, offset, output);
326}
327
328#[allow(unsafe_code, clippy::needless_return)]
330#[inline]
331pub fn sigmoid_slice_dispatch(input: &[f32], output: &mut [f32]) {
332 debug_assert_eq!(input.len(), output.len());
333
334 if cfg!(miri) {
335 sigmoid_slice_dispatch_scalar(input, output);
336 return;
337 }
338
339 {
341 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
342 {
343 if std::is_x86_feature_detected!("avx") {
344 unsafe {
346 sigmoid_slice_avx(input, output);
347 }
348 return;
349 }
350 if std::is_x86_feature_detected!("sse") {
351 unsafe {
353 sigmoid_slice_sse(input, output);
354 }
355 return;
356 }
357 }
358
359 #[cfg(target_arch = "aarch64")]
360 {
361 if std::arch::is_aarch64_feature_detected!("neon") {
362 unsafe {
363 sigmoid_slice_neon(input, output);
364 }
365 return;
366 }
367 }
368
369 sigmoid_slice_dispatch_scalar(input, output);
370 }
371}
372
373#[cfg(target_arch = "aarch64")]
374#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
375#[target_feature(enable = "neon")]
376#[inline]
377unsafe fn fast_exp_sigmoid_neon(x: float32x4_t) -> float32x4_t {
380 use std::arch::aarch64::{
381 vaddq_s32, vcvtnq_s32_f32, vcvtq_f32_s32, vdupq_n_s32, vreinterpretq_f32_s32, vshlq_n_s32,
382 vsubq_f32,
383 };
384 let x = vmaxq_f32(vdupq_n_f32(-88.0), vminq_f32(vdupq_n_f32(88.0), x));
385 let n_f = vmulq_f32(x, vdupq_n_f32(std::f32::consts::LOG2_E));
386 let n_i = vcvtnq_s32_f32(n_f);
387 let r = vsubq_f32(
388 x,
389 vmulq_f32(vcvtq_f32_s32(n_i), vdupq_n_f32(std::f32::consts::LN_2)),
390 );
391 let pow2n = vreinterpretq_f32_s32(vshlq_n_s32::<23>(vaddq_s32(n_i, vdupq_n_s32(127))));
392 let p = vfmaq_f32(vdupq_n_f32(0.5), r, vdupq_n_f32(1.0 / 6.0));
393 let p = vfmaq_f32(vdupq_n_f32(1.0), r, p);
394 vmulq_f32(vfmaq_f32(vdupq_n_f32(1.0), r, p), pow2n)
395}
396
397#[cfg(target_arch = "aarch64")]
402#[allow(unsafe_code)]
403unsafe fn sigmoid_slice_neon(input: &[f32], output: &mut [f32]) {
404 let len = input.len();
405 let mut inp = input.as_ptr();
406 let mut out = output.as_mut_ptr();
407 let mut remaining = len;
408
409 if remaining >= 4 {
411 unsafe {
412 let c_neg88: f32 = -88.0;
414 let c_pos88: f32 = 88.0;
415 let c_schr_c: f32 = 12102203.0; let c_schr_b: i32 = 127 << 23; let c_sixth: f32 = 1.0 / 6.0;
421 let c_half: f32 = 0.5;
422 let c_one: f32 = 1.0;
423 let c_127: i32 = 127;
424
425 std::arch::asm!(
427 "ld1r {{v16.4s}}, [{p_neg88}]",
428 "ld1r {{v17.4s}}, [{p_pos88}]",
429 "ld1r {{v18.4s}}, [{p_schr_c}]", "dup v19.4s, {p_schr_b:w}", "ld1r {{v20.4s}}, [{p_sixth}]",
432 "ld1r {{v21.4s}}, [{p_half}]",
433 "ld1r {{v22.4s}}, [{p_one}]",
434 "dup v23.4s, {p_127:w}",
435 p_neg88 = in(reg) &c_neg88,
436 p_pos88 = in(reg) &c_pos88,
437 p_schr_c = in(reg) &c_schr_c,
438 p_schr_b = in(reg) c_schr_b,
439 p_sixth = in(reg) &c_sixth,
440 p_half = in(reg) &c_half,
441 p_one = in(reg) &c_one,
442 p_127 = in(reg) c_127,
443 out("v16") _, out("v17") _, out("v18") _, out("v19") _,
444 out("v20") _, out("v21") _, out("v22") _, out("v23") _,
445 );
446
447 while remaining >= 16 {
451 std::arch::asm!(
452 "ldp q0, q1, [{inp}]",
453 "ldp q2, q3, [{inp}, #32]",
454 "add {inp}, {inp}, #64",
455 "fneg v0.4s, v0.4s",
456 "fneg v1.4s, v1.4s",
457 "fneg v2.4s, v2.4s",
458 "fneg v3.4s, v3.4s",
459 "fmax v0.4s, v0.4s, v16.4s",
460 "fmax v1.4s, v1.4s, v16.4s",
461 "fmax v2.4s, v2.4s, v16.4s",
462 "fmax v3.4s, v3.4s, v16.4s",
463 "fmin v0.4s, v0.4s, v17.4s",
464 "fmin v1.4s, v1.4s, v17.4s",
465 "fmin v2.4s, v2.4s, v17.4s",
466 "fmin v3.4s, v3.4s, v17.4s",
467 "fmul v0.4s, v0.4s, v18.4s",
469 "fmul v1.4s, v1.4s, v18.4s",
470 "fmul v2.4s, v2.4s, v18.4s",
471 "fmul v3.4s, v3.4s, v18.4s",
472 "fcvtzs v0.4s, v0.4s",
473 "fcvtzs v1.4s, v1.4s",
474 "fcvtzs v2.4s, v2.4s",
475 "fcvtzs v3.4s, v3.4s",
476 "add v0.4s, v0.4s, v19.4s",
478 "add v1.4s, v1.4s, v19.4s",
479 "add v2.4s, v2.4s, v19.4s",
480 "add v3.4s, v3.4s, v19.4s",
481 "fadd v0.4s, v22.4s, v0.4s",
484 "fadd v1.4s, v22.4s, v1.4s",
485 "fadd v2.4s, v22.4s, v2.4s",
486 "fadd v3.4s, v22.4s, v3.4s",
487 "fdiv v0.4s, v22.4s, v0.4s",
488 "fdiv v1.4s, v22.4s, v1.4s",
489 "fdiv v2.4s, v22.4s, v2.4s",
490 "fdiv v3.4s, v22.4s, v3.4s",
491 "stp q0, q1, [{out}]",
492 "stp q2, q3, [{out}, #32]",
493 "add {out}, {out}, #64",
494 inp = inout(reg) inp,
495 out = inout(reg) out,
496 out("v0") _, out("v1") _, out("v2") _, out("v3") _,
497 );
498 remaining -= 16;
499 }
500 while remaining >= 4 {
502 std::arch::asm!(
503 "ld1 {{v0.4s}}, [{inp}], #16",
504 "fneg v0.4s, v0.4s",
505 "fmax v0.4s, v0.4s, v16.4s",
506 "fmin v0.4s, v0.4s, v17.4s",
507 "fmul v0.4s, v0.4s, v18.4s",
508 "fcvtzs v0.4s, v0.4s",
509 "add v0.4s, v0.4s, v19.4s",
510 "fadd v0.4s, v22.4s, v0.4s",
511 "fdiv v0.4s, v22.4s, v0.4s",
512 "st1 {{v0.4s}}, [{out}], #16",
513 inp = inout(reg) inp,
514 out = inout(reg) out,
515 out("v0") _,
516 );
517 remaining -= 4;
518 }
519 while remaining >= 4 {
521 std::arch::asm!(
522 "ld1 {{v0.4s}}, [{inp}], #16",
523 "fneg v0.4s, v0.4s",
524 "fmax v0.4s, v0.4s, v16.4s",
525 "fmin v0.4s, v0.4s, v17.4s",
526 "fmul v0.4s, v0.4s, v18.4s",
527 "fcvtzs v0.4s, v0.4s",
528 "add v0.4s, v0.4s, v19.4s",
529 "fadd v0.4s, v22.4s, v0.4s",
530 "fdiv v0.4s, v22.4s, v0.4s",
531 "st1 {{v0.4s}}, [{out}], #16",
532 inp = inout(reg) inp,
533 out = inout(reg) out,
534 out("v0") _,
535 );
536 remaining -= 4;
537 }
538 }
539 }
540
541 for i in 0..remaining {
543 unsafe {
544 let x = *inp.add(i);
545 *out.add(i) = 1.0 / (1.0 + (-x).exp());
546 }
547 }
548}
549
550#[allow(unsafe_code)]
556#[inline]
557pub fn tanh_slice_dispatch(input: &[f32], output: &mut [f32]) {
558 debug_assert_eq!(input.len(), output.len());
559
560 if cfg!(miri) {
561 tanh_slice_dispatch_scalar(input, output);
562 return;
563 }
564
565 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
566 {
567 if std::is_x86_feature_detected!("avx") {
568 unsafe {
570 tanh_slice_avx(input, output);
571 }
572 return;
573 }
574 if std::is_x86_feature_detected!("sse") {
575 unsafe {
577 tanh_slice_sse(input, output);
578 }
579 return;
580 }
581 }
582
583 #[cfg(target_arch = "aarch64")]
584 {
585 if std::arch::is_aarch64_feature_detected!("neon") {
586 unsafe {
588 tanh_slice_neon(input, output);
589 }
590 return;
591 }
592 }
593
594 tanh_slice_dispatch_scalar(input, output);
595}
596
597#[allow(unsafe_code)]
601#[inline]
602pub fn silu_slice_dispatch(input: &[f32], output: &mut [f32]) {
603 debug_assert_eq!(input.len(), output.len());
604
605 if cfg!(miri) {
606 silu_slice_dispatch_scalar(input, output);
607 return;
608 }
609
610 #[cfg(target_arch = "aarch64")]
611 {
612 if std::arch::is_aarch64_feature_detected!("neon") {
613 unsafe {
614 silu_slice_neon(input, output);
615 }
616 return;
617 }
618 }
619
620 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
621 {
622 if std::is_x86_feature_detected!("avx") {
623 unsafe { silu_slice_avx(input, output) };
624 return;
625 }
626 if std::is_x86_feature_detected!("sse") {
627 unsafe { silu_slice_sse(input, output) };
628 return;
629 }
630 }
631
632 silu_slice_dispatch_scalar(input, output);
633}
634
635#[allow(unsafe_code, dead_code)]
641#[inline]
642pub fn max_reduce_dispatch(data: &[f32]) -> f32 {
643 if data.is_empty() {
644 return f32::NEG_INFINITY;
645 }
646
647 if cfg!(miri) {
648 return max_reduce_scalar(data);
649 }
650
651 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
652 {
653 if std::is_x86_feature_detected!("avx") {
654 return unsafe { max_reduce_avx(data) };
656 }
657 if std::is_x86_feature_detected!("sse") {
658 return unsafe { max_reduce_sse(data) };
660 }
661 }
662
663 #[cfg(target_arch = "aarch64")]
664 {
665 if std::arch::is_aarch64_feature_detected!("neon") {
666 return unsafe { max_reduce_neon(data) };
668 }
669 }
670
671 max_reduce_scalar(data)
672}
673
674#[allow(unsafe_code, dead_code)]
676#[inline]
677pub fn add_reduce_dispatch(data: &[f32]) -> f32 {
678 if data.is_empty() {
679 return 0.0;
680 }
681
682 if cfg!(miri) {
683 return add_reduce_scalar(data);
684 }
685
686 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
687 {
688 if std::is_x86_feature_detected!("avx") {
689 return unsafe { add_reduce_avx(data) };
691 }
692 if std::is_x86_feature_detected!("sse") {
693 return unsafe { add_reduce_sse(data) };
695 }
696 }
697
698 #[cfg(target_arch = "aarch64")]
699 {
700 if std::arch::is_aarch64_feature_detected!("neon") {
701 return unsafe { add_reduce_neon(data) };
703 }
704 }
705
706 add_reduce_scalar(data)
707}
708
709#[allow(unsafe_code, dead_code)]
715#[inline]
716pub fn mul_scalar_inplace_dispatch(data: &mut [f32], scalar: f32) {
717 if cfg!(miri) || data.is_empty() {
718 for v in data.iter_mut() {
719 *v *= scalar;
720 }
721 return;
722 }
723
724 #[cfg(target_arch = "aarch64")]
725 {
726 if std::arch::is_aarch64_feature_detected!("neon") {
727 unsafe { mul_scalar_inplace_neon(data, scalar) };
729 return;
730 }
731 }
732
733 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
734 {
735 if std::is_x86_feature_detected!("avx") {
736 unsafe { mul_scalar_inplace_avx(data, scalar) };
738 return;
739 }
740 if std::is_x86_feature_detected!("sse") {
741 unsafe { mul_scalar_inplace_sse(data, scalar) };
743 return;
744 }
745 }
746
747 for v in data.iter_mut() {
748 *v *= scalar;
749 }
750}
751
752#[cfg(target_arch = "aarch64")]
753#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
754#[target_feature(enable = "neon")]
755unsafe fn mul_scalar_inplace_neon(data: &mut [f32], scalar: f32) {
756 let len = data.len();
757 let ptr = data.as_mut_ptr();
758 let vs = vdupq_n_f32(scalar);
759 let mut i = 0usize;
760 while i + 4 <= len {
761 let v = vld1q_f32(ptr.add(i));
762 vst1q_f32(ptr.add(i), vmulq_f32(v, vs));
763 i += 4;
764 }
765 while i < len {
766 *ptr.add(i) *= scalar;
767 i += 1;
768 }
769}
770
771#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
772#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
773#[target_feature(enable = "avx")]
774unsafe fn mul_scalar_inplace_avx(data: &mut [f32], scalar: f32) {
775 #[cfg(target_arch = "x86")]
776 use std::arch::x86::*;
777 #[cfg(target_arch = "x86_64")]
778 use std::arch::x86_64::*;
779 let len = data.len();
780 let ptr = data.as_mut_ptr();
781 let vs = _mm256_set1_ps(scalar);
782 let mut i = 0usize;
783 while i + 8 <= len {
784 let v = _mm256_loadu_ps(ptr.add(i));
785 _mm256_storeu_ps(ptr.add(i), _mm256_mul_ps(v, vs));
786 i += 8;
787 }
788 let vs4 = _mm_set1_ps(scalar);
790 while i + 4 <= len {
791 let v = _mm_loadu_ps(ptr.add(i));
792 _mm_storeu_ps(ptr.add(i), _mm_mul_ps(v, vs4));
793 i += 4;
794 }
795 while i < len {
796 *ptr.add(i) *= scalar;
797 i += 1;
798 }
799}
800
801#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
802#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
803#[target_feature(enable = "sse")]
804unsafe fn mul_scalar_inplace_sse(data: &mut [f32], scalar: f32) {
805 #[cfg(target_arch = "x86")]
806 use std::arch::x86::*;
807 #[cfg(target_arch = "x86_64")]
808 use std::arch::x86_64::*;
809 let len = data.len();
810 let ptr = data.as_mut_ptr();
811 let vs = _mm_set1_ps(scalar);
812 let mut i = 0usize;
813 while i + 4 <= len {
814 let v = _mm_loadu_ps(ptr.add(i));
815 _mm_storeu_ps(ptr.add(i), _mm_mul_ps(v, vs));
816 i += 4;
817 }
818 while i < len {
819 *ptr.add(i) *= scalar;
820 i += 1;
821 }
822}
823
824#[allow(unsafe_code, dead_code)]
830#[inline]
831pub fn fma_slice_dispatch(a: &[f32], b: &[f32], acc: &mut [f32]) {
832 debug_assert_eq!(a.len(), b.len());
833 debug_assert_eq!(a.len(), acc.len());
834
835 if cfg!(miri) {
836 unsafe {
838 fma_slice_scalar(a, b, acc);
839 }
840 return;
841 }
842
843 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
844 {
845 if std::is_x86_feature_detected!("avx") {
846 unsafe {
848 fma_slice_avx(a, b, acc);
849 }
850 return;
851 }
852 if std::is_x86_feature_detected!("sse") {
853 unsafe {
855 fma_slice_sse(a, b, acc);
856 }
857 return;
858 }
859 }
860
861 #[cfg(target_arch = "aarch64")]
862 {
863 if std::arch::is_aarch64_feature_detected!("neon") {
864 unsafe {
866 fma_slice_neon(a, b, acc);
867 }
868 return;
869 }
870 }
871
872 unsafe {
874 fma_slice_scalar(a, b, acc);
875 }
876}
877
878#[allow(unsafe_code, unreachable_code)]
883#[inline]
884pub fn binary_same_shape_dispatch(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
885 debug_assert_eq!(lhs.len(), rhs.len());
886 debug_assert_eq!(lhs.len(), out.len());
887
888 if cfg!(miri) {
889 unsafe {
891 binary_same_shape_scalar(lhs, rhs, out, kind);
892 }
893 return;
894 }
895
896 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
898 {
899 let n = lhs.len() as u32;
900 unsafe {
902 match kind {
903 BinaryKind::Add => {
904 vDSP_vadd(lhs.as_ptr(), 1, rhs.as_ptr(), 1, out.as_mut_ptr(), 1, n)
905 }
906 BinaryKind::Sub => {
908 vDSP_vsub(rhs.as_ptr(), 1, lhs.as_ptr(), 1, out.as_mut_ptr(), 1, n)
909 }
910 BinaryKind::Mul => {
911 vDSP_vmul(lhs.as_ptr(), 1, rhs.as_ptr(), 1, out.as_mut_ptr(), 1, n)
912 }
913 }
914 }
915 return;
916 }
917
918 #[cfg(all(feature = "mkl", any(target_arch = "x86", target_arch = "x86_64")))]
920 {
921 let n = lhs.len() as i32;
922 unsafe {
924 match kind {
925 BinaryKind::Add => vsAdd(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
926 BinaryKind::Sub => vsSub(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
927 BinaryKind::Mul => vsMul(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
928 }
929 }
930 return;
931 }
932
933 #[cfg(all(feature = "armpl", target_arch = "aarch64", not(target_os = "macos")))]
935 {
936 let n = lhs.len() as i32;
937 unsafe {
939 match kind {
940 BinaryKind::Add => armpl_svadd_f32(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
941 BinaryKind::Sub => armpl_svsub_f32(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
942 BinaryKind::Mul => armpl_svmul_f32(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
943 }
944 }
945 return;
946 }
947
948 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
949 {
950 if std::is_x86_feature_detected!("avx") {
951 unsafe {
953 binary_same_shape_avx(lhs, rhs, out, kind);
954 }
955 return;
956 }
957 if std::is_x86_feature_detected!("sse") {
958 unsafe {
960 binary_same_shape_sse(lhs, rhs, out, kind);
961 }
962 return;
963 }
964 }
965
966 #[cfg(all(target_arch = "aarch64", not(target_os = "macos")))]
967 {
968 if std::arch::is_aarch64_feature_detected!("neon") {
969 unsafe {
971 binary_same_shape_neon(lhs, rhs, out, kind);
972 }
973 return;
974 }
975 }
976
977 unsafe {
979 binary_same_shape_scalar(lhs, rhs, out, kind);
980 }
981}
982
983#[allow(unsafe_code)]
988#[allow(unsafe_op_in_unsafe_fn)]
989unsafe fn relu_slice_scalar(values: &mut [f32]) {
990 let len = values.len();
991 let ptr = values.as_mut_ptr();
992 let mut index = 0usize;
993
994 while index + 8 <= len {
995 let v0 = *ptr.add(index);
996 let v1 = *ptr.add(index + 1);
997 let v2 = *ptr.add(index + 2);
998 let v3 = *ptr.add(index + 3);
999 let v4 = *ptr.add(index + 4);
1000 let v5 = *ptr.add(index + 5);
1001 let v6 = *ptr.add(index + 6);
1002 let v7 = *ptr.add(index + 7);
1003 *ptr.add(index) = v0.max(0.0);
1004 *ptr.add(index + 1) = v1.max(0.0);
1005 *ptr.add(index + 2) = v2.max(0.0);
1006 *ptr.add(index + 3) = v3.max(0.0);
1007 *ptr.add(index + 4) = v4.max(0.0);
1008 *ptr.add(index + 5) = v5.max(0.0);
1009 *ptr.add(index + 6) = v6.max(0.0);
1010 *ptr.add(index + 7) = v7.max(0.0);
1011 index += 8;
1012 }
1013
1014 while index < len {
1015 *ptr.add(index) = (*ptr.add(index)).max(0.0);
1016 index += 1;
1017 }
1018}
1019
1020#[allow(unsafe_code)]
1021#[allow(unsafe_op_in_unsafe_fn)]
1022unsafe fn relu_to_slice_scalar(input: &[f32], output: &mut [f32]) {
1023 let len = input.len();
1024 let in_ptr = input.as_ptr();
1025 let out_ptr = output.as_mut_ptr();
1026 let mut index = 0usize;
1027
1028 while index + 8 <= len {
1029 *out_ptr.add(index) = (*in_ptr.add(index)).max(0.0);
1030 *out_ptr.add(index + 1) = (*in_ptr.add(index + 1)).max(0.0);
1031 *out_ptr.add(index + 2) = (*in_ptr.add(index + 2)).max(0.0);
1032 *out_ptr.add(index + 3) = (*in_ptr.add(index + 3)).max(0.0);
1033 *out_ptr.add(index + 4) = (*in_ptr.add(index + 4)).max(0.0);
1034 *out_ptr.add(index + 5) = (*in_ptr.add(index + 5)).max(0.0);
1035 *out_ptr.add(index + 6) = (*in_ptr.add(index + 6)).max(0.0);
1036 *out_ptr.add(index + 7) = (*in_ptr.add(index + 7)).max(0.0);
1037 index += 8;
1038 }
1039
1040 while index < len {
1041 *out_ptr.add(index) = (*in_ptr.add(index)).max(0.0);
1042 index += 1;
1043 }
1044}
1045
1046#[allow(unsafe_code)]
1047#[allow(unsafe_op_in_unsafe_fn)]
1048unsafe fn binary_same_shape_scalar(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
1049 let len = lhs.len();
1050 let left_ptr = lhs.as_ptr();
1051 let right_ptr = rhs.as_ptr();
1052 let out_ptr = out.as_mut_ptr();
1053 let mut index = 0usize;
1054
1055 match kind {
1056 BinaryKind::Add => {
1057 while index + 8 <= len {
1058 *out_ptr.add(index) = *left_ptr.add(index) + *right_ptr.add(index);
1059 *out_ptr.add(index + 1) = *left_ptr.add(index + 1) + *right_ptr.add(index + 1);
1060 *out_ptr.add(index + 2) = *left_ptr.add(index + 2) + *right_ptr.add(index + 2);
1061 *out_ptr.add(index + 3) = *left_ptr.add(index + 3) + *right_ptr.add(index + 3);
1062 *out_ptr.add(index + 4) = *left_ptr.add(index + 4) + *right_ptr.add(index + 4);
1063 *out_ptr.add(index + 5) = *left_ptr.add(index + 5) + *right_ptr.add(index + 5);
1064 *out_ptr.add(index + 6) = *left_ptr.add(index + 6) + *right_ptr.add(index + 6);
1065 *out_ptr.add(index + 7) = *left_ptr.add(index + 7) + *right_ptr.add(index + 7);
1066 index += 8;
1067 }
1068 while index < len {
1069 *out_ptr.add(index) = *left_ptr.add(index) + *right_ptr.add(index);
1070 index += 1;
1071 }
1072 }
1073 BinaryKind::Sub => {
1074 while index + 8 <= len {
1075 *out_ptr.add(index) = *left_ptr.add(index) - *right_ptr.add(index);
1076 *out_ptr.add(index + 1) = *left_ptr.add(index + 1) - *right_ptr.add(index + 1);
1077 *out_ptr.add(index + 2) = *left_ptr.add(index + 2) - *right_ptr.add(index + 2);
1078 *out_ptr.add(index + 3) = *left_ptr.add(index + 3) - *right_ptr.add(index + 3);
1079 *out_ptr.add(index + 4) = *left_ptr.add(index + 4) - *right_ptr.add(index + 4);
1080 *out_ptr.add(index + 5) = *left_ptr.add(index + 5) - *right_ptr.add(index + 5);
1081 *out_ptr.add(index + 6) = *left_ptr.add(index + 6) - *right_ptr.add(index + 6);
1082 *out_ptr.add(index + 7) = *left_ptr.add(index + 7) - *right_ptr.add(index + 7);
1083 index += 8;
1084 }
1085 while index < len {
1086 *out_ptr.add(index) = *left_ptr.add(index) - *right_ptr.add(index);
1087 index += 1;
1088 }
1089 }
1090 BinaryKind::Mul => {
1091 while index + 8 <= len {
1092 *out_ptr.add(index) = *left_ptr.add(index) * *right_ptr.add(index);
1093 *out_ptr.add(index + 1) = *left_ptr.add(index + 1) * *right_ptr.add(index + 1);
1094 *out_ptr.add(index + 2) = *left_ptr.add(index + 2) * *right_ptr.add(index + 2);
1095 *out_ptr.add(index + 3) = *left_ptr.add(index + 3) * *right_ptr.add(index + 3);
1096 *out_ptr.add(index + 4) = *left_ptr.add(index + 4) * *right_ptr.add(index + 4);
1097 *out_ptr.add(index + 5) = *left_ptr.add(index + 5) * *right_ptr.add(index + 5);
1098 *out_ptr.add(index + 6) = *left_ptr.add(index + 6) * *right_ptr.add(index + 6);
1099 *out_ptr.add(index + 7) = *left_ptr.add(index + 7) * *right_ptr.add(index + 7);
1100 index += 8;
1101 }
1102 while index < len {
1103 *out_ptr.add(index) = *left_ptr.add(index) * *right_ptr.add(index);
1104 index += 1;
1105 }
1106 }
1107 }
1108}
1109
1110fn exp_slice_scalar(input: &[f32], output: &mut [f32]) {
1111 for (o, &v) in output.iter_mut().zip(input.iter()) {
1112 *o = v.exp();
1113 }
1114}
1115
1116fn sub_exp_slice_scalar(input: &[f32], offset: f32, output: &mut [f32]) {
1117 for (o, &v) in output.iter_mut().zip(input.iter()) {
1118 *o = (v - offset).exp();
1119 }
1120}
1121
1122fn sigmoid_slice_dispatch_scalar(input: &[f32], output: &mut [f32]) {
1123 for (o, &v) in output.iter_mut().zip(input.iter()) {
1124 *o = sigmoid_scalar(v);
1125 }
1126}
1127
1128fn tanh_slice_dispatch_scalar(input: &[f32], output: &mut [f32]) {
1129 for (o, &v) in output.iter_mut().zip(input.iter()) {
1130 *o = v.tanh();
1131 }
1132}
1133
1134fn silu_slice_dispatch_scalar(input: &[f32], output: &mut [f32]) {
1135 for (o, &v) in output.iter_mut().zip(input.iter()) {
1136 let s = 1.0 / (1.0 + (-v).exp());
1137 *o = v * s;
1138 }
1139}
1140
1141#[allow(dead_code)]
1142fn max_reduce_scalar(data: &[f32]) -> f32 {
1143 let mut acc = f32::NEG_INFINITY;
1144 for &v in data {
1145 acc = acc.max(v);
1146 }
1147 acc
1148}
1149
1150#[allow(dead_code)]
1151fn add_reduce_scalar(data: &[f32]) -> f32 {
1152 let mut acc = 0.0f32;
1153 for &v in data {
1154 acc += v;
1155 }
1156 acc
1157}
1158
1159#[allow(unsafe_code, dead_code)]
1160#[allow(unsafe_op_in_unsafe_fn)]
1161unsafe fn fma_slice_scalar(a: &[f32], b: &[f32], acc: &mut [f32]) {
1162 let len = a.len();
1163 let a_ptr = a.as_ptr();
1164 let b_ptr = b.as_ptr();
1165 let acc_ptr = acc.as_mut_ptr();
1166 let mut index = 0usize;
1167
1168 while index + 4 <= len {
1169 *acc_ptr.add(index) += *a_ptr.add(index) * *b_ptr.add(index);
1170 *acc_ptr.add(index + 1) += *a_ptr.add(index + 1) * *b_ptr.add(index + 1);
1171 *acc_ptr.add(index + 2) += *a_ptr.add(index + 2) * *b_ptr.add(index + 2);
1172 *acc_ptr.add(index + 3) += *a_ptr.add(index + 3) * *b_ptr.add(index + 3);
1173 index += 4;
1174 }
1175 while index < len {
1176 *acc_ptr.add(index) += *a_ptr.add(index) * *b_ptr.add(index);
1177 index += 1;
1178 }
1179}
1180
1181#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1192#[allow(unsafe_code)]
1193#[allow(unsafe_op_in_unsafe_fn)]
1194#[target_feature(enable = "sse")]
1195#[inline]
1196unsafe fn fast_exp_bittrick_sse(x: __m128) -> __m128 {
1197 #[cfg(target_arch = "x86")]
1199 use std::arch::x86::{_mm_add_epi32, _mm_cvtps_epi32, _mm_set1_epi32};
1200 #[cfg(target_arch = "x86_64")]
1201 use std::arch::x86_64::{_mm_add_epi32, _mm_cvtps_epi32, _mm_set1_epi32};
1202 let scale = _mm_set1_ps(12102203.0); let offset = _mm_set1_epi32(1065353216); let clamp_lo = _mm_set1_ps(-87.0); let clamp_hi = _mm_set1_ps(88.0); let x_clamped = _mm_max_ps(_mm_min_ps(x, clamp_hi), clamp_lo);
1208 let val = _mm_cvtps_epi32(_mm_mul_ps(x_clamped, scale));
1209 _mm_castsi128_ps(_mm_add_epi32(val, offset))
1210}
1211
1212#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1216#[allow(unsafe_code)]
1217#[allow(unsafe_op_in_unsafe_fn)]
1218#[target_feature(enable = "sse")]
1219unsafe fn fast_exp_sse(x: __m128) -> __m128 {
1220 let ln2_inv = _mm_set1_ps(std::f32::consts::LOG2_E);
1221 let ln2_hi = _mm_set1_ps(0.693_359_4); let ln2_lo = _mm_set1_ps(-2.121_944_4e-4); let c0 = _mm_set1_ps(1.0);
1226 let c1 = _mm_set1_ps(1.0);
1227 let c2 = _mm_set1_ps(0.5);
1228 let c3 = _mm_set1_ps(1.0 / 6.0);
1229 let c4 = _mm_set1_ps(1.0 / 24.0);
1230 let c5 = _mm_set1_ps(1.0 / 120.0);
1231 let c6 = _mm_set1_ps(1.0 / 720.0);
1232
1233 let x = _mm_max_ps(_mm_set1_ps(-88.0), _mm_min_ps(_mm_set1_ps(88.0), x));
1235
1236 let n_f = _mm_mul_ps(x, ln2_inv);
1238 let n_i = _mm_cvtps_epi32(n_f);
1240 let n_f = _mm_cvtepi32_ps(n_i);
1241
1242 let r = _mm_sub_ps(
1244 _mm_sub_ps(x, _mm_mul_ps(n_f, ln2_hi)),
1245 _mm_mul_ps(n_f, ln2_lo),
1246 );
1247
1248 let mut poly = _mm_add_ps(c5, _mm_mul_ps(r, c6));
1250 poly = _mm_add_ps(c4, _mm_mul_ps(r, poly));
1251 poly = _mm_add_ps(c3, _mm_mul_ps(r, poly));
1252 poly = _mm_add_ps(c2, _mm_mul_ps(r, poly));
1253 poly = _mm_add_ps(c1, _mm_mul_ps(r, poly));
1254 poly = _mm_add_ps(c0, _mm_mul_ps(r, poly));
1255
1256 let pow2n = {
1259 #[cfg(target_arch = "x86")]
1260 use std::arch::x86::{_mm_add_epi32, _mm_slli_epi32};
1261 #[cfg(target_arch = "x86_64")]
1262 use std::arch::x86_64::{_mm_add_epi32, _mm_slli_epi32};
1263 let bias = _mm_set1_epi32(127);
1264 _mm_castsi128_ps(_mm_slli_epi32(_mm_add_epi32(n_i, bias), 23))
1265 };
1266
1267 _mm_mul_ps(poly, pow2n)
1268}
1269
1270#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1277#[allow(unsafe_code)]
1278#[allow(unsafe_op_in_unsafe_fn)]
1279#[target_feature(enable = "avx")]
1280#[inline]
1281unsafe fn fast_exp_bittrick_avx(x: __m256) -> __m256 {
1282 #[cfg(target_arch = "x86")]
1283 use std::arch::x86::{_mm256_add_epi32, _mm256_cvtps_epi32, _mm256_set1_epi32};
1284 #[cfg(target_arch = "x86_64")]
1285 use std::arch::x86_64::{_mm256_add_epi32, _mm256_cvtps_epi32, _mm256_set1_epi32};
1286 let scale = _mm256_set1_ps(12102203.0); let offset = _mm256_set1_epi32(1065353216); let clamp_lo = _mm256_set1_ps(-87.0); let clamp_hi = _mm256_set1_ps(88.0); let x_clamped = _mm256_max_ps(_mm256_min_ps(x, clamp_hi), clamp_lo);
1291 let val = _mm256_cvtps_epi32(_mm256_mul_ps(x_clamped, scale));
1292 _mm256_castsi256_ps(_mm256_add_epi32(val, offset))
1293}
1294
1295#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1299#[allow(unsafe_code)]
1300#[allow(unsafe_op_in_unsafe_fn)]
1301#[target_feature(enable = "avx")]
1302unsafe fn fast_exp_avx(x: __m256) -> __m256 {
1303 let ln2_inv = _mm256_set1_ps(std::f32::consts::LOG2_E);
1304 let ln2_hi = _mm256_set1_ps(0.693_359_4);
1305 let ln2_lo = _mm256_set1_ps(-2.121_944_4e-4);
1306
1307 let c0 = _mm256_set1_ps(1.0);
1308 let c1 = _mm256_set1_ps(1.0);
1309 let c2 = _mm256_set1_ps(0.5);
1310 let c3 = _mm256_set1_ps(1.0 / 6.0);
1311 let c4 = _mm256_set1_ps(1.0 / 24.0);
1312 let c5 = _mm256_set1_ps(1.0 / 120.0);
1313 let c6 = _mm256_set1_ps(1.0 / 720.0);
1314
1315 let x = _mm256_max_ps(
1316 _mm256_set1_ps(-88.0),
1317 _mm256_min_ps(_mm256_set1_ps(88.0), x),
1318 );
1319
1320 let n_f = _mm256_mul_ps(x, ln2_inv);
1321 let n_i = _mm256_cvtps_epi32(n_f);
1322 let n_f = _mm256_cvtepi32_ps(n_i);
1323
1324 let r = _mm256_sub_ps(
1325 _mm256_sub_ps(x, _mm256_mul_ps(n_f, ln2_hi)),
1326 _mm256_mul_ps(n_f, ln2_lo),
1327 );
1328
1329 let mut poly = _mm256_add_ps(c5, _mm256_mul_ps(r, c6));
1330 poly = _mm256_add_ps(c4, _mm256_mul_ps(r, poly));
1331 poly = _mm256_add_ps(c3, _mm256_mul_ps(r, poly));
1332 poly = _mm256_add_ps(c2, _mm256_mul_ps(r, poly));
1333 poly = _mm256_add_ps(c1, _mm256_mul_ps(r, poly));
1334 poly = _mm256_add_ps(c0, _mm256_mul_ps(r, poly));
1335
1336 let bias = _mm256_set1_epi32(127);
1337 let pow2n = {
1338 #[cfg(target_arch = "x86")]
1339 use std::arch::x86::{_mm256_add_epi32, _mm256_slli_epi32};
1340 #[cfg(target_arch = "x86_64")]
1341 use std::arch::x86_64::{_mm256_add_epi32, _mm256_slli_epi32};
1342 _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_add_epi32(n_i, bias), 23))
1343 };
1344
1345 _mm256_mul_ps(poly, pow2n)
1346}
1347
1348#[cfg(target_arch = "aarch64")]
1353#[allow(unsafe_code)]
1354#[allow(unsafe_op_in_unsafe_fn)]
1355#[target_feature(enable = "neon")]
1356unsafe fn fast_exp_neon(x: float32x4_t) -> float32x4_t {
1357 use std::arch::aarch64::{
1358 vaddq_s32, vcvtnq_s32_f32, vcvtq_f32_s32, vreinterpretq_f32_s32, vshlq_n_s32,
1359 };
1360
1361 let ln2_inv = vdupq_n_f32(std::f32::consts::LOG2_E);
1362 let ln2_hi = vdupq_n_f32(0.693_359_4);
1363 let ln2_lo = vdupq_n_f32(-2.121_944_4e-4);
1364
1365 let c0 = vdupq_n_f32(1.0);
1366 let c1 = vdupq_n_f32(1.0);
1367 let c2 = vdupq_n_f32(0.5);
1368 let c3 = vdupq_n_f32(1.0 / 6.0);
1369 let c4 = vdupq_n_f32(1.0 / 24.0);
1370 let c5 = vdupq_n_f32(1.0 / 120.0);
1371 let c6 = vdupq_n_f32(1.0 / 720.0);
1372
1373 let x = vmaxq_f32(vdupq_n_f32(-88.0), vminq_f32(vdupq_n_f32(88.0), x));
1374
1375 let n_f = vmulq_f32(x, ln2_inv);
1376 let n_i = vcvtnq_s32_f32(n_f);
1377 let n_f = vcvtq_f32_s32(n_i);
1378
1379 let r = vsubq_f32(vsubq_f32(x, vmulq_f32(n_f, ln2_hi)), vmulq_f32(n_f, ln2_lo));
1380
1381 let mut poly = vfmaq_f32(c5, r, c6);
1382 poly = vfmaq_f32(c4, r, poly);
1383 poly = vfmaq_f32(c3, r, poly);
1384 poly = vfmaq_f32(c2, r, poly);
1385 poly = vfmaq_f32(c1, r, poly);
1386 poly = vfmaq_f32(c0, r, poly);
1387
1388 use std::arch::aarch64::vdupq_n_s32;
1389 let bias = vdupq_n_s32(127);
1390 let pow2n = vreinterpretq_f32_s32(vshlq_n_s32::<23>(vaddq_s32(n_i, bias)));
1391
1392 vmulq_f32(poly, pow2n)
1393}
1394
1395#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1400#[allow(unsafe_code)]
1401#[allow(unsafe_op_in_unsafe_fn)]
1402#[target_feature(enable = "sse")]
1403unsafe fn exp_slice_sse(input: &[f32], output: &mut [f32]) {
1404 let len = input.len();
1405 let in_ptr = input.as_ptr();
1406 let out_ptr = output.as_mut_ptr();
1407 let mut index = 0usize;
1408
1409 while index + 4 <= len {
1410 let v = _mm_loadu_ps(in_ptr.add(index));
1411 let r = fast_exp_sse(v);
1412 _mm_storeu_ps(out_ptr.add(index), r);
1413 index += 4;
1414 }
1415
1416 while index < len {
1417 *out_ptr.add(index) = (*in_ptr.add(index)).exp();
1418 index += 1;
1419 }
1420}
1421
1422#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1423#[allow(unsafe_code)]
1424#[allow(unsafe_op_in_unsafe_fn)]
1425#[target_feature(enable = "avx")]
1426unsafe fn exp_slice_avx(input: &[f32], output: &mut [f32]) {
1427 let len = input.len();
1428 let in_ptr = input.as_ptr();
1429 let out_ptr = output.as_mut_ptr();
1430 let mut index = 0usize;
1431
1432 while index + 16 <= len {
1434 #[cfg(target_arch = "x86")]
1436 {
1437 use std::arch::x86::_mm_prefetch;
1438 _mm_prefetch::<3>(in_ptr.add(index + 16) as *const i8);
1439 }
1440 #[cfg(target_arch = "x86_64")]
1441 {
1442 use std::arch::x86_64::_mm_prefetch;
1443 _mm_prefetch::<3>(in_ptr.add(index + 16) as *const i8);
1444 }
1445 let v0 = _mm256_loadu_ps(in_ptr.add(index));
1446 let v1 = _mm256_loadu_ps(in_ptr.add(index + 8));
1447 let r0 = fast_exp_avx(v0);
1448 let r1 = fast_exp_avx(v1);
1449 _mm256_storeu_ps(out_ptr.add(index), r0);
1450 _mm256_storeu_ps(out_ptr.add(index + 8), r1);
1451 index += 16;
1452 }
1453
1454 while index + 8 <= len {
1456 let v = _mm256_loadu_ps(in_ptr.add(index));
1457 let r = fast_exp_avx(v);
1458 _mm256_storeu_ps(out_ptr.add(index), r);
1459 index += 8;
1460 }
1461
1462 if index < len {
1463 exp_slice_sse(&input[index..], &mut output[index..]);
1464 }
1465}
1466
1467#[cfg(target_arch = "aarch64")]
1468#[allow(unsafe_code, dead_code)]
1469#[allow(unsafe_op_in_unsafe_fn)]
1470#[target_feature(enable = "neon")]
1471unsafe fn exp_slice_neon(input: &[f32], output: &mut [f32]) {
1472 let len = input.len();
1473 let in_ptr = input.as_ptr();
1474 let out_ptr = output.as_mut_ptr();
1475 let mut index = 0usize;
1476
1477 while index + 4 <= len {
1478 let v = vld1q_f32(in_ptr.add(index));
1479 let r = fast_exp_neon(v);
1480 vst1q_f32(out_ptr.add(index), r);
1481 index += 4;
1482 }
1483
1484 while index < len {
1485 *out_ptr.add(index) = (*in_ptr.add(index)).exp();
1486 index += 1;
1487 }
1488}
1489
1490#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1495#[allow(unsafe_code)]
1496#[allow(unsafe_op_in_unsafe_fn)]
1497#[target_feature(enable = "sse")]
1498unsafe fn sub_exp_slice_sse(input: &[f32], offset: f32, output: &mut [f32]) {
1499 let len = input.len();
1500 let in_ptr = input.as_ptr();
1501 let out_ptr = output.as_mut_ptr();
1502 let off = _mm_set1_ps(offset);
1503 let mut index = 0usize;
1504
1505 while index + 4 <= len {
1506 let v = _mm_loadu_ps(in_ptr.add(index));
1507 let shifted = _mm_sub_ps(v, off);
1508 let r = fast_exp_sse(shifted);
1509 _mm_storeu_ps(out_ptr.add(index), r);
1510 index += 4;
1511 }
1512
1513 while index < len {
1514 *out_ptr.add(index) = (*in_ptr.add(index) - offset).exp();
1515 index += 1;
1516 }
1517}
1518
1519#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1520#[allow(unsafe_code)]
1521#[allow(unsafe_op_in_unsafe_fn)]
1522#[target_feature(enable = "avx")]
1523unsafe fn sub_exp_slice_avx(input: &[f32], offset: f32, output: &mut [f32]) {
1524 let len = input.len();
1525 let in_ptr = input.as_ptr();
1526 let out_ptr = output.as_mut_ptr();
1527 let off = _mm256_set1_ps(offset);
1528 let mut index = 0usize;
1529
1530 while index + 16 <= len {
1532 #[cfg(target_arch = "x86")]
1533 {
1534 use std::arch::x86::_mm_prefetch;
1535 _mm_prefetch::<3>(in_ptr.add(index + 16) as *const i8);
1536 }
1537 #[cfg(target_arch = "x86_64")]
1538 {
1539 use std::arch::x86_64::_mm_prefetch;
1540 _mm_prefetch::<3>(in_ptr.add(index + 16) as *const i8);
1541 }
1542 let v0 = _mm256_loadu_ps(in_ptr.add(index));
1543 let v1 = _mm256_loadu_ps(in_ptr.add(index + 8));
1544 let shifted0 = _mm256_sub_ps(v0, off);
1545 let shifted1 = _mm256_sub_ps(v1, off);
1546 let r0 = fast_exp_avx(shifted0);
1547 let r1 = fast_exp_avx(shifted1);
1548 _mm256_storeu_ps(out_ptr.add(index), r0);
1549 _mm256_storeu_ps(out_ptr.add(index + 8), r1);
1550 index += 16;
1551 }
1552
1553 while index + 8 <= len {
1555 let v = _mm256_loadu_ps(in_ptr.add(index));
1556 let shifted = _mm256_sub_ps(v, off);
1557 let r = fast_exp_avx(shifted);
1558 _mm256_storeu_ps(out_ptr.add(index), r);
1559 index += 8;
1560 }
1561
1562 if index < len {
1563 sub_exp_slice_sse(&input[index..], offset, &mut output[index..]);
1564 }
1565}
1566
1567#[cfg(target_arch = "aarch64")]
1568#[allow(unsafe_code)]
1569#[allow(unsafe_op_in_unsafe_fn)]
1570#[target_feature(enable = "neon")]
1571unsafe fn sub_exp_slice_neon(input: &[f32], offset: f32, output: &mut [f32]) {
1572 let len = input.len();
1573 let in_ptr = input.as_ptr();
1574 let out_ptr = output.as_mut_ptr();
1575 let off = vdupq_n_f32(offset);
1576 let mut index = 0usize;
1577
1578 while index + 4 <= len {
1579 let v = vld1q_f32(in_ptr.add(index));
1580 let shifted = vsubq_f32(v, off);
1581 let r = fast_exp_neon(shifted);
1582 vst1q_f32(out_ptr.add(index), r);
1583 index += 4;
1584 }
1585
1586 while index < len {
1587 *out_ptr.add(index) = (*in_ptr.add(index) - offset).exp();
1588 index += 1;
1589 }
1590}
1591
1592#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1597#[allow(unsafe_code)]
1598#[allow(unsafe_op_in_unsafe_fn)]
1599#[target_feature(enable = "sse")]
1600unsafe fn sigmoid_slice_sse(input: &[f32], output: &mut [f32]) {
1601 #[cfg(target_arch = "x86")]
1602 use std::arch::x86::_mm_div_ps;
1603 #[cfg(target_arch = "x86_64")]
1604 use std::arch::x86_64::_mm_div_ps;
1605
1606 let len = input.len();
1607 let in_ptr = input.as_ptr();
1608 let out_ptr = output.as_mut_ptr();
1609 let one = _mm_set1_ps(1.0);
1610 let zero = _mm_setzero_ps();
1611 let mut index = 0usize;
1612
1613 while index + 16 <= len {
1615 let x0 = _mm_loadu_ps(in_ptr.add(index));
1616 let x1 = _mm_loadu_ps(in_ptr.add(index + 4));
1617 let x2 = _mm_loadu_ps(in_ptr.add(index + 8));
1618 let x3 = _mm_loadu_ps(in_ptr.add(index + 12));
1619
1620 let e0 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x0));
1622 let e1 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x1));
1623 let e2 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x2));
1624 let e3 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x3));
1625
1626 let r0 = _mm_div_ps(one, _mm_add_ps(one, e0));
1627 let r1 = _mm_div_ps(one, _mm_add_ps(one, e1));
1628 let r2 = _mm_div_ps(one, _mm_add_ps(one, e2));
1629 let r3 = _mm_div_ps(one, _mm_add_ps(one, e3));
1630
1631 _mm_storeu_ps(out_ptr.add(index), r0);
1632 _mm_storeu_ps(out_ptr.add(index + 4), r1);
1633 _mm_storeu_ps(out_ptr.add(index + 8), r2);
1634 _mm_storeu_ps(out_ptr.add(index + 12), r3);
1635
1636 index += 16;
1637 }
1638
1639 while index + 4 <= len {
1641 let x = _mm_loadu_ps(in_ptr.add(index));
1642 let neg_x = _mm_sub_ps(zero, x);
1643 let exp_neg_x = fast_exp_bittrick_sse(neg_x);
1644 let denom = _mm_add_ps(one, exp_neg_x);
1645 let result = _mm_div_ps(one, denom);
1646 _mm_storeu_ps(out_ptr.add(index), result);
1647 index += 4;
1648 }
1649
1650 while index < len {
1651 *out_ptr.add(index) = sigmoid_scalar(*in_ptr.add(index));
1652 index += 1;
1653 }
1654}
1655
1656#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1657#[allow(unsafe_code)]
1658#[allow(unsafe_op_in_unsafe_fn)]
1659#[target_feature(enable = "avx")]
1660unsafe fn sigmoid_slice_avx(input: &[f32], output: &mut [f32]) {
1661 #[cfg(target_arch = "x86")]
1662 use std::arch::x86::_mm256_div_ps;
1663 #[cfg(target_arch = "x86_64")]
1664 use std::arch::x86_64::_mm256_div_ps;
1665
1666 let len = input.len();
1667 let in_ptr = input.as_ptr();
1668 let out_ptr = output.as_mut_ptr();
1669 let one = _mm256_set1_ps(1.0);
1670 let zero = _mm256_setzero_ps();
1671 let mut index = 0usize;
1672
1673 while index + 32 <= len {
1675 let x0 = _mm256_loadu_ps(in_ptr.add(index));
1676 let x1 = _mm256_loadu_ps(in_ptr.add(index + 8));
1677 let x2 = _mm256_loadu_ps(in_ptr.add(index + 16));
1678 let x3 = _mm256_loadu_ps(in_ptr.add(index + 24));
1679
1680 let e0 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x0));
1682 let e1 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x1));
1683 let e2 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x2));
1684 let e3 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x3));
1685
1686 let r0 = _mm256_div_ps(one, _mm256_add_ps(one, e0));
1687 let r1 = _mm256_div_ps(one, _mm256_add_ps(one, e1));
1688 let r2 = _mm256_div_ps(one, _mm256_add_ps(one, e2));
1689 let r3 = _mm256_div_ps(one, _mm256_add_ps(one, e3));
1690
1691 _mm256_storeu_ps(out_ptr.add(index), r0);
1692 _mm256_storeu_ps(out_ptr.add(index + 8), r1);
1693 _mm256_storeu_ps(out_ptr.add(index + 16), r2);
1694 _mm256_storeu_ps(out_ptr.add(index + 24), r3);
1695
1696 index += 32;
1697 }
1698
1699 while index + 8 <= len {
1701 let x = _mm256_loadu_ps(in_ptr.add(index));
1702 let neg_x = _mm256_sub_ps(zero, x);
1703 let exp_neg_x = fast_exp_bittrick_avx(neg_x);
1704 let denom = _mm256_add_ps(one, exp_neg_x);
1705 let result = _mm256_div_ps(one, denom);
1706 _mm256_storeu_ps(out_ptr.add(index), result);
1707 index += 8;
1708 }
1709
1710 if index < len {
1711 sigmoid_slice_sse(&input[index..], &mut output[index..]);
1712 }
1713}
1714
1715#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1722#[allow(unsafe_code)]
1723#[allow(unsafe_op_in_unsafe_fn)]
1724#[target_feature(enable = "sse")]
1725unsafe fn tanh_slice_sse(input: &[f32], output: &mut [f32]) {
1726 let len = input.len();
1727 let in_ptr = input.as_ptr();
1728 let out_ptr = output.as_mut_ptr();
1729 let two = _mm_set1_ps(2.0);
1730 let one = _mm_set1_ps(1.0);
1731 let zero = _mm_setzero_ps();
1732 let mut index = 0usize;
1733
1734 while index + 4 <= len {
1735 let x = _mm_loadu_ps(in_ptr.add(index));
1736 let two_x = _mm_mul_ps(two, x);
1737 let neg_two_x = _mm_sub_ps(zero, two_x);
1739 let exp_neg = fast_exp_bittrick_sse(neg_two_x);
1741 let denom = _mm_add_ps(one, exp_neg);
1742 let sig = {
1743 #[cfg(target_arch = "x86")]
1744 use std::arch::x86::_mm_rcp_ps;
1745 #[cfg(target_arch = "x86_64")]
1746 use std::arch::x86_64::_mm_rcp_ps;
1747 let rcp = _mm_rcp_ps(denom);
1748 _mm_mul_ps(rcp, _mm_sub_ps(two, _mm_mul_ps(denom, rcp)))
1749 };
1750 let result = _mm_sub_ps(_mm_mul_ps(two, sig), one);
1752 _mm_storeu_ps(out_ptr.add(index), result);
1753 index += 4;
1754 }
1755
1756 while index < len {
1757 *out_ptr.add(index) = (*in_ptr.add(index)).tanh();
1758 index += 1;
1759 }
1760}
1761
1762#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1763#[allow(unsafe_code)]
1764#[allow(unsafe_op_in_unsafe_fn)]
1765#[target_feature(enable = "avx")]
1766unsafe fn tanh_slice_avx(input: &[f32], output: &mut [f32]) {
1767 let len = input.len();
1768 let in_ptr = input.as_ptr();
1769 let out_ptr = output.as_mut_ptr();
1770 let two = _mm256_set1_ps(2.0);
1771 let one = _mm256_set1_ps(1.0);
1772 let zero = _mm256_setzero_ps();
1773 let mut index = 0usize;
1774
1775 while index + 8 <= len {
1776 let x = _mm256_loadu_ps(in_ptr.add(index));
1777 let two_x = _mm256_mul_ps(two, x);
1778 let neg_two_x = _mm256_sub_ps(zero, two_x);
1779 let exp_neg = fast_exp_bittrick_avx(neg_two_x);
1781 let denom = _mm256_add_ps(one, exp_neg);
1782 let sig = {
1783 #[cfg(target_arch = "x86")]
1784 use std::arch::x86::_mm256_rcp_ps;
1785 #[cfg(target_arch = "x86_64")]
1786 use std::arch::x86_64::_mm256_rcp_ps;
1787 let rcp = _mm256_rcp_ps(denom);
1788 _mm256_mul_ps(rcp, _mm256_sub_ps(two, _mm256_mul_ps(denom, rcp)))
1789 };
1790 let result = _mm256_sub_ps(_mm256_mul_ps(two, sig), one);
1791 _mm256_storeu_ps(out_ptr.add(index), result);
1792 index += 8;
1793 }
1794
1795 if index < len {
1796 tanh_slice_sse(&input[index..], &mut output[index..]);
1797 }
1798}
1799
1800#[cfg(target_arch = "aarch64")]
1801#[allow(unsafe_code, dead_code)]
1802#[allow(unsafe_op_in_unsafe_fn)]
1803#[target_feature(enable = "neon")]
1804unsafe fn tanh_slice_neon(input: &[f32], output: &mut [f32]) {
1805 let len = input.len();
1806 let in_ptr = input.as_ptr();
1807 let out_ptr = output.as_mut_ptr();
1808 let two = vdupq_n_f32(2.0);
1809 let one = vdupq_n_f32(1.0);
1810 let mut index = 0usize;
1811
1812 while index + 32 <= len {
1814 let x0 = vld1q_f32(in_ptr.add(index));
1815 let x1 = vld1q_f32(in_ptr.add(index + 4));
1816 let x2 = vld1q_f32(in_ptr.add(index + 8));
1817 let x3 = vld1q_f32(in_ptr.add(index + 12));
1818 let x4 = vld1q_f32(in_ptr.add(index + 16));
1819 let x5 = vld1q_f32(in_ptr.add(index + 20));
1820 let x6 = vld1q_f32(in_ptr.add(index + 24));
1821 let x7 = vld1q_f32(in_ptr.add(index + 28));
1822
1823 let e0 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x0)));
1825 let e1 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x1)));
1826 let e2 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x2)));
1827 let e3 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x3)));
1828 let e4 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x4)));
1829 let e5 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x5)));
1830 let e6 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x6)));
1831 let e7 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x7)));
1832
1833 vst1q_f32(
1835 out_ptr.add(index),
1836 vsubq_f32(vdivq_f32(two, vaddq_f32(one, e0)), one),
1837 );
1838 vst1q_f32(
1839 out_ptr.add(index + 4),
1840 vsubq_f32(vdivq_f32(two, vaddq_f32(one, e1)), one),
1841 );
1842 vst1q_f32(
1843 out_ptr.add(index + 8),
1844 vsubq_f32(vdivq_f32(two, vaddq_f32(one, e2)), one),
1845 );
1846 vst1q_f32(
1847 out_ptr.add(index + 12),
1848 vsubq_f32(vdivq_f32(two, vaddq_f32(one, e3)), one),
1849 );
1850 vst1q_f32(
1851 out_ptr.add(index + 16),
1852 vsubq_f32(vdivq_f32(two, vaddq_f32(one, e4)), one),
1853 );
1854 vst1q_f32(
1855 out_ptr.add(index + 20),
1856 vsubq_f32(vdivq_f32(two, vaddq_f32(one, e5)), one),
1857 );
1858 vst1q_f32(
1859 out_ptr.add(index + 24),
1860 vsubq_f32(vdivq_f32(two, vaddq_f32(one, e6)), one),
1861 );
1862 vst1q_f32(
1863 out_ptr.add(index + 28),
1864 vsubq_f32(vdivq_f32(two, vaddq_f32(one, e7)), one),
1865 );
1866 index += 32;
1867 }
1868
1869 while index + 4 <= len {
1870 let x = vld1q_f32(in_ptr.add(index));
1871 let two_x = vmulq_f32(two, x);
1872 let neg_two_x = vnegq_f32(two_x);
1873 let exp_neg = fast_exp_sigmoid_neon(neg_two_x);
1874 let denom = vaddq_f32(one, exp_neg);
1875 let result = vsubq_f32(vdivq_f32(two, denom), one);
1876 vst1q_f32(out_ptr.add(index), result);
1877 index += 4;
1878 }
1879
1880 while index < len {
1881 *out_ptr.add(index) = (*in_ptr.add(index)).tanh();
1882 index += 1;
1883 }
1884}
1885
1886#[cfg(target_arch = "aarch64")]
1887#[allow(unsafe_code, dead_code)]
1888#[allow(unsafe_op_in_unsafe_fn)]
1889#[target_feature(enable = "neon")]
1890unsafe fn silu_slice_neon(input: &[f32], output: &mut [f32]) {
1893 let len = input.len();
1894 let in_ptr = input.as_ptr();
1895 let out_ptr = output.as_mut_ptr();
1896 let one = vdupq_n_f32(1.0);
1897 let mut index = 0usize;
1898
1899 while index + 32 <= len {
1901 let x0 = vld1q_f32(in_ptr.add(index));
1902 let x1 = vld1q_f32(in_ptr.add(index + 4));
1903 let x2 = vld1q_f32(in_ptr.add(index + 8));
1904 let x3 = vld1q_f32(in_ptr.add(index + 12));
1905 let x4 = vld1q_f32(in_ptr.add(index + 16));
1906 let x5 = vld1q_f32(in_ptr.add(index + 20));
1907 let x6 = vld1q_f32(in_ptr.add(index + 24));
1908 let x7 = vld1q_f32(in_ptr.add(index + 28));
1909
1910 let e0 = fast_exp_sigmoid_neon(vnegq_f32(x0));
1912 let e1 = fast_exp_sigmoid_neon(vnegq_f32(x1));
1913 let e2 = fast_exp_sigmoid_neon(vnegq_f32(x2));
1914 let e3 = fast_exp_sigmoid_neon(vnegq_f32(x3));
1915 let e4 = fast_exp_sigmoid_neon(vnegq_f32(x4));
1916 let e5 = fast_exp_sigmoid_neon(vnegq_f32(x5));
1917 let e6 = fast_exp_sigmoid_neon(vnegq_f32(x6));
1918 let e7 = fast_exp_sigmoid_neon(vnegq_f32(x7));
1919
1920 vst1q_f32(
1922 out_ptr.add(index),
1923 vmulq_f32(x0, vdivq_f32(one, vaddq_f32(one, e0))),
1924 );
1925 vst1q_f32(
1926 out_ptr.add(index + 4),
1927 vmulq_f32(x1, vdivq_f32(one, vaddq_f32(one, e1))),
1928 );
1929 vst1q_f32(
1930 out_ptr.add(index + 8),
1931 vmulq_f32(x2, vdivq_f32(one, vaddq_f32(one, e2))),
1932 );
1933 vst1q_f32(
1934 out_ptr.add(index + 12),
1935 vmulq_f32(x3, vdivq_f32(one, vaddq_f32(one, e3))),
1936 );
1937 vst1q_f32(
1938 out_ptr.add(index + 16),
1939 vmulq_f32(x4, vdivq_f32(one, vaddq_f32(one, e4))),
1940 );
1941 vst1q_f32(
1942 out_ptr.add(index + 20),
1943 vmulq_f32(x5, vdivq_f32(one, vaddq_f32(one, e5))),
1944 );
1945 vst1q_f32(
1946 out_ptr.add(index + 24),
1947 vmulq_f32(x6, vdivq_f32(one, vaddq_f32(one, e6))),
1948 );
1949 vst1q_f32(
1950 out_ptr.add(index + 28),
1951 vmulq_f32(x7, vdivq_f32(one, vaddq_f32(one, e7))),
1952 );
1953 index += 32;
1954 }
1955
1956 while index + 4 <= len {
1957 let x = vld1q_f32(in_ptr.add(index));
1958 let e = fast_exp_sigmoid_neon(vnegq_f32(x));
1959 let sig = vdivq_f32(one, vaddq_f32(one, e));
1960 vst1q_f32(out_ptr.add(index), vmulq_f32(x, sig));
1961 index += 4;
1962 }
1963
1964 while index < len {
1965 let x = *in_ptr.add(index);
1966 let s = 1.0 / (1.0 + (-x).exp());
1967 *out_ptr.add(index) = x * s;
1968 index += 1;
1969 }
1970}
1971
1972#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1974#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1975#[target_feature(enable = "sse")]
1976unsafe fn silu_slice_sse(input: &[f32], output: &mut [f32]) {
1977 #[cfg(target_arch = "x86")]
1978 use std::arch::x86::_mm_div_ps;
1979 #[cfg(target_arch = "x86_64")]
1980 use std::arch::x86_64::_mm_div_ps;
1981
1982 let len = input.len();
1983 let in_ptr = input.as_ptr();
1984 let out_ptr = output.as_mut_ptr();
1985 let one = _mm_set1_ps(1.0);
1986 let zero = _mm_setzero_ps();
1987 let mut index = 0usize;
1988
1989 while index + 16 <= len {
1990 let x0 = _mm_loadu_ps(in_ptr.add(index));
1991 let x1 = _mm_loadu_ps(in_ptr.add(index + 4));
1992 let x2 = _mm_loadu_ps(in_ptr.add(index + 8));
1993 let x3 = _mm_loadu_ps(in_ptr.add(index + 12));
1994
1995 let e0 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x0));
1997 let e1 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x1));
1998 let e2 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x2));
1999 let e3 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x3));
2000
2001 _mm_storeu_ps(
2003 out_ptr.add(index),
2004 _mm_mul_ps(x0, _mm_div_ps(one, _mm_add_ps(one, e0))),
2005 );
2006 _mm_storeu_ps(
2007 out_ptr.add(index + 4),
2008 _mm_mul_ps(x1, _mm_div_ps(one, _mm_add_ps(one, e1))),
2009 );
2010 _mm_storeu_ps(
2011 out_ptr.add(index + 8),
2012 _mm_mul_ps(x2, _mm_div_ps(one, _mm_add_ps(one, e2))),
2013 );
2014 _mm_storeu_ps(
2015 out_ptr.add(index + 12),
2016 _mm_mul_ps(x3, _mm_div_ps(one, _mm_add_ps(one, e3))),
2017 );
2018
2019 index += 16;
2020 }
2021
2022 while index + 4 <= len {
2023 let x = _mm_loadu_ps(in_ptr.add(index));
2024 let e = fast_exp_bittrick_sse(_mm_sub_ps(zero, x));
2025 let sig = _mm_div_ps(one, _mm_add_ps(one, e));
2026 _mm_storeu_ps(out_ptr.add(index), _mm_mul_ps(x, sig));
2027 index += 4;
2028 }
2029
2030 while index < len {
2031 let v = *in_ptr.add(index);
2032 let s = 1.0 / (1.0 + (-v).exp());
2033 *out_ptr.add(index) = v * s;
2034 index += 1;
2035 }
2036}
2037
2038#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2040#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
2041#[target_feature(enable = "avx")]
2042unsafe fn silu_slice_avx(input: &[f32], output: &mut [f32]) {
2043 #[cfg(target_arch = "x86")]
2044 use std::arch::x86::_mm256_div_ps;
2045 #[cfg(target_arch = "x86_64")]
2046 use std::arch::x86_64::_mm256_div_ps;
2047
2048 let len = input.len();
2049 let in_ptr = input.as_ptr();
2050 let out_ptr = output.as_mut_ptr();
2051 let one = _mm256_set1_ps(1.0);
2052 let zero = _mm256_setzero_ps();
2053 let mut index = 0usize;
2054
2055 while index + 32 <= len {
2056 let x0 = _mm256_loadu_ps(in_ptr.add(index));
2057 let x1 = _mm256_loadu_ps(in_ptr.add(index + 8));
2058 let x2 = _mm256_loadu_ps(in_ptr.add(index + 16));
2059 let x3 = _mm256_loadu_ps(in_ptr.add(index + 24));
2060
2061 let e0 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x0));
2063 let e1 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x1));
2064 let e2 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x2));
2065 let e3 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x3));
2066
2067 _mm256_storeu_ps(
2069 out_ptr.add(index),
2070 _mm256_mul_ps(x0, _mm256_div_ps(one, _mm256_add_ps(one, e0))),
2071 );
2072 _mm256_storeu_ps(
2073 out_ptr.add(index + 8),
2074 _mm256_mul_ps(x1, _mm256_div_ps(one, _mm256_add_ps(one, e1))),
2075 );
2076 _mm256_storeu_ps(
2077 out_ptr.add(index + 16),
2078 _mm256_mul_ps(x2, _mm256_div_ps(one, _mm256_add_ps(one, e2))),
2079 );
2080 _mm256_storeu_ps(
2081 out_ptr.add(index + 24),
2082 _mm256_mul_ps(x3, _mm256_div_ps(one, _mm256_add_ps(one, e3))),
2083 );
2084
2085 index += 32;
2086 }
2087
2088 while index + 8 <= len {
2089 let x = _mm256_loadu_ps(in_ptr.add(index));
2090 let e = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x));
2091 let sig = _mm256_div_ps(one, _mm256_add_ps(one, e));
2092 _mm256_storeu_ps(out_ptr.add(index), _mm256_mul_ps(x, sig));
2093 index += 8;
2094 }
2095
2096 if index < len {
2097 silu_slice_sse(&input[index..], &mut output[index..]);
2098 }
2099}
2100
2101#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2106#[allow(unsafe_code)]
2107#[allow(unsafe_op_in_unsafe_fn)]
2108#[target_feature(enable = "sse")]
2109unsafe fn max_reduce_sse(data: &[f32]) -> f32 {
2110 let len = data.len();
2111 let ptr = data.as_ptr();
2112 let mut index = 0usize;
2113 let mut acc = _mm_set1_ps(f32::NEG_INFINITY);
2114
2115 while index + 4 <= len {
2116 let v = _mm_loadu_ps(ptr.add(index));
2117 acc = _mm_max_ps(acc, v);
2118 index += 4;
2119 }
2120
2121 let mut buf = [0.0f32; 4];
2123 _mm_storeu_ps(buf.as_mut_ptr(), acc);
2124 let mut result = buf[0].max(buf[1]).max(buf[2]).max(buf[3]);
2125
2126 while index < len {
2127 result = result.max(*ptr.add(index));
2128 index += 1;
2129 }
2130 result
2131}
2132
2133#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2134#[allow(unsafe_code)]
2135#[allow(unsafe_op_in_unsafe_fn)]
2136#[target_feature(enable = "avx")]
2137unsafe fn max_reduce_avx(data: &[f32]) -> f32 {
2138 let len = data.len();
2139 let ptr = data.as_ptr();
2140 let mut index = 0usize;
2141 let mut acc = _mm256_set1_ps(f32::NEG_INFINITY);
2142
2143 while index + 8 <= len {
2144 let v = _mm256_loadu_ps(ptr.add(index));
2145 acc = _mm256_max_ps(acc, v);
2146 index += 8;
2147 }
2148
2149 let mut buf = [0.0f32; 8];
2151 _mm256_storeu_ps(buf.as_mut_ptr(), acc);
2152 let mut result = buf[0];
2153 for i in 1..8 {
2154 result = result.max(buf[i]);
2155 }
2156
2157 while index < len {
2158 result = result.max(*ptr.add(index));
2159 index += 1;
2160 }
2161 result
2162}
2163
2164#[cfg(target_arch = "aarch64")]
2165#[allow(unsafe_code, dead_code)]
2166#[allow(unsafe_op_in_unsafe_fn)]
2167#[target_feature(enable = "neon")]
2168unsafe fn max_reduce_neon(data: &[f32]) -> f32 {
2169 use std::arch::aarch64::vmaxvq_f32;
2170
2171 let len = data.len();
2172 let ptr = data.as_ptr();
2173 let mut index = 0usize;
2174 let mut acc = vdupq_n_f32(f32::NEG_INFINITY);
2175
2176 while index + 4 <= len {
2177 let v = vld1q_f32(ptr.add(index));
2178 acc = vmaxq_f32(acc, v);
2179 index += 4;
2180 }
2181
2182 let mut result = vmaxvq_f32(acc);
2183 while index < len {
2184 result = result.max(*ptr.add(index));
2185 index += 1;
2186 }
2187 result
2188}
2189
2190#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2195#[allow(unsafe_code)]
2196#[allow(unsafe_op_in_unsafe_fn)]
2197#[target_feature(enable = "sse")]
2198unsafe fn add_reduce_sse(data: &[f32]) -> f32 {
2199 let len = data.len();
2200 let ptr = data.as_ptr();
2201 let mut index = 0usize;
2202 let mut acc = _mm_setzero_ps();
2203
2204 while index + 4 <= len {
2205 let v = _mm_loadu_ps(ptr.add(index));
2206 acc = _mm_add_ps(acc, v);
2207 index += 4;
2208 }
2209
2210 let mut buf = [0.0f32; 4];
2212 _mm_storeu_ps(buf.as_mut_ptr(), acc);
2213 let mut result = buf[0] + buf[1] + buf[2] + buf[3];
2214
2215 while index < len {
2216 result += *ptr.add(index);
2217 index += 1;
2218 }
2219 result
2220}
2221
2222#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2223#[allow(unsafe_code)]
2224#[allow(unsafe_op_in_unsafe_fn)]
2225#[target_feature(enable = "avx")]
2226unsafe fn add_reduce_avx(data: &[f32]) -> f32 {
2227 let len = data.len();
2228 let ptr = data.as_ptr();
2229 let mut index = 0usize;
2230 let mut acc = _mm256_setzero_ps();
2231
2232 while index + 8 <= len {
2233 let v = _mm256_loadu_ps(ptr.add(index));
2234 acc = _mm256_add_ps(acc, v);
2235 index += 8;
2236 }
2237
2238 let mut buf = [0.0f32; 8];
2239 _mm256_storeu_ps(buf.as_mut_ptr(), acc);
2240 let mut result = buf[0] + buf[1] + buf[2] + buf[3] + buf[4] + buf[5] + buf[6] + buf[7];
2241
2242 while index < len {
2243 result += *ptr.add(index);
2244 index += 1;
2245 }
2246 result
2247}
2248
2249#[cfg(target_arch = "aarch64")]
2250#[allow(unsafe_code, dead_code)]
2251#[allow(unsafe_op_in_unsafe_fn)]
2252#[target_feature(enable = "neon")]
2253unsafe fn add_reduce_neon(data: &[f32]) -> f32 {
2254 use std::arch::aarch64::vaddvq_f32;
2255
2256 let len = data.len();
2257 let ptr = data.as_ptr();
2258 let mut index = 0usize;
2259 let mut acc = vdupq_n_f32(0.0);
2260
2261 while index + 4 <= len {
2262 let v = vld1q_f32(ptr.add(index));
2263 acc = vaddq_f32(acc, v);
2264 index += 4;
2265 }
2266
2267 let mut result = vaddvq_f32(acc);
2268 while index < len {
2269 result += *ptr.add(index);
2270 index += 1;
2271 }
2272 result
2273}
2274
2275#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2280#[allow(unsafe_code)]
2281#[allow(unsafe_op_in_unsafe_fn)]
2282#[target_feature(enable = "sse")]
2283unsafe fn fma_slice_sse(a: &[f32], b: &[f32], acc: &mut [f32]) {
2284 let len = a.len();
2285 let a_ptr = a.as_ptr();
2286 let b_ptr = b.as_ptr();
2287 let acc_ptr = acc.as_mut_ptr();
2288 let mut index = 0usize;
2289
2290 while index + 4 <= len {
2291 let av = _mm_loadu_ps(a_ptr.add(index));
2292 let bv = _mm_loadu_ps(b_ptr.add(index));
2293 let cv = _mm_loadu_ps(acc_ptr.add(index));
2294 let result = _mm_add_ps(cv, _mm_mul_ps(av, bv));
2295 _mm_storeu_ps(acc_ptr.add(index), result);
2296 index += 4;
2297 }
2298
2299 if index < len {
2300 fma_slice_scalar(&a[index..], &b[index..], &mut acc[index..]);
2301 }
2302}
2303
2304#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2305#[allow(unsafe_code)]
2306#[allow(unsafe_op_in_unsafe_fn)]
2307#[target_feature(enable = "avx")]
2308unsafe fn fma_slice_avx(a: &[f32], b: &[f32], acc: &mut [f32]) {
2309 let len = a.len();
2310 let a_ptr = a.as_ptr();
2311 let b_ptr = b.as_ptr();
2312 let acc_ptr = acc.as_mut_ptr();
2313 let mut index = 0usize;
2314
2315 while index + 8 <= len {
2316 let av = _mm256_loadu_ps(a_ptr.add(index));
2317 let bv = _mm256_loadu_ps(b_ptr.add(index));
2318 let cv = _mm256_loadu_ps(acc_ptr.add(index));
2319 let result = _mm256_add_ps(cv, _mm256_mul_ps(av, bv));
2320 _mm256_storeu_ps(acc_ptr.add(index), result);
2321 index += 8;
2322 }
2323
2324 if index < len {
2325 fma_slice_sse(&a[index..], &b[index..], &mut acc[index..]);
2326 }
2327}
2328
2329#[cfg(target_arch = "aarch64")]
2330#[allow(unsafe_code, dead_code)]
2331#[allow(unsafe_op_in_unsafe_fn)]
2332#[target_feature(enable = "neon")]
2333unsafe fn fma_slice_neon(a: &[f32], b: &[f32], acc: &mut [f32]) {
2334 let len = a.len();
2335 let a_ptr = a.as_ptr();
2336 let b_ptr = b.as_ptr();
2337 let acc_ptr = acc.as_mut_ptr();
2338 let mut index = 0usize;
2339
2340 while index + 4 <= len {
2341 let av = vld1q_f32(a_ptr.add(index));
2342 let bv = vld1q_f32(b_ptr.add(index));
2343 let cv = vld1q_f32(acc_ptr.add(index));
2344 let result = vfmaq_f32(cv, av, bv);
2345 vst1q_f32(acc_ptr.add(index), result);
2346 index += 4;
2347 }
2348
2349 if index < len {
2350 fma_slice_scalar(&a[index..], &b[index..], &mut acc[index..]);
2351 }
2352}
2353
2354#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2359#[allow(unsafe_code)]
2360#[allow(unsafe_op_in_unsafe_fn)]
2361#[target_feature(enable = "sse")]
2362unsafe fn relu_slice_sse(values: &mut [f32]) {
2363 let len = values.len();
2364 let ptr = values.as_mut_ptr();
2365 let zero = _mm_setzero_ps();
2366 let mut index = 0usize;
2367
2368 while index + 4 <= len {
2369 let input = _mm_loadu_ps(ptr.add(index));
2370 let out = _mm_max_ps(input, zero);
2371 _mm_storeu_ps(ptr.add(index), out);
2372 index += 4;
2373 }
2374
2375 if index < len {
2376 relu_slice_scalar(&mut values[index..]);
2377 }
2378}
2379
2380#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2381#[allow(unsafe_code)]
2382#[allow(unsafe_op_in_unsafe_fn)]
2383#[target_feature(enable = "avx")]
2384unsafe fn relu_slice_avx(values: &mut [f32]) {
2385 let len = values.len();
2386 let ptr = values.as_mut_ptr();
2387 let zero = _mm256_setzero_ps();
2388 let mut index = 0usize;
2389
2390 while index + 32 <= len {
2392 let v0 = _mm256_max_ps(_mm256_loadu_ps(ptr.add(index)), zero);
2393 let v1 = _mm256_max_ps(_mm256_loadu_ps(ptr.add(index + 8)), zero);
2394 let v2 = _mm256_max_ps(_mm256_loadu_ps(ptr.add(index + 16)), zero);
2395 let v3 = _mm256_max_ps(_mm256_loadu_ps(ptr.add(index + 24)), zero);
2396 _mm256_storeu_ps(ptr.add(index), v0);
2397 _mm256_storeu_ps(ptr.add(index + 8), v1);
2398 _mm256_storeu_ps(ptr.add(index + 16), v2);
2399 _mm256_storeu_ps(ptr.add(index + 24), v3);
2400 index += 32;
2401 }
2402
2403 while index + 8 <= len {
2404 _mm256_storeu_ps(
2405 ptr.add(index),
2406 _mm256_max_ps(_mm256_loadu_ps(ptr.add(index)), zero),
2407 );
2408 index += 8;
2409 }
2410
2411 if index < len {
2412 relu_slice_sse(&mut values[index..]);
2413 }
2414}
2415
2416#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2417#[allow(unsafe_code)]
2418#[allow(unsafe_op_in_unsafe_fn)]
2419#[target_feature(enable = "sse")]
2420unsafe fn binary_same_shape_sse(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
2421 let len = lhs.len();
2422 let left_ptr = lhs.as_ptr();
2423 let right_ptr = rhs.as_ptr();
2424 let out_ptr = out.as_mut_ptr();
2425 let mut index = 0usize;
2426
2427 while index + 4 <= len {
2428 let left = _mm_loadu_ps(left_ptr.add(index));
2429 let right = _mm_loadu_ps(right_ptr.add(index));
2430 let result = match kind {
2431 BinaryKind::Add => _mm_add_ps(left, right),
2432 BinaryKind::Sub => _mm_sub_ps(left, right),
2433 BinaryKind::Mul => _mm_mul_ps(left, right),
2434 };
2435 _mm_storeu_ps(out_ptr.add(index), result);
2436 index += 4;
2437 }
2438
2439 if index < len {
2440 binary_same_shape_scalar(&lhs[index..], &rhs[index..], &mut out[index..], kind);
2441 }
2442}
2443
2444#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2445#[allow(unsafe_code)]
2446#[allow(unsafe_op_in_unsafe_fn)]
2447#[target_feature(enable = "avx")]
2448unsafe fn binary_same_shape_avx(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
2449 let len = lhs.len();
2450 let left_ptr = lhs.as_ptr();
2451 let right_ptr = rhs.as_ptr();
2452 let out_ptr = out.as_mut_ptr();
2453 let mut index = 0usize;
2454
2455 match kind {
2458 BinaryKind::Add => {
2459 while index + 32 <= len {
2460 #[cfg(target_arch = "x86")]
2461 {
2462 use std::arch::x86::_mm_prefetch;
2463 _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2464 _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2465 }
2466 #[cfg(target_arch = "x86_64")]
2467 {
2468 use std::arch::x86_64::_mm_prefetch;
2469 _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2470 _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2471 }
2472 let a0 = _mm256_loadu_ps(left_ptr.add(index));
2473 let b0 = _mm256_loadu_ps(right_ptr.add(index));
2474 let a1 = _mm256_loadu_ps(left_ptr.add(index + 8));
2475 let b1 = _mm256_loadu_ps(right_ptr.add(index + 8));
2476 _mm256_storeu_ps(out_ptr.add(index), _mm256_add_ps(a0, b0));
2477 _mm256_storeu_ps(out_ptr.add(index + 8), _mm256_add_ps(a1, b1));
2478 let a2 = _mm256_loadu_ps(left_ptr.add(index + 16));
2479 let b2 = _mm256_loadu_ps(right_ptr.add(index + 16));
2480 let a3 = _mm256_loadu_ps(left_ptr.add(index + 24));
2481 let b3 = _mm256_loadu_ps(right_ptr.add(index + 24));
2482 _mm256_storeu_ps(out_ptr.add(index + 16), _mm256_add_ps(a2, b2));
2483 _mm256_storeu_ps(out_ptr.add(index + 24), _mm256_add_ps(a3, b3));
2484 index += 32;
2485 }
2486 }
2487 BinaryKind::Sub => {
2488 while index + 32 <= len {
2489 #[cfg(target_arch = "x86")]
2490 {
2491 use std::arch::x86::_mm_prefetch;
2492 _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2493 _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2494 }
2495 #[cfg(target_arch = "x86_64")]
2496 {
2497 use std::arch::x86_64::_mm_prefetch;
2498 _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2499 _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2500 }
2501 let a0 = _mm256_loadu_ps(left_ptr.add(index));
2502 let b0 = _mm256_loadu_ps(right_ptr.add(index));
2503 let a1 = _mm256_loadu_ps(left_ptr.add(index + 8));
2504 let b1 = _mm256_loadu_ps(right_ptr.add(index + 8));
2505 _mm256_storeu_ps(out_ptr.add(index), _mm256_sub_ps(a0, b0));
2506 _mm256_storeu_ps(out_ptr.add(index + 8), _mm256_sub_ps(a1, b1));
2507 let a2 = _mm256_loadu_ps(left_ptr.add(index + 16));
2508 let b2 = _mm256_loadu_ps(right_ptr.add(index + 16));
2509 let a3 = _mm256_loadu_ps(left_ptr.add(index + 24));
2510 let b3 = _mm256_loadu_ps(right_ptr.add(index + 24));
2511 _mm256_storeu_ps(out_ptr.add(index + 16), _mm256_sub_ps(a2, b2));
2512 _mm256_storeu_ps(out_ptr.add(index + 24), _mm256_sub_ps(a3, b3));
2513 index += 32;
2514 }
2515 }
2516 BinaryKind::Mul => {
2517 while index + 32 <= len {
2518 #[cfg(target_arch = "x86")]
2519 {
2520 use std::arch::x86::_mm_prefetch;
2521 _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2522 _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2523 }
2524 #[cfg(target_arch = "x86_64")]
2525 {
2526 use std::arch::x86_64::_mm_prefetch;
2527 _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2528 _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2529 }
2530 let a0 = _mm256_loadu_ps(left_ptr.add(index));
2531 let b0 = _mm256_loadu_ps(right_ptr.add(index));
2532 let a1 = _mm256_loadu_ps(left_ptr.add(index + 8));
2533 let b1 = _mm256_loadu_ps(right_ptr.add(index + 8));
2534 _mm256_storeu_ps(out_ptr.add(index), _mm256_mul_ps(a0, b0));
2535 _mm256_storeu_ps(out_ptr.add(index + 8), _mm256_mul_ps(a1, b1));
2536 let a2 = _mm256_loadu_ps(left_ptr.add(index + 16));
2537 let b2 = _mm256_loadu_ps(right_ptr.add(index + 16));
2538 let a3 = _mm256_loadu_ps(left_ptr.add(index + 24));
2539 let b3 = _mm256_loadu_ps(right_ptr.add(index + 24));
2540 _mm256_storeu_ps(out_ptr.add(index + 16), _mm256_mul_ps(a2, b2));
2541 _mm256_storeu_ps(out_ptr.add(index + 24), _mm256_mul_ps(a3, b3));
2542 index += 32;
2543 }
2544 }
2545 }
2546
2547 while index + 8 <= len {
2549 let left = _mm256_loadu_ps(left_ptr.add(index));
2550 let right = _mm256_loadu_ps(right_ptr.add(index));
2551 let result = match kind {
2552 BinaryKind::Add => _mm256_add_ps(left, right),
2553 BinaryKind::Sub => _mm256_sub_ps(left, right),
2554 BinaryKind::Mul => _mm256_mul_ps(left, right),
2555 };
2556 _mm256_storeu_ps(out_ptr.add(index), result);
2557 index += 8;
2558 }
2559
2560 if index < len {
2561 binary_same_shape_sse(&lhs[index..], &rhs[index..], &mut out[index..], kind);
2562 }
2563}
2564
2565#[cfg(target_arch = "aarch64")]
2566#[allow(unsafe_code)]
2567#[allow(unsafe_op_in_unsafe_fn)]
2568#[target_feature(enable = "neon")]
2569unsafe fn relu_slice_neon(values: &mut [f32]) {
2570 let len = values.len();
2571 let ptr = values.as_mut_ptr();
2572 let zero = vdupq_n_f32(0.0);
2573 let mut index = 0usize;
2574
2575 while index + 32 <= len {
2577 let v0 = vmaxq_f32(vld1q_f32(ptr.add(index)), zero);
2578 let v1 = vmaxq_f32(vld1q_f32(ptr.add(index + 4)), zero);
2579 let v2 = vmaxq_f32(vld1q_f32(ptr.add(index + 8)), zero);
2580 let v3 = vmaxq_f32(vld1q_f32(ptr.add(index + 12)), zero);
2581 let v4 = vmaxq_f32(vld1q_f32(ptr.add(index + 16)), zero);
2582 let v5 = vmaxq_f32(vld1q_f32(ptr.add(index + 20)), zero);
2583 let v6 = vmaxq_f32(vld1q_f32(ptr.add(index + 24)), zero);
2584 let v7 = vmaxq_f32(vld1q_f32(ptr.add(index + 28)), zero);
2585 vst1q_f32(ptr.add(index), v0);
2586 vst1q_f32(ptr.add(index + 4), v1);
2587 vst1q_f32(ptr.add(index + 8), v2);
2588 vst1q_f32(ptr.add(index + 12), v3);
2589 vst1q_f32(ptr.add(index + 16), v4);
2590 vst1q_f32(ptr.add(index + 20), v5);
2591 vst1q_f32(ptr.add(index + 24), v6);
2592 vst1q_f32(ptr.add(index + 28), v7);
2593 index += 32;
2594 }
2595
2596 while index + 4 <= len {
2597 vst1q_f32(ptr.add(index), vmaxq_f32(vld1q_f32(ptr.add(index)), zero));
2598 index += 4;
2599 }
2600
2601 if index < len {
2602 relu_slice_scalar(&mut values[index..]);
2603 }
2604}
2605
2606#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2611#[allow(unsafe_code)]
2612#[allow(unsafe_op_in_unsafe_fn)]
2613#[target_feature(enable = "sse")]
2614unsafe fn relu_to_slice_sse(input: &[f32], output: &mut [f32]) {
2615 let len = input.len();
2616 let in_ptr = input.as_ptr();
2617 let out_ptr = output.as_mut_ptr();
2618 let zero = _mm_setzero_ps();
2619 let mut index = 0usize;
2620
2621 while index + 4 <= len {
2622 let v = _mm_loadu_ps(in_ptr.add(index));
2623 let r = _mm_max_ps(v, zero);
2624 _mm_storeu_ps(out_ptr.add(index), r);
2625 index += 4;
2626 }
2627
2628 if index < len {
2629 relu_to_slice_scalar(&input[index..], &mut output[index..]);
2630 }
2631}
2632
2633#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2634#[allow(unsafe_code)]
2635#[allow(unsafe_op_in_unsafe_fn)]
2636#[target_feature(enable = "avx")]
2637unsafe fn relu_to_slice_avx(input: &[f32], output: &mut [f32]) {
2638 let len = input.len();
2639 let in_ptr = input.as_ptr();
2640 let out_ptr = output.as_mut_ptr();
2641 let zero = _mm256_setzero_ps();
2642 let mut index = 0usize;
2643
2644 while index + 32 <= len {
2646 let a0 = _mm256_loadu_ps(in_ptr.add(index));
2647 let a1 = _mm256_loadu_ps(in_ptr.add(index + 8));
2648 let a2 = _mm256_loadu_ps(in_ptr.add(index + 16));
2649 let a3 = _mm256_loadu_ps(in_ptr.add(index + 24));
2650 _mm256_storeu_ps(out_ptr.add(index), _mm256_max_ps(a0, zero));
2651 _mm256_storeu_ps(out_ptr.add(index + 8), _mm256_max_ps(a1, zero));
2652 _mm256_storeu_ps(out_ptr.add(index + 16), _mm256_max_ps(a2, zero));
2653 _mm256_storeu_ps(out_ptr.add(index + 24), _mm256_max_ps(a3, zero));
2654 index += 32;
2655 }
2656
2657 while index + 8 <= len {
2658 _mm256_storeu_ps(
2659 out_ptr.add(index),
2660 _mm256_max_ps(_mm256_loadu_ps(in_ptr.add(index)), zero),
2661 );
2662 index += 8;
2663 }
2664
2665 if index < len {
2666 relu_to_slice_sse(&input[index..], &mut output[index..]);
2667 }
2668}
2669
2670#[cfg(target_arch = "aarch64")]
2671#[allow(unsafe_code)]
2672#[allow(unsafe_op_in_unsafe_fn)]
2673#[target_feature(enable = "neon")]
2674unsafe fn relu_to_slice_neon(input: &[f32], output: &mut [f32]) {
2675 let len = input.len();
2676 let in_ptr = input.as_ptr();
2677 let out_ptr = output.as_mut_ptr();
2678 let zero = vdupq_n_f32(0.0);
2679 let mut index = 0usize;
2680
2681 while index + 32 <= len {
2683 let a0 = vld1q_f32(in_ptr.add(index));
2684 let a1 = vld1q_f32(in_ptr.add(index + 4));
2685 let a2 = vld1q_f32(in_ptr.add(index + 8));
2686 let a3 = vld1q_f32(in_ptr.add(index + 12));
2687 vst1q_f32(out_ptr.add(index), vmaxq_f32(a0, zero));
2688 vst1q_f32(out_ptr.add(index + 4), vmaxq_f32(a1, zero));
2689 let a4 = vld1q_f32(in_ptr.add(index + 16));
2690 let a5 = vld1q_f32(in_ptr.add(index + 20));
2691 vst1q_f32(out_ptr.add(index + 8), vmaxq_f32(a2, zero));
2692 vst1q_f32(out_ptr.add(index + 12), vmaxq_f32(a3, zero));
2693 let a6 = vld1q_f32(in_ptr.add(index + 24));
2694 let a7 = vld1q_f32(in_ptr.add(index + 28));
2695 vst1q_f32(out_ptr.add(index + 16), vmaxq_f32(a4, zero));
2696 vst1q_f32(out_ptr.add(index + 20), vmaxq_f32(a5, zero));
2697 vst1q_f32(out_ptr.add(index + 24), vmaxq_f32(a6, zero));
2698 vst1q_f32(out_ptr.add(index + 28), vmaxq_f32(a7, zero));
2699 index += 32;
2700 }
2701
2702 while index + 4 <= len {
2703 vst1q_f32(
2704 out_ptr.add(index),
2705 vmaxq_f32(vld1q_f32(in_ptr.add(index)), zero),
2706 );
2707 index += 4;
2708 }
2709
2710 if index < len {
2711 relu_to_slice_scalar(&input[index..], &mut output[index..]);
2712 }
2713}
2714
2715#[cfg(all(target_arch = "aarch64", not(target_os = "macos")))]
2716#[allow(unsafe_code)]
2717#[allow(unsafe_op_in_unsafe_fn)]
2718#[target_feature(enable = "neon")]
2719unsafe fn binary_same_shape_neon(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
2720 let len = lhs.len();
2721 let left_ptr = lhs.as_ptr();
2722 let right_ptr = rhs.as_ptr();
2723 let out_ptr = out.as_mut_ptr();
2724 let mut index = 0usize;
2725
2726 while index + 4 <= len {
2727 let left = vld1q_f32(left_ptr.add(index));
2728 let right = vld1q_f32(right_ptr.add(index));
2729 let result = match kind {
2730 BinaryKind::Add => vaddq_f32(left, right),
2731 BinaryKind::Sub => vsubq_f32(left, right),
2732 BinaryKind::Mul => vmulq_f32(left, right),
2733 };
2734 vst1q_f32(out_ptr.add(index), result);
2735 index += 4;
2736 }
2737
2738 if index < len {
2739 binary_same_shape_scalar(&lhs[index..], &rhs[index..], &mut out[index..], kind);
2740 }
2741}
2742
2743#[inline]
2762#[allow(unsafe_code)]
2763#[allow(unsafe_op_in_unsafe_fn)]
2764pub unsafe fn matmul_row_dispatch(
2765 left_row: *const f32,
2766 right: *const f32,
2767 out_row: *mut f32,
2768 k: usize,
2769 n: usize,
2770) {
2771 if cfg!(miri) {
2772 matmul_row_scalar(left_row, right, out_row, k, n);
2773 return;
2774 }
2775
2776 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2777 {
2778 if std::is_x86_feature_detected!("avx") {
2779 matmul_row_avx(left_row, right, out_row, k, n);
2780 return;
2781 }
2782 if std::is_x86_feature_detected!("sse") {
2783 matmul_row_sse(left_row, right, out_row, k, n);
2784 return;
2785 }
2786 }
2787
2788 #[cfg(target_arch = "aarch64")]
2789 {
2790 if std::arch::is_aarch64_feature_detected!("neon") {
2791 matmul_row_neon(left_row, right, out_row, k, n);
2792 return;
2793 }
2794 }
2795
2796 matmul_row_scalar(left_row, right, out_row, k, n);
2797}
2798
2799#[allow(unsafe_code)]
2801#[allow(unsafe_op_in_unsafe_fn)]
2802unsafe fn matmul_row_scalar(
2803 left_row: *const f32,
2804 right: *const f32,
2805 out_row: *mut f32,
2806 k: usize,
2807 n: usize,
2808) {
2809 for p in 0..k {
2810 let a_val = *left_row.add(p);
2811 let b_row = right.add(p * n);
2812
2813 let mut col = 0usize;
2814 while col + 4 <= n {
2815 *out_row.add(col) += a_val * *b_row.add(col);
2816 *out_row.add(col + 1) += a_val * *b_row.add(col + 1);
2817 *out_row.add(col + 2) += a_val * *b_row.add(col + 2);
2818 *out_row.add(col + 3) += a_val * *b_row.add(col + 3);
2819 col += 4;
2820 }
2821 while col < n {
2822 *out_row.add(col) += a_val * *b_row.add(col);
2823 col += 1;
2824 }
2825 }
2826}
2827
2828#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2829#[allow(unsafe_code)]
2830#[allow(unsafe_op_in_unsafe_fn)]
2831#[target_feature(enable = "sse")]
2832unsafe fn matmul_row_sse(
2833 left_row: *const f32,
2834 right: *const f32,
2835 out_row: *mut f32,
2836 k: usize,
2837 n: usize,
2838) {
2839 for p in 0..k {
2840 let a_val = _mm_set1_ps(*left_row.add(p));
2841 let b_row = right.add(p * n);
2842
2843 let mut col = 0usize;
2844 while col + 4 <= n {
2845 let b_vec = _mm_loadu_ps(b_row.add(col));
2846 let out_vec = _mm_loadu_ps(out_row.add(col));
2847 let result = _mm_add_ps(out_vec, _mm_mul_ps(a_val, b_vec));
2848 _mm_storeu_ps(out_row.add(col), result);
2849 col += 4;
2850 }
2851 while col < n {
2852 *out_row.add(col) += *left_row.add(p) * *b_row.add(col);
2853 col += 1;
2854 }
2855 }
2856}
2857
2858#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2859#[allow(unsafe_code)]
2860#[allow(unsafe_op_in_unsafe_fn)]
2861#[target_feature(enable = "avx")]
2862unsafe fn matmul_row_avx(
2863 left_row: *const f32,
2864 right: *const f32,
2865 out_row: *mut f32,
2866 k: usize,
2867 n: usize,
2868) {
2869 for p in 0..k {
2870 let a_val_avx = _mm256_set1_ps(*left_row.add(p));
2871 let a_val_sse = _mm_set1_ps(*left_row.add(p));
2872 let b_row = right.add(p * n);
2873
2874 let mut col = 0usize;
2875 while col + 8 <= n {
2876 let b_vec = _mm256_loadu_ps(b_row.add(col));
2877 let out_vec = _mm256_loadu_ps(out_row.add(col));
2878 let result = _mm256_add_ps(out_vec, _mm256_mul_ps(a_val_avx, b_vec));
2879 _mm256_storeu_ps(out_row.add(col), result);
2880 col += 8;
2881 }
2882 while col + 4 <= n {
2884 let b_vec = _mm_loadu_ps(b_row.add(col));
2885 let out_vec = _mm_loadu_ps(out_row.add(col));
2886 let result = _mm_add_ps(out_vec, _mm_mul_ps(a_val_sse, b_vec));
2887 _mm_storeu_ps(out_row.add(col), result);
2888 col += 4;
2889 }
2890 while col < n {
2891 *out_row.add(col) += *left_row.add(p) * *b_row.add(col);
2892 col += 1;
2893 }
2894 }
2895}
2896
2897#[cfg(target_arch = "aarch64")]
2898#[allow(unsafe_code)]
2899#[allow(unsafe_op_in_unsafe_fn)]
2900#[target_feature(enable = "neon")]
2901unsafe fn matmul_row_neon(
2902 left_row: *const f32,
2903 right: *const f32,
2904 out_row: *mut f32,
2905 k: usize,
2906 n: usize,
2907) {
2908 for p in 0..k {
2909 let a_val: float32x4_t = vdupq_n_f32(*left_row.add(p));
2910 let b_row = right.add(p * n);
2911
2912 let mut col = 0usize;
2913 while col + 4 <= n {
2914 let b_vec = vld1q_f32(b_row.add(col));
2915 let out_vec = vld1q_f32(out_row.add(col));
2916 let result = vfmaq_f32(out_vec, a_val, b_vec);
2917 vst1q_f32(out_row.add(col), result);
2918 col += 4;
2919 }
2920 while col < n {
2921 *out_row.add(col) += *left_row.add(p) * *b_row.add(col);
2922 col += 1;
2923 }
2924 }
2925}
2926
2927#[allow(unsafe_code)]
2936#[inline]
2937pub fn softmax_row_fused_dispatch(input: &[f32], output: &mut [f32]) {
2938 debug_assert_eq!(input.len(), output.len());
2939
2940 if cfg!(miri) || input.is_empty() {
2941 softmax_row_fused_scalar(input, output);
2942 return;
2943 }
2944
2945 #[cfg(target_arch = "aarch64")]
2946 {
2947 if std::arch::is_aarch64_feature_detected!("neon") {
2948 unsafe {
2950 softmax_row_fused_neon(input, output);
2951 }
2952 return;
2953 }
2954 }
2955
2956 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2957 {
2958 if std::is_x86_feature_detected!("avx") {
2959 unsafe {
2961 softmax_row_fused_avx(input, output);
2962 }
2963 return;
2964 }
2965 if std::is_x86_feature_detected!("sse") {
2966 unsafe {
2968 softmax_row_fused_sse(input, output);
2969 }
2970 return;
2971 }
2972 }
2973
2974 softmax_row_fused_scalar(input, output);
2975}
2976
2977fn softmax_row_fused_scalar(input: &[f32], output: &mut [f32]) {
2978 if input.is_empty() {
2979 return;
2980 }
2981
2982 let mut max_val = f32::NEG_INFINITY;
2984 for &v in input {
2985 max_val = max_val.max(v);
2986 }
2987
2988 let mut sum_exp = 0.0f32;
2990 for (o, &v) in output.iter_mut().zip(input.iter()) {
2991 let e = (v - max_val).exp();
2992 *o = e;
2993 sum_exp += e;
2994 }
2995
2996 let inv = 1.0 / sum_exp;
2998 for o in output.iter_mut() {
2999 *o *= inv;
3000 }
3001}
3002
3003#[cfg(target_arch = "aarch64")]
3004#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3005#[target_feature(enable = "neon")]
3006unsafe fn softmax_row_fused_neon(input: &[f32], output: &mut [f32]) {
3007 use std::arch::aarch64::{vaddvq_f32, vmaxvq_f32};
3008
3009 let len = input.len();
3010 let in_ptr = input.as_ptr();
3011 let out_ptr = output.as_mut_ptr();
3012
3013 let mut acc_max = vdupq_n_f32(f32::NEG_INFINITY);
3015 let mut i = 0usize;
3016 while i + 16 <= len {
3017 let v0 = vld1q_f32(in_ptr.add(i));
3018 let v1 = vld1q_f32(in_ptr.add(i + 4));
3019 let v2 = vld1q_f32(in_ptr.add(i + 8));
3020 let v3 = vld1q_f32(in_ptr.add(i + 12));
3021 acc_max = vmaxq_f32(acc_max, vmaxq_f32(vmaxq_f32(v0, v1), vmaxq_f32(v2, v3)));
3022 i += 16;
3023 }
3024 while i + 4 <= len {
3025 let v = vld1q_f32(in_ptr.add(i));
3026 acc_max = vmaxq_f32(acc_max, v);
3027 i += 4;
3028 }
3029 let mut max_val = vmaxvq_f32(acc_max);
3030 while i < len {
3031 max_val = max_val.max(*in_ptr.add(i));
3032 i += 1;
3033 }
3034
3035 let off = vdupq_n_f32(max_val);
3037 let mut acc_sum = vdupq_n_f32(0.0);
3038 i = 0;
3039 while i + 16 <= len {
3040 let v0 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i)), off));
3041 let v1 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 4)), off));
3042 let v2 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 8)), off));
3043 let v3 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 12)), off));
3044 vst1q_f32(out_ptr.add(i), v0);
3045 vst1q_f32(out_ptr.add(i + 4), v1);
3046 vst1q_f32(out_ptr.add(i + 8), v2);
3047 vst1q_f32(out_ptr.add(i + 12), v3);
3048 acc_sum = vaddq_f32(acc_sum, vaddq_f32(vaddq_f32(v0, v1), vaddq_f32(v2, v3)));
3049 i += 16;
3050 }
3051 while i + 4 <= len {
3052 let v = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i)), off));
3053 vst1q_f32(out_ptr.add(i), v);
3054 acc_sum = vaddq_f32(acc_sum, v);
3055 i += 4;
3056 }
3057 let mut sum_exp = vaddvq_f32(acc_sum);
3058 while i < len {
3059 let e = (*in_ptr.add(i) - max_val).exp();
3060 *out_ptr.add(i) = e;
3061 sum_exp += e;
3062 i += 1;
3063 }
3064
3065 let inv = vdupq_n_f32(1.0 / sum_exp);
3067 i = 0;
3068 while i + 16 <= len {
3069 vst1q_f32(out_ptr.add(i), vmulq_f32(vld1q_f32(out_ptr.add(i)), inv));
3070 vst1q_f32(
3071 out_ptr.add(i + 4),
3072 vmulq_f32(vld1q_f32(out_ptr.add(i + 4)), inv),
3073 );
3074 vst1q_f32(
3075 out_ptr.add(i + 8),
3076 vmulq_f32(vld1q_f32(out_ptr.add(i + 8)), inv),
3077 );
3078 vst1q_f32(
3079 out_ptr.add(i + 12),
3080 vmulq_f32(vld1q_f32(out_ptr.add(i + 12)), inv),
3081 );
3082 i += 16;
3083 }
3084 while i + 4 <= len {
3085 vst1q_f32(out_ptr.add(i), vmulq_f32(vld1q_f32(out_ptr.add(i)), inv));
3086 i += 4;
3087 }
3088 let inv_s = 1.0 / sum_exp;
3089 while i < len {
3090 *out_ptr.add(i) *= inv_s;
3091 i += 1;
3092 }
3093}
3094
3095#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3097#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3098#[target_feature(enable = "sse")]
3099unsafe fn softmax_row_fused_sse(input: &[f32], output: &mut [f32]) {
3100 let len = input.len();
3101 let in_ptr = input.as_ptr();
3102 let out_ptr = output.as_mut_ptr();
3103
3104 let mut acc_max = _mm_set1_ps(f32::NEG_INFINITY);
3106 let mut i = 0usize;
3107 while i + 4 <= len {
3108 acc_max = _mm_max_ps(acc_max, _mm_loadu_ps(in_ptr.add(i)));
3109 i += 4;
3110 }
3111 let mut buf = [0.0f32; 4];
3112 _mm_storeu_ps(buf.as_mut_ptr(), acc_max);
3113 let mut max_val = buf[0].max(buf[1]).max(buf[2].max(buf[3]));
3114 while i < len {
3115 max_val = max_val.max(*in_ptr.add(i));
3116 i += 1;
3117 }
3118
3119 let off = _mm_set1_ps(max_val);
3121 let mut acc_sum = _mm_setzero_ps();
3122 i = 0;
3123 while i + 4 <= len {
3124 let v = fast_exp_sse(_mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), off));
3125 _mm_storeu_ps(out_ptr.add(i), v);
3126 acc_sum = _mm_add_ps(acc_sum, v);
3127 i += 4;
3128 }
3129 _mm_storeu_ps(buf.as_mut_ptr(), acc_sum);
3130 let mut sum_exp = buf[0] + buf[1] + buf[2] + buf[3];
3131 while i < len {
3132 let e = (*in_ptr.add(i) - max_val).exp();
3133 *out_ptr.add(i) = e;
3134 sum_exp += e;
3135 i += 1;
3136 }
3137
3138 let inv = _mm_set1_ps(1.0 / sum_exp);
3140 i = 0;
3141 while i + 4 <= len {
3142 _mm_storeu_ps(
3143 out_ptr.add(i),
3144 _mm_mul_ps(_mm_loadu_ps(out_ptr.add(i)), inv),
3145 );
3146 i += 4;
3147 }
3148 let inv_s = 1.0 / sum_exp;
3149 while i < len {
3150 *out_ptr.add(i) *= inv_s;
3151 i += 1;
3152 }
3153}
3154
3155#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3157#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3158#[target_feature(enable = "avx")]
3159unsafe fn softmax_row_fused_avx(input: &[f32], output: &mut [f32]) {
3160 let len = input.len();
3161 let in_ptr = input.as_ptr();
3162 let out_ptr = output.as_mut_ptr();
3163
3164 let mut acc_max = _mm256_set1_ps(f32::NEG_INFINITY);
3166 let mut i = 0usize;
3167 while i + 8 <= len {
3168 acc_max = _mm256_max_ps(acc_max, _mm256_loadu_ps(in_ptr.add(i)));
3169 i += 8;
3170 }
3171 let mut buf8 = [0.0f32; 8];
3172 _mm256_storeu_ps(buf8.as_mut_ptr(), acc_max);
3173 let mut max_val = buf8[0];
3174 for &v in &buf8[1..] {
3175 max_val = max_val.max(v);
3176 }
3177 while i < len {
3178 max_val = max_val.max(*in_ptr.add(i));
3179 i += 1;
3180 }
3181
3182 let off = _mm256_set1_ps(max_val);
3184 let mut acc_sum = _mm256_setzero_ps();
3185 i = 0;
3186 while i + 8 <= len {
3187 let v = fast_exp_avx(_mm256_sub_ps(_mm256_loadu_ps(in_ptr.add(i)), off));
3188 _mm256_storeu_ps(out_ptr.add(i), v);
3189 acc_sum = _mm256_add_ps(acc_sum, v);
3190 i += 8;
3191 }
3192 _mm256_storeu_ps(buf8.as_mut_ptr(), acc_sum);
3193 let mut sum_exp: f32 = buf8.iter().sum();
3194 let off4 = _mm_set1_ps(max_val);
3196 while i + 4 <= len {
3197 let v = fast_exp_sse(_mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), off4));
3198 _mm_storeu_ps(out_ptr.add(i), v);
3199 let mut b4 = [0.0f32; 4];
3200 _mm_storeu_ps(b4.as_mut_ptr(), v);
3201 sum_exp += b4[0] + b4[1] + b4[2] + b4[3];
3202 i += 4;
3203 }
3204 while i < len {
3205 let e = (*in_ptr.add(i) - max_val).exp();
3206 *out_ptr.add(i) = e;
3207 sum_exp += e;
3208 i += 1;
3209 }
3210
3211 let inv8 = _mm256_set1_ps(1.0 / sum_exp);
3213 i = 0;
3214 while i + 8 <= len {
3215 _mm256_storeu_ps(
3216 out_ptr.add(i),
3217 _mm256_mul_ps(_mm256_loadu_ps(out_ptr.add(i)), inv8),
3218 );
3219 i += 8;
3220 }
3221 let inv4 = _mm_set1_ps(1.0 / sum_exp);
3222 while i + 4 <= len {
3223 _mm_storeu_ps(
3224 out_ptr.add(i),
3225 _mm_mul_ps(_mm_loadu_ps(out_ptr.add(i)), inv4),
3226 );
3227 i += 4;
3228 }
3229 let inv_s = 1.0 / sum_exp;
3230 while i < len {
3231 *out_ptr.add(i) *= inv_s;
3232 i += 1;
3233 }
3234}
3235
3236#[allow(unsafe_code)]
3241#[inline]
3242pub fn log_softmax_row_fused_dispatch(input: &[f32], output: &mut [f32]) {
3243 debug_assert_eq!(input.len(), output.len());
3244
3245 if cfg!(miri) || input.is_empty() {
3246 log_softmax_row_fused_scalar(input, output);
3247 return;
3248 }
3249
3250 #[cfg(target_arch = "aarch64")]
3251 {
3252 if std::arch::is_aarch64_feature_detected!("neon") {
3253 unsafe {
3255 log_softmax_row_fused_neon(input, output);
3256 }
3257 return;
3258 }
3259 }
3260
3261 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3262 {
3263 if std::is_x86_feature_detected!("avx") {
3264 unsafe {
3266 log_softmax_row_fused_avx(input, output);
3267 }
3268 return;
3269 }
3270 if std::is_x86_feature_detected!("sse") {
3271 unsafe {
3273 log_softmax_row_fused_sse(input, output);
3274 }
3275 return;
3276 }
3277 }
3278
3279 log_softmax_row_fused_scalar(input, output);
3280}
3281
3282fn log_softmax_row_fused_scalar(input: &[f32], output: &mut [f32]) {
3283 if input.is_empty() {
3284 return;
3285 }
3286
3287 let mut max_val = f32::NEG_INFINITY;
3289 for &v in input {
3290 max_val = max_val.max(v);
3291 }
3292
3293 let mut sum_exp = 0.0f32;
3295 for &v in input {
3296 sum_exp += (v - max_val).exp();
3297 }
3298
3299 let log_denom = max_val + sum_exp.ln();
3301 for (o, &v) in output.iter_mut().zip(input.iter()) {
3302 *o = v - log_denom;
3303 }
3304}
3305
3306#[cfg(target_arch = "aarch64")]
3307#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3308#[target_feature(enable = "neon")]
3309unsafe fn log_softmax_row_fused_neon(input: &[f32], output: &mut [f32]) {
3310 use std::arch::aarch64::{vaddvq_f32, vmaxvq_f32};
3311
3312 let len = input.len();
3313 let in_ptr = input.as_ptr();
3314 let out_ptr = output.as_mut_ptr();
3315
3316 let mut acc_max = vdupq_n_f32(f32::NEG_INFINITY);
3318 let mut i = 0usize;
3319 while i + 16 <= len {
3320 let v0 = vld1q_f32(in_ptr.add(i));
3321 let v1 = vld1q_f32(in_ptr.add(i + 4));
3322 let v2 = vld1q_f32(in_ptr.add(i + 8));
3323 let v3 = vld1q_f32(in_ptr.add(i + 12));
3324 acc_max = vmaxq_f32(acc_max, vmaxq_f32(vmaxq_f32(v0, v1), vmaxq_f32(v2, v3)));
3325 i += 16;
3326 }
3327 while i + 4 <= len {
3328 acc_max = vmaxq_f32(acc_max, vld1q_f32(in_ptr.add(i)));
3329 i += 4;
3330 }
3331 let mut max_val = vmaxvq_f32(acc_max);
3332 while i < len {
3333 max_val = max_val.max(*in_ptr.add(i));
3334 i += 1;
3335 }
3336
3337 let off = vdupq_n_f32(max_val);
3339 let mut acc_sum = vdupq_n_f32(0.0);
3340 i = 0;
3341 while i + 16 <= len {
3342 let e0 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i)), off));
3343 let e1 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 4)), off));
3344 let e2 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 8)), off));
3345 let e3 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 12)), off));
3346 acc_sum = vaddq_f32(acc_sum, vaddq_f32(vaddq_f32(e0, e1), vaddq_f32(e2, e3)));
3347 i += 16;
3348 }
3349 while i + 4 <= len {
3350 let e = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i)), off));
3351 acc_sum = vaddq_f32(acc_sum, e);
3352 i += 4;
3353 }
3354 let mut sum_exp = vaddvq_f32(acc_sum);
3355 while i < len {
3356 sum_exp += (*in_ptr.add(i) - max_val).exp();
3357 i += 1;
3358 }
3359
3360 let log_denom = vdupq_n_f32(max_val + sum_exp.ln());
3362 i = 0;
3363 while i + 16 <= len {
3364 vst1q_f32(
3365 out_ptr.add(i),
3366 vsubq_f32(vld1q_f32(in_ptr.add(i)), log_denom),
3367 );
3368 vst1q_f32(
3369 out_ptr.add(i + 4),
3370 vsubq_f32(vld1q_f32(in_ptr.add(i + 4)), log_denom),
3371 );
3372 vst1q_f32(
3373 out_ptr.add(i + 8),
3374 vsubq_f32(vld1q_f32(in_ptr.add(i + 8)), log_denom),
3375 );
3376 vst1q_f32(
3377 out_ptr.add(i + 12),
3378 vsubq_f32(vld1q_f32(in_ptr.add(i + 12)), log_denom),
3379 );
3380 i += 16;
3381 }
3382 while i + 4 <= len {
3383 vst1q_f32(
3384 out_ptr.add(i),
3385 vsubq_f32(vld1q_f32(in_ptr.add(i)), log_denom),
3386 );
3387 i += 4;
3388 }
3389 let log_denom_s = max_val + sum_exp.ln();
3390 while i < len {
3391 *out_ptr.add(i) = *in_ptr.add(i) - log_denom_s;
3392 i += 1;
3393 }
3394}
3395
3396#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3398#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3399#[target_feature(enable = "sse")]
3400unsafe fn log_softmax_row_fused_sse(input: &[f32], output: &mut [f32]) {
3401 let len = input.len();
3402 let in_ptr = input.as_ptr();
3403 let out_ptr = output.as_mut_ptr();
3404
3405 let mut acc_max = _mm_set1_ps(f32::NEG_INFINITY);
3407 let mut i = 0usize;
3408 while i + 4 <= len {
3409 acc_max = _mm_max_ps(acc_max, _mm_loadu_ps(in_ptr.add(i)));
3410 i += 4;
3411 }
3412 let mut buf = [0.0f32; 4];
3413 _mm_storeu_ps(buf.as_mut_ptr(), acc_max);
3414 let mut max_val = buf[0].max(buf[1]).max(buf[2].max(buf[3]));
3415 while i < len {
3416 max_val = max_val.max(*in_ptr.add(i));
3417 i += 1;
3418 }
3419
3420 let off = _mm_set1_ps(max_val);
3422 let mut acc_sum = _mm_setzero_ps();
3423 i = 0;
3424 while i + 4 <= len {
3425 let e = fast_exp_sse(_mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), off));
3426 acc_sum = _mm_add_ps(acc_sum, e);
3427 i += 4;
3428 }
3429 _mm_storeu_ps(buf.as_mut_ptr(), acc_sum);
3430 let mut sum_exp = buf[0] + buf[1] + buf[2] + buf[3];
3431 while i < len {
3432 sum_exp += (*in_ptr.add(i) - max_val).exp();
3433 i += 1;
3434 }
3435
3436 let log_denom = _mm_set1_ps(max_val + sum_exp.ln());
3438 i = 0;
3439 while i + 4 <= len {
3440 _mm_storeu_ps(
3441 out_ptr.add(i),
3442 _mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), log_denom),
3443 );
3444 i += 4;
3445 }
3446 let log_denom_s = max_val + sum_exp.ln();
3447 while i < len {
3448 *out_ptr.add(i) = *in_ptr.add(i) - log_denom_s;
3449 i += 1;
3450 }
3451}
3452
3453#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3455#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3456#[target_feature(enable = "avx")]
3457unsafe fn log_softmax_row_fused_avx(input: &[f32], output: &mut [f32]) {
3458 let len = input.len();
3459 let in_ptr = input.as_ptr();
3460 let out_ptr = output.as_mut_ptr();
3461
3462 let mut acc_max = _mm256_set1_ps(f32::NEG_INFINITY);
3464 let mut i = 0usize;
3465 while i + 8 <= len {
3466 acc_max = _mm256_max_ps(acc_max, _mm256_loadu_ps(in_ptr.add(i)));
3467 i += 8;
3468 }
3469 let mut buf8 = [0.0f32; 8];
3470 _mm256_storeu_ps(buf8.as_mut_ptr(), acc_max);
3471 let mut max_val = buf8[0];
3472 for &v in &buf8[1..] {
3473 max_val = max_val.max(v);
3474 }
3475 while i < len {
3476 max_val = max_val.max(*in_ptr.add(i));
3477 i += 1;
3478 }
3479
3480 let off = _mm256_set1_ps(max_val);
3482 let mut acc_sum = _mm256_setzero_ps();
3483 i = 0;
3484 while i + 8 <= len {
3485 let e = fast_exp_avx(_mm256_sub_ps(_mm256_loadu_ps(in_ptr.add(i)), off));
3486 acc_sum = _mm256_add_ps(acc_sum, e);
3487 i += 8;
3488 }
3489 _mm256_storeu_ps(buf8.as_mut_ptr(), acc_sum);
3490 let mut sum_exp: f32 = buf8.iter().sum();
3491 let off4 = _mm_set1_ps(max_val);
3493 while i + 4 <= len {
3494 let e = fast_exp_sse(_mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), off4));
3495 let mut b4 = [0.0f32; 4];
3496 _mm_storeu_ps(b4.as_mut_ptr(), e);
3497 sum_exp += b4[0] + b4[1] + b4[2] + b4[3];
3498 i += 4;
3499 }
3500 while i < len {
3501 sum_exp += (*in_ptr.add(i) - max_val).exp();
3502 i += 1;
3503 }
3504
3505 let log_denom_val = max_val + sum_exp.ln();
3507 let log_denom8 = _mm256_set1_ps(log_denom_val);
3508 i = 0;
3509 while i + 8 <= len {
3510 _mm256_storeu_ps(
3511 out_ptr.add(i),
3512 _mm256_sub_ps(_mm256_loadu_ps(in_ptr.add(i)), log_denom8),
3513 );
3514 i += 8;
3515 }
3516 let log_denom4 = _mm_set1_ps(log_denom_val);
3517 while i + 4 <= len {
3518 _mm_storeu_ps(
3519 out_ptr.add(i),
3520 _mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), log_denom4),
3521 );
3522 i += 4;
3523 }
3524 while i < len {
3525 *out_ptr.add(i) = *in_ptr.add(i) - log_denom_val;
3526 i += 1;
3527 }
3528}
3529
3530#[cfg(test)]
3535mod tests {
3536 use super::*;
3537
3538 fn assert_close(a: &[f32], b: &[f32], tol: f32) {
3539 assert_eq!(a.len(), b.len(), "length mismatch");
3540 for (i, (&x, &y)) in a.iter().zip(b.iter()).enumerate() {
3541 let d = (x - y).abs();
3542 assert!(d <= tol, "index {i}: {x} vs {y}, diff={d}, tolerance={tol}");
3543 }
3544 }
3545
3546 #[test]
3547 fn exp_matches_scalar() {
3548 let input: Vec<f32> = (-20..=20).map(|i| i as f32 * 0.5).collect();
3549 let mut simd_out = vec![0.0f32; input.len()];
3550 let mut scalar_out = vec![0.0f32; input.len()];
3551
3552 exp_slice_dispatch(&input, &mut simd_out);
3553 exp_slice_scalar(&input, &mut scalar_out);
3554
3555 for (i, (&s, &r)) in simd_out.iter().zip(scalar_out.iter()).enumerate() {
3557 let rel = if r.abs() > 1e-10 {
3558 (s - r).abs() / r.abs()
3559 } else {
3560 (s - r).abs()
3561 };
3562 assert!(
3563 rel < 1e-5,
3564 "exp mismatch at index {i}: simd={s}, scalar={r}, rel_err={rel}"
3565 );
3566 }
3567 }
3568
3569 #[test]
3570 fn sigmoid_dispatch_matches_scalar() {
3571 let input: Vec<f32> = (-30..=30).map(|i| i as f32 * 0.3).collect();
3572 let mut simd_out = vec![0.0f32; input.len()];
3573 let mut scalar_out = vec![0.0f32; input.len()];
3574
3575 sigmoid_slice_dispatch(&input, &mut simd_out);
3576 sigmoid_slice_dispatch_scalar(&input, &mut scalar_out);
3577
3578 assert_close(&simd_out, &scalar_out, 0.035);
3581 }
3582
3583 #[test]
3584 fn tanh_dispatch_matches_scalar() {
3585 let input: Vec<f32> = (-30..=30).map(|i| i as f32 * 0.3).collect();
3586 let mut simd_out = vec![0.0f32; input.len()];
3587 let mut scalar_out = vec![0.0f32; input.len()];
3588
3589 tanh_slice_dispatch(&input, &mut simd_out);
3590 tanh_slice_dispatch_scalar(&input, &mut scalar_out);
3591
3592 assert_close(&simd_out, &scalar_out, 2e-3);
3594 }
3595
3596 #[test]
3597 fn max_reduce_matches_scalar() {
3598 let data: Vec<f32> = (0..37).map(|i| (i as f32 * 0.7 - 12.0).sin()).collect();
3599 let simd_result = max_reduce_dispatch(&data);
3600 let scalar_result = max_reduce_scalar(&data);
3601 assert!((simd_result - scalar_result).abs() < 1e-6);
3602 }
3603
3604 #[test]
3605 fn max_reduce_empty() {
3606 assert_eq!(max_reduce_dispatch(&[]), f32::NEG_INFINITY);
3607 }
3608
3609 #[test]
3610 fn add_reduce_matches_scalar() {
3611 let data: Vec<f32> = (0..37).map(|i| i as f32 * 0.1).collect();
3612 let simd_result = add_reduce_dispatch(&data);
3613 let scalar_result = add_reduce_scalar(&data);
3614 assert!(
3615 (simd_result - scalar_result).abs() < 1e-3,
3616 "simd={simd_result}, scalar={scalar_result}"
3617 );
3618 }
3619
3620 #[test]
3621 fn add_reduce_empty() {
3622 assert_eq!(add_reduce_dispatch(&[]), 0.0);
3623 }
3624
3625 #[test]
3626 #[allow(unsafe_code)]
3627 fn fma_matches_scalar() {
3628 let a: Vec<f32> = (0..33).map(|i| i as f32 * 0.3).collect();
3629 let b: Vec<f32> = (0..33).map(|i| (i as f32 * 0.7).sin()).collect();
3630 let mut simd_acc = vec![1.0f32; 33];
3631 let mut scalar_acc = vec![1.0f32; 33];
3632
3633 fma_slice_dispatch(&a, &b, &mut simd_acc);
3634 unsafe { fma_slice_scalar(&a, &b, &mut scalar_acc) };
3635
3636 assert_close(&simd_acc, &scalar_acc, 1e-5);
3637 }
3638
3639 #[test]
3640 fn sigmoid_dispatch_boundary_values() {
3641 let input = vec![-100.0, -10.0, 0.0, 10.0, 100.0];
3643 let mut output = vec![0.0f32; 5];
3644 sigmoid_slice_dispatch(&input, &mut output);
3645
3646 assert!(
3648 output[0] < 0.01,
3649 "sigmoid(-100) should be near 0: {}",
3650 output[0]
3651 );
3652 assert!(
3653 (output[2] - 0.5).abs() < 0.01,
3654 "sigmoid(0) should be near 0.5: {}",
3655 output[2]
3656 );
3657 assert!(
3658 output[4] > 0.99,
3659 "sigmoid(100) should be near 1: {}",
3660 output[4]
3661 );
3662 }
3663
3664 #[test]
3665 fn tanh_dispatch_boundary_values() {
3666 let input = vec![-100.0, -1.0, 0.0, 1.0, 100.0];
3667 let mut output = vec![0.0f32; 5];
3668 tanh_slice_dispatch(&input, &mut output);
3669
3670 assert!(
3671 output[0] < -0.99,
3672 "tanh(-100) should be near -1: {}",
3673 output[0]
3674 );
3675 assert!(
3676 (output[2]).abs() < 0.01,
3677 "tanh(0) should be near 0: {}",
3678 output[2]
3679 );
3680 assert!(
3681 output[4] > 0.99,
3682 "tanh(100) should be near 1: {}",
3683 output[4]
3684 );
3685 }
3686}