1use crate::num_traits::Zero;
2use tract_data::internal::f16;
3
4map_reduce_impl_wrap!(
5 f32,
6 x86_64_fma_softmax2_fastcompact_f32_32n,
7 32,
8 8,
9 f32,
10 f32::MIN,
11 0f32,
12 #[inline(never)]
13 fn run(buf: &mut [f32], max: f32) -> f32 {
14 assert!(buf.len() % 32 == 0);
15 assert!(buf.len() > 0);
16 unsafe { x86_64_fma_softmax2_fastcompact_f32_32n_run(buf, max) }
17 },
18 #[inline(never)]
19 fn reduce_two(a: f32, b: f32) -> f32 {
20 a + b
21 }
22);
23
24#[target_feature(enable = "avx,fma")]
25unsafe fn x86_64_fma_softmax2_fastcompact_f32_32n_run(buf: &mut [f32], max: f32) -> f32 {
26 unsafe {
27 let len = buf.len();
28 let ptr = buf.as_ptr();
29 let mut acc = 0f32;
30 const MLN2: f32 = 0.6931471805f32;
31 const A: f32 = 8388608.0f32;
32 const B: f32 = 1065353216.0f32;
33 const C: f32 = 60801.0f32;
34 const SLOPE: f32 = A / MLN2;
35 const OFFSET: f32 = B - C;
36 std::arch::asm!("
37 vbroadcastss ymm0, xmm0
38 vmovaps ymm1, ymm0
39 vmovaps ymm2, ymm0
40 vmovaps ymm3, ymm0
41
42 vpxor ymm12, ymm12, ymm12
43 vbroadcastss ymm13, xmm13
44 vbroadcastss ymm14, xmm14
45 vbroadcastss ymm15, xmm15
46 2:
47 vmovaps ymm4, [{ptr}]
48 vmovaps ymm5, [{ptr} + 32]
49 vmovaps ymm6, [{ptr} + 64]
50 vmovaps ymm7, [{ptr} + 96]
51
52 vsubps ymm4, ymm4, ymm13
53 vsubps ymm5, ymm5, ymm13
54 vsubps ymm6, ymm6, ymm13
55 vsubps ymm7, ymm7, ymm13
56
57 vmovaps ymm8, ymm15
58 vmovaps ymm9, ymm15
59 vmovaps ymm10, ymm15
60 vmovaps ymm11, ymm15
61
62 vfmadd231ps ymm8, ymm4, ymm14
63 vfmadd231ps ymm9, ymm5, ymm14
64 vfmadd231ps ymm10, ymm6, ymm14
65 vfmadd231ps ymm11, ymm7, ymm14
66
67 vmaxps ymm8, ymm8, ymm12
68 vmaxps ymm9, ymm9, ymm12
69 vmaxps ymm10, ymm10, ymm12
70 vmaxps ymm11, ymm11, ymm12
71
72 vcvttps2dq ymm8, ymm8
73 vcvttps2dq ymm9, ymm9
74 vcvttps2dq ymm10, ymm10
75 vcvttps2dq ymm11, ymm11
76
77 vmovaps [{ptr}] , ymm8
78 vmovaps [{ptr} + 32], ymm9
79 vmovaps [{ptr} + 64], ymm10
80 vmovaps [{ptr} + 96], ymm11
81
82 vaddps ymm0, ymm0, ymm8
83 vaddps ymm1, ymm1, ymm9
84 vaddps ymm2, ymm2, ymm10
85 vaddps ymm3, ymm3, ymm11
86
87 add {ptr}, 128
88 sub {len}, 32
89 jnz 2b
90
91 vaddps ymm0, ymm0, ymm1
92 vaddps ymm2, ymm2, ymm3
93 vaddps ymm0, ymm0, ymm2
94 vperm2f128 ymm1, ymm0, ymm0, 1
95 vaddps xmm0, xmm0, xmm1
96 vpermilps xmm1, xmm0, 2 + (3 << 2)
97 vaddps xmm0, xmm0, xmm1
98 vpermilps xmm1, xmm0, 1
99 vaddps xmm0, xmm0, xmm1
100 ",
101 len = inout(reg) len => _,
102 ptr = inout(reg) ptr => _,
103 inout("ymm0") acc,
104 out("ymm1") _, out("ymm2") _, out("ymm3") _,
105 out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _,
106 out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _,
107 out("ymm12") _,
108 inout("ymm13") max => _,
109 inout("ymm14") SLOPE => _,
110 inout("ymm15") OFFSET => _,
111 );
112 acc
113 }
114}
115
116#[cfg(test)]
117mod test_x86_64_fma_softmax2_fastcompact_f32_32n {
118 use super::*;
119 crate::softmax_l2_frame_tests!(
120 is_x86_feature_detected!("fma"),
121 f32,
122 x86_64_fma_softmax2_fastcompact_f32_32n
123 );
124}
125
126map_reduce_impl_wrap!(
133 f32,
134 x86_64_avx512_softmax2_fastcompact_f32_64n,
135 64,
136 16,
137 f32,
138 f32::MIN,
139 0f32,
140 #[inline(never)]
141 fn run(buf: &mut [f32], max: f32) -> f32 {
142 assert!(buf.len() % 64 == 0);
143 assert!(buf.len() > 0);
144 unsafe { x86_64_avx512_softmax2_fastcompact_f32_64n_run(buf, max) }
145 },
146 #[inline(never)]
147 fn reduce_two(a: f32, b: f32) -> f32 {
148 a + b
149 }
150);
151
152#[target_feature(enable = "avx512f")]
153unsafe fn x86_64_avx512_softmax2_fastcompact_f32_64n_run(buf: &mut [f32], max: f32) -> f32 {
154 unsafe {
155 let len = buf.len();
156 let ptr = buf.as_ptr();
157 let mut acc = 0f32;
158 const MLN2: f32 = 0.6931471805f32;
159 const A: f32 = 8388608.0f32;
160 const B: f32 = 1065353216.0f32;
161 const C: f32 = 60801.0f32;
162 const SLOPE: f32 = A / MLN2;
163 const OFFSET: f32 = B - C;
164 std::arch::asm!("
165 vbroadcastss zmm0, xmm0
166 vmovaps zmm1, zmm0
167 vmovaps zmm2, zmm0
168 vmovaps zmm3, zmm0
169
170 vpxord zmm28, zmm28, zmm28 // zero (clamp floor)
171 vbroadcastss zmm29, xmm29 // max
172 vbroadcastss zmm30, xmm30 // slope
173 vbroadcastss zmm31, xmm31 // offset
174 2:
175 vmovaps zmm4, [{ptr}]
176 vmovaps zmm5, [{ptr} + 64]
177 vmovaps zmm6, [{ptr} + 128]
178 vmovaps zmm7, [{ptr} + 192]
179
180 vsubps zmm4, zmm4, zmm29
181 vsubps zmm5, zmm5, zmm29
182 vsubps zmm6, zmm6, zmm29
183 vsubps zmm7, zmm7, zmm29
184
185 vmovaps zmm8, zmm31
186 vmovaps zmm9, zmm31
187 vmovaps zmm10, zmm31
188 vmovaps zmm11, zmm31
189
190 vfmadd231ps zmm8, zmm4, zmm30
191 vfmadd231ps zmm9, zmm5, zmm30
192 vfmadd231ps zmm10, zmm6, zmm30
193 vfmadd231ps zmm11, zmm7, zmm30
194
195 vmaxps zmm8, zmm8, zmm28
196 vmaxps zmm9, zmm9, zmm28
197 vmaxps zmm10, zmm10, zmm28
198 vmaxps zmm11, zmm11, zmm28
199
200 vcvttps2dq zmm8, zmm8
201 vcvttps2dq zmm9, zmm9
202 vcvttps2dq zmm10, zmm10
203 vcvttps2dq zmm11, zmm11
204
205 vmovaps [{ptr}] , zmm8
206 vmovaps [{ptr} + 64] , zmm9
207 vmovaps [{ptr} + 128], zmm10
208 vmovaps [{ptr} + 192], zmm11
209
210 vaddps zmm0, zmm0, zmm8
211 vaddps zmm1, zmm1, zmm9
212 vaddps zmm2, zmm2, zmm10
213 vaddps zmm3, zmm3, zmm11
214
215 add {ptr}, 256
216 sub {len}, 64
217 jnz 2b
218
219 vaddps zmm0, zmm0, zmm1
220 vaddps zmm2, zmm2, zmm3
221 vaddps zmm0, zmm0, zmm2 // zmm0 holds 16 partial sums
222 vextractf64x4 ymm1, zmm0, 1 // upper 256 bits (8xf32) -> ymm1 (avx512f)
223 vaddps ymm0, ymm0, ymm1 // ymm0 holds 8 values
224 vextractf128 xmm1, ymm0, 1 // upper 4xf32 -> xmm1
225 vaddps xmm0, xmm0, xmm1 // xmm0 holds 4 values
226 vpermilps xmm1, xmm0, 2 + (3 << 2)
227 vaddps xmm0, xmm0, xmm1 // xmm0 holds 2 values
228 vpermilps xmm1, xmm0, 1
229 vaddps xmm0, xmm0, xmm1
230 ",
231 len = inout(reg) len => _,
232 ptr = inout(reg) ptr => _,
233 inout("zmm0") acc,
234 out("zmm1") _, out("zmm2") _, out("zmm3") _,
235 out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
236 out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
237 out("zmm28") _,
238 inout("zmm29") max => _,
239 inout("zmm30") SLOPE => _,
240 inout("zmm31") OFFSET => _,
241 );
242 acc
243 }
244}
245
246#[cfg(test)]
247mod test_x86_64_avx512_softmax2_fastcompact_f32_64n {
248 use super::*;
249 crate::softmax_l2_frame_tests!(
250 is_x86_feature_detected!("avx512f"),
251 f32,
252 x86_64_avx512_softmax2_fastcompact_f32_64n
253 );
254}
255
256map_reduce_impl_wrap!(
265 f16,
266 x86_64_avx512_softmax2_fastcompact_f16_64n,
267 64,
268 32,
269 f16,
270 f16::MIN,
271 f16::zero(),
272 #[inline(never)]
273 fn run(buf: &mut [f16], max: f16) -> f16 {
274 assert!(buf.len() % 64 == 0);
275 assert!(buf.len() > 0);
276 unsafe { x86_64_avx512_softmax2_fastcompact_f16_64n_run(buf, max) }
277 },
278 #[inline(never)]
279 fn reduce_two(a: f16, b: f16) -> f16 {
280 a + b
281 }
282);
283
284#[target_feature(enable = "avx512f")]
285unsafe fn x86_64_avx512_softmax2_fastcompact_f16_64n_run(
286 buf: &mut [tract_data::internal::f16],
287 max: tract_data::internal::f16,
288) -> tract_data::internal::f16 {
289 unsafe {
290 let len = buf.len();
291 let ptr = buf.as_ptr();
292 let max_f32: f32 = max.to_f32();
293 let mut acc = 0f32;
294 const MLN2: f32 = 0.6931471805f32;
295 const A: f32 = 8388608.0f32;
296 const B: f32 = 1065353216.0f32;
297 const C: f32 = 60801.0f32;
298 const SLOPE: f32 = A / MLN2;
299 const OFFSET: f32 = B - C;
300 std::arch::asm!("
301 vbroadcastss zmm0, xmm0
302 vmovaps zmm1, zmm0
303 vmovaps zmm2, zmm0
304 vmovaps zmm3, zmm0
305
306 vpxord zmm28, zmm28, zmm28 // 0 (clamp floor)
307 vbroadcastss zmm29, xmm29 // max (f32)
308 vbroadcastss zmm30, xmm30 // slope
309 vbroadcastss zmm31, xmm31 // offset
310 2:
311 // load 4 ymm of f16 (16 f16 per ymm = 32 bytes), convert to zmm f32
312 vcvtph2ps zmm4, [{ptr}]
313 vcvtph2ps zmm5, [{ptr} + 32]
314 vcvtph2ps zmm6, [{ptr} + 64]
315 vcvtph2ps zmm7, [{ptr} + 96]
316
317 // subtract max
318 vsubps zmm4, zmm4, zmm29
319 vsubps zmm5, zmm5, zmm29
320 vsubps zmm6, zmm6, zmm29
321 vsubps zmm7, zmm7, zmm29
322
323 // OFFSET + SLOPE * (x - max)
324 vmovaps zmm8, zmm31
325 vmovaps zmm9, zmm31
326 vmovaps zmm10, zmm31
327 vmovaps zmm11, zmm31
328 vfmadd231ps zmm8, zmm4, zmm30
329 vfmadd231ps zmm9, zmm5, zmm30
330 vfmadd231ps zmm10, zmm6, zmm30
331 vfmadd231ps zmm11, zmm7, zmm30
332
333 // max(0, ...)
334 vmaxps zmm8, zmm8, zmm28
335 vmaxps zmm9, zmm9, zmm28
336 vmaxps zmm10, zmm10, zmm28
337 vmaxps zmm11, zmm11, zmm28
338
339 // fast-compact-exp trick: the truncated i32 has the same bit
340 // pattern as the f32 ~exp(x), so accumulate AS f32 + store as f16
341 vcvttps2dq zmm8, zmm8
342 vcvttps2dq zmm9, zmm9
343 vcvttps2dq zmm10, zmm10
344 vcvttps2dq zmm11, zmm11
345
346 vaddps zmm0, zmm0, zmm8
347 vaddps zmm1, zmm1, zmm9
348 vaddps zmm2, zmm2, zmm10
349 vaddps zmm3, zmm3, zmm11
350
351 // convert back to f16 and store (4th operand 0 = round to nearest even)
352 vcvtps2ph [{ptr}], zmm8, 0
353 vcvtps2ph [{ptr} + 32], zmm9, 0
354 vcvtps2ph [{ptr} + 64], zmm10, 0
355 vcvtps2ph [{ptr} + 96], zmm11, 0
356
357 add {ptr}, 128
358 sub {len}, 64
359 jnz 2b
360
361 // reduce zmm0..3 to a scalar f32 in xmm0
362 vaddps zmm0, zmm0, zmm1
363 vaddps zmm2, zmm2, zmm3
364 vaddps zmm0, zmm0, zmm2
365 vextractf64x4 ymm1, zmm0, 1
366 vaddps ymm0, ymm0, ymm1
367 vextractf128 xmm1, ymm0, 1
368 vaddps xmm0, xmm0, xmm1
369 vpermilps xmm1, xmm0, 2 + (3 << 2)
370 vaddps xmm0, xmm0, xmm1
371 vpermilps xmm1, xmm0, 1
372 vaddps xmm0, xmm0, xmm1
373 ",
374 len = inout(reg) len => _,
375 ptr = inout(reg) ptr => _,
376 inout("zmm0") acc,
377 out("zmm1") _, out("zmm2") _, out("zmm3") _,
378 out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
379 out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
380 out("zmm28") _,
381 inout("zmm29") max_f32 => _,
382 inout("zmm30") SLOPE => _,
383 inout("zmm31") OFFSET => _,
384 );
385 f16::from_f32(acc)
386 }
387}
388
389#[cfg(test)]
390mod test_x86_64_avx512_softmax2_fastcompact_f16_64n {
391 use super::*;
392 use tract_data::internal::f16;
393 crate::softmax_l2_frame_tests!(
394 is_x86_feature_detected!("avx512f"),
395 f16,
396 x86_64_avx512_softmax2_fastcompact_f16_64n
397 );
398}