Skip to main content

tract_linalg/x86_64_fma/
softmax.rs

1use crate::num_traits::Zero;
2use tract_data::internal::f16;
3
4map_reduce_impl_wrap!(
5    f32,
6    x86_64_fma_softmax2_fastcompact_f32_32n,
7    32,
8    8,
9    f32,
10    f32::MIN,
11    0f32,
12    #[inline(never)]
13    fn run(buf: &mut [f32], max: f32) -> f32 {
14        assert!(buf.len() % 32 == 0);
15        assert!(buf.len() > 0);
16        unsafe { x86_64_fma_softmax2_fastcompact_f32_32n_run(buf, max) }
17    },
18    #[inline(never)]
19    fn reduce_two(a: f32, b: f32) -> f32 {
20        a + b
21    }
22);
23
24#[target_feature(enable = "avx,fma")]
25unsafe fn x86_64_fma_softmax2_fastcompact_f32_32n_run(buf: &mut [f32], max: f32) -> f32 {
26    unsafe {
27        let len = buf.len();
28        let ptr = buf.as_ptr();
29        let mut acc = 0f32;
30        const MLN2: f32 = 0.6931471805f32;
31        const A: f32 = 8388608.0f32;
32        const B: f32 = 1065353216.0f32;
33        const C: f32 = 60801.0f32;
34        const SLOPE: f32 = A / MLN2;
35        const OFFSET: f32 = B - C;
36        std::arch::asm!("
37            vbroadcastss ymm0, xmm0
38            vmovaps ymm1, ymm0
39            vmovaps ymm2, ymm0
40            vmovaps ymm3, ymm0
41
42            vpxor   ymm12, ymm12, ymm12
43            vbroadcastss ymm13, xmm13
44            vbroadcastss ymm14, xmm14
45            vbroadcastss ymm15, xmm15
46            2:
47                vmovaps ymm4, [{ptr}]
48                vmovaps ymm5, [{ptr} + 32]
49                vmovaps ymm6, [{ptr} + 64]
50                vmovaps ymm7, [{ptr} + 96]
51
52                vsubps ymm4, ymm4, ymm13
53                vsubps ymm5, ymm5, ymm13
54                vsubps ymm6, ymm6, ymm13
55                vsubps ymm7, ymm7, ymm13
56
57                vmovaps ymm8, ymm15
58                vmovaps ymm9, ymm15
59                vmovaps ymm10, ymm15
60                vmovaps ymm11, ymm15
61
62                vfmadd231ps ymm8, ymm4, ymm14
63                vfmadd231ps ymm9, ymm5, ymm14
64                vfmadd231ps ymm10, ymm6, ymm14
65                vfmadd231ps ymm11, ymm7, ymm14
66
67                vmaxps ymm8, ymm8, ymm12
68                vmaxps ymm9, ymm9, ymm12
69                vmaxps ymm10, ymm10, ymm12
70                vmaxps ymm11, ymm11, ymm12
71
72                vcvttps2dq ymm8, ymm8
73                vcvttps2dq ymm9, ymm9
74                vcvttps2dq ymm10, ymm10
75                vcvttps2dq ymm11, ymm11
76
77                vmovaps [{ptr}]     , ymm8
78                vmovaps [{ptr} + 32], ymm9
79                vmovaps [{ptr} + 64], ymm10
80                vmovaps [{ptr} + 96], ymm11
81
82                vaddps ymm0, ymm0, ymm8
83                vaddps ymm1, ymm1, ymm9
84                vaddps ymm2, ymm2, ymm10
85                vaddps ymm3, ymm3, ymm11
86
87                add {ptr}, 128
88                sub {len}, 32
89                jnz 2b
90
91            vaddps ymm0, ymm0, ymm1
92            vaddps ymm2, ymm2, ymm3
93            vaddps ymm0, ymm0, ymm2
94            vperm2f128 ymm1, ymm0, ymm0, 1
95            vaddps xmm0, xmm0, xmm1
96            vpermilps xmm1, xmm0, 2 + (3 << 2)
97            vaddps xmm0, xmm0, xmm1
98            vpermilps xmm1, xmm0, 1
99            vaddps xmm0, xmm0, xmm1
100            ",
101        len = inout(reg) len => _,
102        ptr = inout(reg) ptr => _,
103        inout("ymm0") acc,
104        out("ymm1") _, out("ymm2") _, out("ymm3") _,
105        out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _,
106        out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _,
107        out("ymm12") _,
108        inout("ymm13") max => _,
109        inout("ymm14") SLOPE => _,
110        inout("ymm15") OFFSET => _,
111        );
112        acc
113    }
114}
115
116#[cfg(test)]
117mod test_x86_64_fma_softmax2_fastcompact_f32_32n {
118    use super::*;
119    crate::softmax_l2_frame_tests!(
120        is_x86_feature_detected!("fma"),
121        f32,
122        x86_64_fma_softmax2_fastcompact_f32_32n
123    );
124}
125
126// AVX-512 version: processes 64 f32 per loop iteration (4 zmm registers of 16
127// lanes each). Same fast-compact-exp algorithm as the FMA kernel above:
128//   y = bitcast_u32(max(0, SLOPE*(x-max) + OFFSET))   (via vcvttps2dq)
129// then writes y back and accumulates sum(y). Runtime-gated on avx512f (see
130// x86_64_fma.rs::plug_avx512f); non-AVX512 CPUs keep using the FMA kernel.
131// nr=64, 64-byte (16xf32) alignment.
132map_reduce_impl_wrap!(
133    f32,
134    x86_64_avx512_softmax2_fastcompact_f32_64n,
135    64,
136    16,
137    f32,
138    f32::MIN,
139    0f32,
140    #[inline(never)]
141    fn run(buf: &mut [f32], max: f32) -> f32 {
142        assert!(buf.len() % 64 == 0);
143        assert!(buf.len() > 0);
144        unsafe { x86_64_avx512_softmax2_fastcompact_f32_64n_run(buf, max) }
145    },
146    #[inline(never)]
147    fn reduce_two(a: f32, b: f32) -> f32 {
148        a + b
149    }
150);
151
152#[target_feature(enable = "avx512f")]
153unsafe fn x86_64_avx512_softmax2_fastcompact_f32_64n_run(buf: &mut [f32], max: f32) -> f32 {
154    unsafe {
155        let len = buf.len();
156        let ptr = buf.as_ptr();
157        let mut acc = 0f32;
158        const MLN2: f32 = 0.6931471805f32;
159        const A: f32 = 8388608.0f32;
160        const B: f32 = 1065353216.0f32;
161        const C: f32 = 60801.0f32;
162        const SLOPE: f32 = A / MLN2;
163        const OFFSET: f32 = B - C;
164        std::arch::asm!("
165            vbroadcastss zmm0, xmm0
166            vmovaps zmm1, zmm0
167            vmovaps zmm2, zmm0
168            vmovaps zmm3, zmm0
169
170            vpxord  zmm28, zmm28, zmm28          // zero (clamp floor)
171            vbroadcastss zmm29, xmm29            // max
172            vbroadcastss zmm30, xmm30            // slope
173            vbroadcastss zmm31, xmm31            // offset
174            2:
175                vmovaps zmm4, [{ptr}]
176                vmovaps zmm5, [{ptr} + 64]
177                vmovaps zmm6, [{ptr} + 128]
178                vmovaps zmm7, [{ptr} + 192]
179
180                vsubps zmm4, zmm4, zmm29
181                vsubps zmm5, zmm5, zmm29
182                vsubps zmm6, zmm6, zmm29
183                vsubps zmm7, zmm7, zmm29
184
185                vmovaps zmm8, zmm31
186                vmovaps zmm9, zmm31
187                vmovaps zmm10, zmm31
188                vmovaps zmm11, zmm31
189
190                vfmadd231ps zmm8, zmm4, zmm30
191                vfmadd231ps zmm9, zmm5, zmm30
192                vfmadd231ps zmm10, zmm6, zmm30
193                vfmadd231ps zmm11, zmm7, zmm30
194
195                vmaxps zmm8, zmm8, zmm28
196                vmaxps zmm9, zmm9, zmm28
197                vmaxps zmm10, zmm10, zmm28
198                vmaxps zmm11, zmm11, zmm28
199
200                vcvttps2dq zmm8, zmm8
201                vcvttps2dq zmm9, zmm9
202                vcvttps2dq zmm10, zmm10
203                vcvttps2dq zmm11, zmm11
204
205                vmovaps [{ptr}]      , zmm8
206                vmovaps [{ptr} + 64] , zmm9
207                vmovaps [{ptr} + 128], zmm10
208                vmovaps [{ptr} + 192], zmm11
209
210                vaddps zmm0, zmm0, zmm8
211                vaddps zmm1, zmm1, zmm9
212                vaddps zmm2, zmm2, zmm10
213                vaddps zmm3, zmm3, zmm11
214
215                add {ptr}, 256
216                sub {len}, 64
217                jnz 2b
218
219            vaddps zmm0, zmm0, zmm1
220            vaddps zmm2, zmm2, zmm3
221            vaddps zmm0, zmm0, zmm2             // zmm0 holds 16 partial sums
222            vextractf64x4 ymm1, zmm0, 1         // upper 256 bits (8xf32) -> ymm1 (avx512f)
223            vaddps ymm0, ymm0, ymm1            // ymm0 holds 8 values
224            vextractf128 xmm1, ymm0, 1          // upper 4xf32 -> xmm1
225            vaddps xmm0, xmm0, xmm1            // xmm0 holds 4 values
226            vpermilps xmm1, xmm0, 2 + (3 << 2)
227            vaddps xmm0, xmm0, xmm1            // xmm0 holds 2 values
228            vpermilps xmm1, xmm0, 1
229            vaddps xmm0, xmm0, xmm1
230            ",
231        len = inout(reg) len => _,
232        ptr = inout(reg) ptr => _,
233        inout("zmm0") acc,
234        out("zmm1") _, out("zmm2") _, out("zmm3") _,
235        out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
236        out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
237        out("zmm28") _,
238        inout("zmm29") max => _,
239        inout("zmm30") SLOPE => _,
240        inout("zmm31") OFFSET => _,
241        );
242        acc
243    }
244}
245
246#[cfg(test)]
247mod test_x86_64_avx512_softmax2_fastcompact_f32_64n {
248    use super::*;
249    crate::softmax_l2_frame_tests!(
250        is_x86_feature_detected!("avx512f"),
251        f32,
252        x86_64_avx512_softmax2_fastcompact_f32_64n
253    );
254}
255
256// AVX-512 f16 softmax_l2: same fast-compact-exp algorithm as the FMA f32
257// kernel, with f16 <-> f32 conversion at the IO boundary. Each loop iteration
258// handles 64 f16 (128 bytes) through 4× (ymm f16 load -> vcvtph2ps -> zmm f32
259// compute -> vcvttps2dq -> vcvtps2ph -> ymm f16 store). The sum is accumulated
260// in f32 across the loop (higher precision than the generic HSoftMaxL2 which
261// accumulates in f16) and cast to f16 at return; the SuperApproximate test
262// tolerance covers the precision delta.
263// nr=64 (multiple of 4 ymm f16 loads); alignment_items=32 (64-byte aligned).
264map_reduce_impl_wrap!(
265    f16,
266    x86_64_avx512_softmax2_fastcompact_f16_64n,
267    64,
268    32,
269    f16,
270    f16::MIN,
271    f16::zero(),
272    #[inline(never)]
273    fn run(buf: &mut [f16], max: f16) -> f16 {
274        assert!(buf.len() % 64 == 0);
275        assert!(buf.len() > 0);
276        unsafe { x86_64_avx512_softmax2_fastcompact_f16_64n_run(buf, max) }
277    },
278    #[inline(never)]
279    fn reduce_two(a: f16, b: f16) -> f16 {
280        a + b
281    }
282);
283
284#[target_feature(enable = "avx512f")]
285unsafe fn x86_64_avx512_softmax2_fastcompact_f16_64n_run(
286    buf: &mut [tract_data::internal::f16],
287    max: tract_data::internal::f16,
288) -> tract_data::internal::f16 {
289    unsafe {
290        let len = buf.len();
291        let ptr = buf.as_ptr();
292        let max_f32: f32 = max.to_f32();
293        let mut acc = 0f32;
294        const MLN2: f32 = 0.6931471805f32;
295        const A: f32 = 8388608.0f32;
296        const B: f32 = 1065353216.0f32;
297        const C: f32 = 60801.0f32;
298        const SLOPE: f32 = A / MLN2;
299        const OFFSET: f32 = B - C;
300        std::arch::asm!("
301            vbroadcastss zmm0, xmm0
302            vmovaps zmm1, zmm0
303            vmovaps zmm2, zmm0
304            vmovaps zmm3, zmm0
305
306            vpxord       zmm28, zmm28, zmm28      // 0 (clamp floor)
307            vbroadcastss zmm29, xmm29             // max (f32)
308            vbroadcastss zmm30, xmm30             // slope
309            vbroadcastss zmm31, xmm31             // offset
310            2:
311                // load 4 ymm of f16 (16 f16 per ymm = 32 bytes), convert to zmm f32
312                vcvtph2ps zmm4, [{ptr}]
313                vcvtph2ps zmm5, [{ptr} + 32]
314                vcvtph2ps zmm6, [{ptr} + 64]
315                vcvtph2ps zmm7, [{ptr} + 96]
316
317                // subtract max
318                vsubps zmm4, zmm4, zmm29
319                vsubps zmm5, zmm5, zmm29
320                vsubps zmm6, zmm6, zmm29
321                vsubps zmm7, zmm7, zmm29
322
323                // OFFSET + SLOPE * (x - max)
324                vmovaps zmm8,  zmm31
325                vmovaps zmm9,  zmm31
326                vmovaps zmm10, zmm31
327                vmovaps zmm11, zmm31
328                vfmadd231ps zmm8,  zmm4, zmm30
329                vfmadd231ps zmm9,  zmm5, zmm30
330                vfmadd231ps zmm10, zmm6, zmm30
331                vfmadd231ps zmm11, zmm7, zmm30
332
333                // max(0, ...)
334                vmaxps zmm8,  zmm8,  zmm28
335                vmaxps zmm9,  zmm9,  zmm28
336                vmaxps zmm10, zmm10, zmm28
337                vmaxps zmm11, zmm11, zmm28
338
339                // fast-compact-exp trick: the truncated i32 has the same bit
340                // pattern as the f32 ~exp(x), so accumulate AS f32 + store as f16
341                vcvttps2dq zmm8,  zmm8
342                vcvttps2dq zmm9,  zmm9
343                vcvttps2dq zmm10, zmm10
344                vcvttps2dq zmm11, zmm11
345
346                vaddps zmm0, zmm0, zmm8
347                vaddps zmm1, zmm1, zmm9
348                vaddps zmm2, zmm2, zmm10
349                vaddps zmm3, zmm3, zmm11
350
351                // convert back to f16 and store (4th operand 0 = round to nearest even)
352                vcvtps2ph [{ptr}],      zmm8,  0
353                vcvtps2ph [{ptr} + 32], zmm9,  0
354                vcvtps2ph [{ptr} + 64], zmm10, 0
355                vcvtps2ph [{ptr} + 96], zmm11, 0
356
357                add {ptr}, 128
358                sub {len}, 64
359                jnz 2b
360
361            // reduce zmm0..3 to a scalar f32 in xmm0
362            vaddps zmm0, zmm0, zmm1
363            vaddps zmm2, zmm2, zmm3
364            vaddps zmm0, zmm0, zmm2
365            vextractf64x4 ymm1, zmm0, 1
366            vaddps ymm0, ymm0, ymm1
367            vextractf128 xmm1, ymm0, 1
368            vaddps xmm0, xmm0, xmm1
369            vpermilps xmm1, xmm0, 2 + (3 << 2)
370            vaddps xmm0, xmm0, xmm1
371            vpermilps xmm1, xmm0, 1
372            vaddps xmm0, xmm0, xmm1
373            ",
374        len = inout(reg) len => _,
375        ptr = inout(reg) ptr => _,
376        inout("zmm0") acc,
377        out("zmm1") _, out("zmm2") _, out("zmm3") _,
378        out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
379        out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
380        out("zmm28") _,
381        inout("zmm29") max_f32 => _,
382        inout("zmm30") SLOPE => _,
383        inout("zmm31") OFFSET => _,
384        );
385        f16::from_f32(acc)
386    }
387}
388
389#[cfg(test)]
390mod test_x86_64_avx512_softmax2_fastcompact_f16_64n {
391    use super::*;
392    use tract_data::internal::f16;
393    crate::softmax_l2_frame_tests!(
394        is_x86_feature_detected!("avx512f"),
395        f16,
396        x86_64_avx512_softmax2_fastcompact_f16_64n
397    );
398}