Skip to main content

tract_linalg/
x86_64_fma.rs

1use crate::Ops;
2use crate::frame::element_wise::ElementWiseKer;
3use crate::frame::reduce::{MapReduceKer, ReduceKer};
4use crate::x86_64_fma::softmax::x86_64_avx512_softmax2_fastcompact_f16_64n;
5use crate::x86_64_fma::softmax::x86_64_fma_softmax2_fastcompact_f32_32n;
6
7pub mod mmm;
8
9pub mod act;
10pub mod act_f16;
11pub mod act_f16_fp16;
12pub mod by_scalar;
13pub mod erf;
14mod intel;
15pub mod max;
16pub mod panel_extract;
17pub mod rms_norm;
18pub mod softmax;
19
20const AVX2: fn() -> bool = || is_x86_feature_detected!("avx2");
21const FMA: fn() -> bool = || is_x86_feature_detected!("fma");
22const AVX512F: fn() -> bool = || is_x86_feature_detected!("avx512f");
23#[cfg(tract_avx512vnni)]
24const AVX512VNNI: fn() -> bool = || is_x86_feature_detected!("avx512vnni");
25
26tanh_impl!(f32, fma_tanh_f32, 8, 8, is_x86_feature_detected!("fma"));
27sigmoid_impl!(f32, fma_sigmoid_f32, 8, 8, is_x86_feature_detected!("fma"));
28
29// AVX-512 (zmm, 16-wide) variants. The assembly lives in x86_64/avx512/; the
30// main loop handles 64 lanes (4 zmm) per iteration with a 16-lane tail, so
31// nr()=16 (any multiple of 16 is safe).
32tanh_impl!(f32, avx512_tanh_f32, 16, 16, is_x86_feature_detected!("avx512f"));
33sigmoid_impl!(f32, avx512_sigmoid_f32, 16, 16, is_x86_feature_detected!("avx512f"));
34
35fn plug_avx2(_ops: &mut Ops) {}
36
37fn plug_fma(ops: &mut Ops) {
38    panel_extract::plug(ops);
39
40    ops.sigmoid_f32 = Box::new(|| fma_sigmoid_f32::ew());
41    ops.tanh_f32 = Box::new(|| fma_tanh_f32::ew());
42
43    ops.mul_by_scalar_f32 = Box::new(|| by_scalar::x86_64_avx_f32_mul_by_scalar_32n::ew());
44    ops.max_f32 = Box::new(|| max::x86_64_fma_max_f32_32n::red());
45    ops.softmax2_fastcompact_f32 = Box::new(|| x86_64_fma_softmax2_fastcompact_f32_32n::red());
46
47    log::info!("sigmoid_f32, tanh_f32: x86_64/fma activated");
48}
49
50/// On hosts that also support AVX-512_FP16 (Sapphire Rapids / Granite Rapids /
51/// later, and recent Xeon-D / consumer parts), upgrade the f16 element-wise
52/// kernels from the f32-roundtrip implementations in `act_f16.rs` to the
53/// native f16 implementations in `act_f16_fp16.rs` where the native path is
54/// actually faster on this uarch. We benched each op against its f32-roundtrip
55/// equivalent on Sapphire Rapids and only plug in the ones that win:
56///
57///   hardswish_f16:  8.71 → 31.6 Gelem/s  (3.62× native) — plug in
58///   leaky_relu_f16: 9.44 →  5.85 Gelem/s (0.62× native — regression) — keep
59///                   the f32-roundtrip version from act_f16.rs. The native
60///                   kernel exists in act_f16_fp16.rs for future revisits but
61///                   is not wired here.
62fn plug_avx512fp16(ops: &mut Ops) {
63    ops.hardswish_f16 = Box::new(|| act_f16_fp16::x86_64_avx512fp16_hardswish_f16_128n::ew());
64
65    log::info!("hardswish_f16: x86_64/avx512fp16 native activated");
66}
67
68fn plug_avx512f(ops: &mut Ops) {
69    ops.sigmoid_f32 = Box::new(|| avx512_sigmoid_f32::ew());
70    ops.tanh_f32 = Box::new(|| avx512_tanh_f32::ew());
71    ops.hardswish_f32 = Box::new(|| act::x86_64_avx512_hardswish_f32_64n::ew());
72    ops.leaky_relu_f32 = Box::new(|| act::x86_64_avx512_leaky_relu_f32_64n::ew());
73    ops.silu_f32 = Box::new(|| act::x86_64_avx512_silu_f32_16n::ew());
74    ops.gelu_f32 = Box::new(|| act::x86_64_avx512_gelu_f32_16n::ew());
75
76    ops.sigmoid_f16 = Box::new(|| act_f16::x86_64_avx512_sigmoid_f16_16n::ew());
77    ops.tanh_f16 = Box::new(|| act_f16::x86_64_avx512_tanh_f16_16n::ew());
78    ops.hardswish_f16 = Box::new(|| act_f16::x86_64_avx512_hardswish_f16_64n::ew());
79    ops.leaky_relu_f16 = Box::new(|| act_f16::x86_64_avx512_leaky_relu_f16_64n::ew());
80    ops.silu_f16 = Box::new(|| act_f16::x86_64_avx512_silu_f16_16n::ew());
81    ops.gelu_f16 = Box::new(|| act_f16::x86_64_avx512_gelu_f16_16n::ew());
82
83    ops.max_f32 = Box::new(|| max::x86_64_avx512_max_f32_64n::red());
84    ops.softmax2_fastcompact_f32 =
85        Box::new(|| softmax::x86_64_avx512_softmax2_fastcompact_f32_64n::red());
86    ops.softmax2_fastcompact_f16 = Box::new(|| x86_64_avx512_softmax2_fastcompact_f16_64n::red());
87
88    ops.erf_f32 = Box::new(|| erf::x86_64_avx512_erf_f32_64n::ew());
89
90    ops.rms_norm_f32 = Box::new(rms_norm::rms_norm_f32);
91
92    log::info!(
93        "sigmoid_f32, tanh_f32, hardswish_f32, leaky_relu_f32, \
94         silu_f32, gelu_f32, \
95         sigmoid_f16, tanh_f16, hardswish_f16, leaky_relu_f16, \
96         silu_f16, gelu_f16, \
97         max_f32, softmax2_fastcompact_f32, softmax2_fastcompact_f16, erf_f32, \
98         rms_norm_f32: x86_64/avx512f activated"
99    );
100}
101
102pub fn plug(ops: &mut Ops) {
103    mmm::plug(ops);
104    if is_x86_feature_detected!("avx2") {
105        plug_avx2(ops);
106        if is_x86_feature_detected!("fma") {
107            plug_fma(ops);
108            if is_x86_feature_detected!("avx512f") {
109                plug_avx512f(ops);
110                if is_x86_feature_detected!("avx512fp16") {
111                    plug_avx512fp16(ops);
112                }
113            }
114        }
115    }
116}