1ew_impl_wrap!(
10 f32,
11 x86_64_avx512_hardswish_f32_64n,
12 64,
13 16,
14 (),
15 #[inline(never)]
16 fn run(buf: &mut [f32], _: ()) {
17 debug_assert!(buf.len() % Self::nr() == 0);
18 debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
19 if buf.is_empty() {
20 return;
21 }
22 unsafe { x86_64_avx512_hardswish_f32_64n_run(buf) }
23 }
24);
25
26#[target_feature(enable = "avx512f")]
27unsafe fn x86_64_avx512_hardswish_f32_64n_run(buf: &mut [f32]) {
28 unsafe {
29 let len = buf.len();
30 let ptr = buf.as_ptr();
31 std::arch::asm!("
32 vbroadcastss zmm0, xmm0 // 3.0
33 vbroadcastss zmm1, xmm1 // 6.0
34 vbroadcastss zmm2, xmm2 // 1/6
35 vpxord zmm3, zmm3, zmm3 // 0.0
36 2:
37 vmovaps zmm4, [{ptr}]
38 vmovaps zmm5, [{ptr} + 64]
39 vmovaps zmm6, [{ptr} + 128]
40 vmovaps zmm7, [{ptr} + 192]
41
42 vaddps zmm8, zmm4, zmm0
43 vaddps zmm9, zmm5, zmm0
44 vaddps zmm10, zmm6, zmm0
45 vaddps zmm11, zmm7, zmm0
46
47 vminps zmm8, zmm8, zmm1
48 vminps zmm9, zmm9, zmm1
49 vminps zmm10, zmm10, zmm1
50 vminps zmm11, zmm11, zmm1
51
52 vmaxps zmm8, zmm8, zmm3
53 vmaxps zmm9, zmm9, zmm3
54 vmaxps zmm10, zmm10, zmm3
55 vmaxps zmm11, zmm11, zmm3
56
57 vmulps zmm8, zmm8, zmm4
58 vmulps zmm9, zmm9, zmm5
59 vmulps zmm10, zmm10, zmm6
60 vmulps zmm11, zmm11, zmm7
61
62 vmulps zmm8, zmm8, zmm2
63 vmulps zmm9, zmm9, zmm2
64 vmulps zmm10, zmm10, zmm2
65 vmulps zmm11, zmm11, zmm2
66
67 vmovaps [{ptr}], zmm8
68 vmovaps [{ptr} + 64], zmm9
69 vmovaps [{ptr} + 128], zmm10
70 vmovaps [{ptr} + 192], zmm11
71
72 add {ptr}, 256
73 sub {len}, 64
74 jnz 2b
75 ",
76 len = inout(reg) len => _,
77 ptr = inout(reg) ptr => _,
78 inout("xmm0") 3.0f32 => _,
79 inout("xmm1") 6.0f32 => _,
80 inout("xmm2") 1.0f32 / 6.0f32 => _,
81 out("zmm3") _,
82 out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
83 out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
84 );
85 }
86}
87
88#[cfg(test)]
89pub mod test_x86_64_avx512_hardswish_f32_64n {
90 use super::*;
91 hardswish_frame_tests!(
92 is_x86_feature_detected!("avx512f"),
93 f32,
94 x86_64_avx512_hardswish_f32_64n
95 );
96}
97
98ew_impl_wrap!(
100 f32,
101 x86_64_avx512_leaky_relu_f32_64n,
102 64,
103 16,
104 f32,
105 #[inline(never)]
106 fn run(buf: &mut [f32], alpha: f32) {
107 debug_assert!(buf.len() % Self::nr() == 0);
108 debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
109 if buf.is_empty() {
110 return;
111 }
112 unsafe { x86_64_avx512_leaky_relu_f32_64n_run(buf, alpha) }
113 }
114);
115
116#[target_feature(enable = "avx512f")]
117unsafe fn x86_64_avx512_leaky_relu_f32_64n_run(buf: &mut [f32], alpha: f32) {
118 unsafe {
119 let len = buf.len();
120 let ptr = buf.as_ptr();
121 std::arch::asm!("
122 vbroadcastss zmm0, xmm0 // alpha
123 vpxord zmm1, zmm1, zmm1 // 0.0
124 2:
125 vmovaps zmm4, [{ptr}]
126 vmovaps zmm5, [{ptr} + 64]
127 vmovaps zmm6, [{ptr} + 128]
128 vmovaps zmm7, [{ptr} + 192]
129
130 // alpha * x in zmm8..11
131 vmulps zmm8, zmm4, zmm0
132 vmulps zmm9, zmm5, zmm0
133 vmulps zmm10, zmm6, zmm0
134 vmulps zmm11, zmm7, zmm0
135
136 // mask = x > 0
137 vcmpps k1, zmm4, zmm1, 14
138 vcmpps k2, zmm5, zmm1, 14
139 vcmpps k3, zmm6, zmm1, 14
140 vcmpps k4, zmm7, zmm1, 14
141
142 // where x > 0, overwrite alpha*x with x
143 vmovaps zmm8{{k1}}, zmm4
144 vmovaps zmm9{{k2}}, zmm5
145 vmovaps zmm10{{k3}}, zmm6
146 vmovaps zmm11{{k4}}, zmm7
147
148 vmovaps [{ptr}], zmm8
149 vmovaps [{ptr} + 64], zmm9
150 vmovaps [{ptr} + 128], zmm10
151 vmovaps [{ptr} + 192], zmm11
152
153 add {ptr}, 256
154 sub {len}, 64
155 jnz 2b
156 ",
157 len = inout(reg) len => _,
158 ptr = inout(reg) ptr => _,
159 inout("xmm0") alpha => _,
160 out("zmm1") _,
161 out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
162 out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
163 out("k1") _, out("k2") _, out("k3") _, out("k4") _,
164 );
165 }
166}
167
168#[cfg(test)]
169pub mod test_x86_64_avx512_leaky_relu_f32_64n {
170 use super::*;
171 leaky_relu_frame_tests!(
172 is_x86_feature_detected!("avx512f"),
173 f32,
174 x86_64_avx512_leaky_relu_f32_64n
175 );
176}
177
178ew_impl_wrap!(
184 f32,
185 x86_64_avx512_silu_f32_16n,
186 16,
187 16,
188 (),
189 #[inline(never)]
190 fn run(buf: &mut [f32], _: ()) {
191 debug_assert!(buf.len() % Self::nr() == 0);
192 debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
193 const CHUNK: usize = 256;
194 let mut scratch = [0f32; CHUNK];
195 let mut start = 0;
196 while start < buf.len() {
197 let end = (start + CHUNK).min(buf.len());
198 let chunk = &mut buf[start..end];
199 let n = chunk.len();
200 scratch[..n].copy_from_slice(chunk);
201 super::avx512_sigmoid_f32::run(chunk, ());
202 for i in 0..n {
203 chunk[i] *= scratch[i];
204 }
205 start = end;
206 }
207 }
208);
209
210#[cfg(test)]
211pub mod test_x86_64_avx512_silu_f32_16n {
212 use super::*;
213 silu_frame_tests!(is_x86_feature_detected!("avx512f"), f32, x86_64_avx512_silu_f32_16n);
214}
215
216ew_impl_wrap!(
222 f32,
223 x86_64_avx512_gelu_f32_16n,
224 16,
225 16,
226 (),
227 #[inline(never)]
228 fn run(buf: &mut [f32], _: ()) {
229 debug_assert!(buf.len() % Self::nr() == 0);
230 debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
231 const SQRT_2_OVER_PI: f32 = 0.7978845608028654;
232 const COEF: f32 = 0.044715;
233 const CHUNK: usize = 256;
234 let mut scratch = [0f32; CHUNK];
235 let mut start = 0;
236 while start < buf.len() {
237 let end = (start + CHUNK).min(buf.len());
238 let chunk = &mut buf[start..end];
239 let n = chunk.len();
240 for i in 0..n {
241 let x = chunk[i];
242 scratch[i] = x;
243 chunk[i] = SQRT_2_OVER_PI * (x + COEF * x * x * x);
244 }
245 super::avx512_tanh_f32::run(chunk, ());
246 for i in 0..n {
247 chunk[i] = 0.5 * scratch[i] * (1.0 + chunk[i]);
248 }
249 start = end;
250 }
251 }
252);
253
254#[cfg(test)]
255pub mod test_x86_64_avx512_gelu_f32_16n {
256 use super::*;
257 gelu_frame_tests!(is_x86_feature_detected!("avx512f"), f32, x86_64_avx512_gelu_f32_16n);
258}