Skip to main content

tract_linalg/x86_64_fma/
softmax.rs

1map_reduce_impl_wrap!(
2    f32,
3    x86_64_fma_softmax2_fastcompact_f32_32n,
4    32,
5    8,
6    f32,
7    f32::MIN,
8    0f32,
9    #[inline(never)]
10    fn run(buf: &mut [f32], max: f32) -> f32 {
11        assert!(buf.len() % 32 == 0);
12        assert!(buf.len() > 0);
13        unsafe { x86_64_fma_softmax2_fastcompact_f32_32n_run(buf, max) }
14    },
15    #[inline(never)]
16    fn reduce_two(a: f32, b: f32) -> f32 {
17        a + b
18    }
19);
20
21#[target_feature(enable = "avx,fma")]
22unsafe fn x86_64_fma_softmax2_fastcompact_f32_32n_run(buf: &mut [f32], max: f32) -> f32 {
23    unsafe {
24        let len = buf.len();
25        let ptr = buf.as_ptr();
26        let mut acc = 0f32;
27        const MLN2: f32 = 0.6931471805f32;
28        const A: f32 = 8388608.0f32;
29        const B: f32 = 1065353216.0f32;
30        const C: f32 = 60801.0f32;
31        const SLOPE: f32 = A / MLN2;
32        const OFFSET: f32 = B - C;
33        std::arch::asm!("
34            vbroadcastss ymm0, xmm0
35            vmovaps ymm1, ymm0
36            vmovaps ymm2, ymm0
37            vmovaps ymm3, ymm0
38
39            vpxor   ymm12, ymm12, ymm12
40            vbroadcastss ymm13, xmm13
41            vbroadcastss ymm14, xmm14
42            vbroadcastss ymm15, xmm15
43            2:
44                vmovaps ymm4, [{ptr}]
45                vmovaps ymm5, [{ptr} + 32]
46                vmovaps ymm6, [{ptr} + 64]
47                vmovaps ymm7, [{ptr} + 96]
48
49                vsubps ymm4, ymm4, ymm13
50                vsubps ymm5, ymm5, ymm13
51                vsubps ymm6, ymm6, ymm13
52                vsubps ymm7, ymm7, ymm13
53
54                vmovaps ymm8, ymm15
55                vmovaps ymm9, ymm15
56                vmovaps ymm10, ymm15
57                vmovaps ymm11, ymm15
58
59                vfmadd231ps ymm8, ymm4, ymm14
60                vfmadd231ps ymm9, ymm5, ymm14
61                vfmadd231ps ymm10, ymm6, ymm14
62                vfmadd231ps ymm11, ymm7, ymm14
63
64                vmaxps ymm8, ymm8, ymm12
65                vmaxps ymm9, ymm9, ymm12
66                vmaxps ymm10, ymm10, ymm12
67                vmaxps ymm11, ymm11, ymm12
68
69                vcvttps2dq ymm8, ymm8
70                vcvttps2dq ymm9, ymm9
71                vcvttps2dq ymm10, ymm10
72                vcvttps2dq ymm11, ymm11
73
74                vmovaps [{ptr}]     , ymm8
75                vmovaps [{ptr} + 32], ymm9
76                vmovaps [{ptr} + 64], ymm10
77                vmovaps [{ptr} + 96], ymm11
78
79                vaddps ymm0, ymm0, ymm8
80                vaddps ymm1, ymm1, ymm9
81                vaddps ymm2, ymm2, ymm10
82                vaddps ymm3, ymm3, ymm11
83
84                add {ptr}, 128
85                sub {len}, 32
86                jnz 2b
87
88            vaddps ymm0, ymm0, ymm1
89            vaddps ymm2, ymm2, ymm3
90            vaddps ymm0, ymm0, ymm2
91            vperm2f128 ymm1, ymm0, ymm0, 1
92            vaddps xmm0, xmm0, xmm1
93            vpermilps xmm1, xmm0, 2 + (3 << 2)
94            vaddps xmm0, xmm0, xmm1
95            vpermilps xmm1, xmm0, 1
96            vaddps xmm0, xmm0, xmm1
97            ",
98        len = inout(reg) len => _,
99        ptr = inout(reg) ptr => _,
100        inout("ymm0") acc,
101        out("ymm1") _, out("ymm2") _, out("ymm3") _,
102        out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _,
103        out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _,
104        out("ymm12") _,
105        inout("ymm13") max => _,
106        inout("ymm14") SLOPE => _,
107        inout("ymm15") OFFSET => _,
108        );
109        acc
110    }
111}
112
113#[cfg(test)]
114mod test_x86_64_fma_softmax2_fastcompact_f32_32n {
115    use super::*;
116    crate::softmax_l2_frame_tests!(
117        is_x86_feature_detected!("fma"),
118        f32,
119        x86_64_fma_softmax2_fastcompact_f32_32n
120    );
121}