Skip to main content

tract_linalg/x86_64_fma/
act.rs

1// AVX-512 (zmm, 16-wide) element-wise activation kernels with no FMA
2// predecessor on x86: hardswish and leaky_relu. They mirror the aarch64 NEON
3// kernels (arm64simd_hardswish_f32_8n / arm64simd_leaky_relu_f32_8n) but use
4// 512-bit zmm registers, processing 64 f32 lanes per iteration. Validated
5// against the generic scalar reference via the *_frame_tests! macros.
6
7// hardswish(x) = x * relu6(x + 3) / 6
8//              = x * max(0, min(6, x + 3)) * (1/6)
9ew_impl_wrap!(
10    f32,
11    x86_64_avx512_hardswish_f32_64n,
12    64,
13    16,
14    (),
15    #[inline(never)]
16    fn run(buf: &mut [f32], _: ()) {
17        debug_assert!(buf.len() % Self::nr() == 0);
18        debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
19        if buf.is_empty() {
20            return;
21        }
22        unsafe { x86_64_avx512_hardswish_f32_64n_run(buf) }
23    }
24);
25
26#[target_feature(enable = "avx512f")]
27unsafe fn x86_64_avx512_hardswish_f32_64n_run(buf: &mut [f32]) {
28    unsafe {
29        let len = buf.len();
30        let ptr = buf.as_ptr();
31        std::arch::asm!("
32            vbroadcastss zmm0, xmm0          // 3.0
33            vbroadcastss zmm1, xmm1          // 6.0
34            vbroadcastss zmm2, xmm2          // 1/6
35            vpxord       zmm3, zmm3, zmm3    // 0.0
36            2:
37                vmovaps zmm4, [{ptr}]
38                vmovaps zmm5, [{ptr} + 64]
39                vmovaps zmm6, [{ptr} + 128]
40                vmovaps zmm7, [{ptr} + 192]
41
42                vaddps  zmm8,  zmm4, zmm0
43                vaddps  zmm9,  zmm5, zmm0
44                vaddps  zmm10, zmm6, zmm0
45                vaddps  zmm11, zmm7, zmm0
46
47                vminps  zmm8,  zmm8,  zmm1
48                vminps  zmm9,  zmm9,  zmm1
49                vminps  zmm10, zmm10, zmm1
50                vminps  zmm11, zmm11, zmm1
51
52                vmaxps  zmm8,  zmm8,  zmm3
53                vmaxps  zmm9,  zmm9,  zmm3
54                vmaxps  zmm10, zmm10, zmm3
55                vmaxps  zmm11, zmm11, zmm3
56
57                vmulps  zmm8,  zmm8,  zmm4
58                vmulps  zmm9,  zmm9,  zmm5
59                vmulps  zmm10, zmm10, zmm6
60                vmulps  zmm11, zmm11, zmm7
61
62                vmulps  zmm8,  zmm8,  zmm2
63                vmulps  zmm9,  zmm9,  zmm2
64                vmulps  zmm10, zmm10, zmm2
65                vmulps  zmm11, zmm11, zmm2
66
67                vmovaps [{ptr}],       zmm8
68                vmovaps [{ptr} + 64],  zmm9
69                vmovaps [{ptr} + 128], zmm10
70                vmovaps [{ptr} + 192], zmm11
71
72                add {ptr}, 256
73                sub {len}, 64
74                jnz 2b
75            ",
76        len = inout(reg) len => _,
77        ptr = inout(reg) ptr => _,
78        inout("xmm0") 3.0f32 => _,
79        inout("xmm1") 6.0f32 => _,
80        inout("xmm2") 1.0f32 / 6.0f32 => _,
81        out("zmm3") _,
82        out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
83        out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
84        );
85    }
86}
87
88#[cfg(test)]
89pub mod test_x86_64_avx512_hardswish_f32_64n {
90    use super::*;
91    hardswish_frame_tests!(
92        is_x86_feature_detected!("avx512f"),
93        f32,
94        x86_64_avx512_hardswish_f32_64n
95    );
96}
97
98// leaky_relu(x) = x > 0 ? x : alpha * x
99ew_impl_wrap!(
100    f32,
101    x86_64_avx512_leaky_relu_f32_64n,
102    64,
103    16,
104    f32,
105    #[inline(never)]
106    fn run(buf: &mut [f32], alpha: f32) {
107        debug_assert!(buf.len() % Self::nr() == 0);
108        debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
109        if buf.is_empty() {
110            return;
111        }
112        unsafe { x86_64_avx512_leaky_relu_f32_64n_run(buf, alpha) }
113    }
114);
115
116#[target_feature(enable = "avx512f")]
117unsafe fn x86_64_avx512_leaky_relu_f32_64n_run(buf: &mut [f32], alpha: f32) {
118    unsafe {
119        let len = buf.len();
120        let ptr = buf.as_ptr();
121        std::arch::asm!("
122            vbroadcastss zmm0, xmm0          // alpha
123            vpxord       zmm1, zmm1, zmm1    // 0.0
124            2:
125                vmovaps zmm4, [{ptr}]
126                vmovaps zmm5, [{ptr} + 64]
127                vmovaps zmm6, [{ptr} + 128]
128                vmovaps zmm7, [{ptr} + 192]
129
130                // alpha * x in zmm8..11
131                vmulps  zmm8,  zmm4, zmm0
132                vmulps  zmm9,  zmm5, zmm0
133                vmulps  zmm10, zmm6, zmm0
134                vmulps  zmm11, zmm7, zmm0
135
136                // mask = x > 0
137                vcmpps  k1, zmm4, zmm1, 14
138                vcmpps  k2, zmm5, zmm1, 14
139                vcmpps  k3, zmm6, zmm1, 14
140                vcmpps  k4, zmm7, zmm1, 14
141
142                // where x > 0, overwrite alpha*x with x
143                vmovaps zmm8{{k1}},  zmm4
144                vmovaps zmm9{{k2}},  zmm5
145                vmovaps zmm10{{k3}}, zmm6
146                vmovaps zmm11{{k4}}, zmm7
147
148                vmovaps [{ptr}],       zmm8
149                vmovaps [{ptr} + 64],  zmm9
150                vmovaps [{ptr} + 128], zmm10
151                vmovaps [{ptr} + 192], zmm11
152
153                add {ptr}, 256
154                sub {len}, 64
155                jnz 2b
156            ",
157        len = inout(reg) len => _,
158        ptr = inout(reg) ptr => _,
159        inout("xmm0") alpha => _,
160        out("zmm1") _,
161        out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
162        out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
163        out("k1") _, out("k2") _, out("k3") _, out("k4") _,
164        );
165    }
166}
167
168#[cfg(test)]
169pub mod test_x86_64_avx512_leaky_relu_f32_64n {
170    use super::*;
171    leaky_relu_frame_tests!(
172        is_x86_feature_detected!("avx512f"),
173        f32,
174        x86_64_avx512_leaky_relu_f32_64n
175    );
176}
177
178// SiLU(x) = x * sigmoid(x). Composed at the kernel level (mirrors arm64): save
179// the input chunk, run the AVX-512 sigmoid kernel in place, then multiply back
180// by the saved original. nr() and CHUNK (256) are multiples of 16 so the
181// sigmoid kernel always receives a 64-byte-aligned slice whose length is a
182// multiple of 16.
183ew_impl_wrap!(
184    f32,
185    x86_64_avx512_silu_f32_16n,
186    16,
187    16,
188    (),
189    #[inline(never)]
190    fn run(buf: &mut [f32], _: ()) {
191        debug_assert!(buf.len() % Self::nr() == 0);
192        debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
193        const CHUNK: usize = 256;
194        let mut scratch = [0f32; CHUNK];
195        let mut start = 0;
196        while start < buf.len() {
197            let end = (start + CHUNK).min(buf.len());
198            let chunk = &mut buf[start..end];
199            let n = chunk.len();
200            scratch[..n].copy_from_slice(chunk);
201            super::avx512_sigmoid_f32::run(chunk, ());
202            for i in 0..n {
203                chunk[i] *= scratch[i];
204            }
205            start = end;
206        }
207    }
208);
209
210#[cfg(test)]
211pub mod test_x86_64_avx512_silu_f32_16n {
212    use super::*;
213    silu_frame_tests!(is_x86_feature_detected!("avx512f"), f32, x86_64_avx512_silu_f32_16n);
214}
215
216// Tanh-form GELU (pow=3) matching tract's GeluApproximate:
217//   gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
218// Composed at the kernel level (mirrors arm64): save the original x, compute
219// the tanh argument in place, run the AVX-512 tanh kernel, then finish with the
220// 0.5 * x * (1 + tanh) combine.
221ew_impl_wrap!(
222    f32,
223    x86_64_avx512_gelu_f32_16n,
224    16,
225    16,
226    (),
227    #[inline(never)]
228    fn run(buf: &mut [f32], _: ()) {
229        debug_assert!(buf.len() % Self::nr() == 0);
230        debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
231        const SQRT_2_OVER_PI: f32 = 0.7978845608028654;
232        const COEF: f32 = 0.044715;
233        const CHUNK: usize = 256;
234        let mut scratch = [0f32; CHUNK];
235        let mut start = 0;
236        while start < buf.len() {
237            let end = (start + CHUNK).min(buf.len());
238            let chunk = &mut buf[start..end];
239            let n = chunk.len();
240            for i in 0..n {
241                let x = chunk[i];
242                scratch[i] = x;
243                chunk[i] = SQRT_2_OVER_PI * (x + COEF * x * x * x);
244            }
245            super::avx512_tanh_f32::run(chunk, ());
246            for i in 0..n {
247                chunk[i] = 0.5 * scratch[i] * (1.0 + chunk[i]);
248            }
249            start = end;
250        }
251    }
252);
253
254#[cfg(test)]
255pub mod test_x86_64_avx512_gelu_f32_16n {
256    use super::*;
257    gelu_frame_tests!(is_x86_feature_detected!("avx512f"), f32, x86_64_avx512_gelu_f32_16n);
258}