Skip to main content

tract_linalg/x86_64_fma/
softmax.rs

1map_reduce_impl_wrap!(
2    f32,
3    x86_64_fma_softmax2_fastcompact_f32_32n,
4    32,
5    8,
6    f32,
7    f32::MIN,
8    0f32,
9    #[inline(never)]
10    fn run(buf: &mut [f32], max: f32) -> f32 {
11        assert!(buf.len() % 32 == 0);
12        assert!(buf.len() > 0);
13        unsafe { x86_64_fma_softmax2_fastcompact_f32_32n_run(buf, max) }
14    },
15    #[inline(never)]
16    fn reduce_two(a: f32, b: f32) -> f32 {
17        a + b
18    }
19);
20
21#[target_feature(enable = "avx,fma")]
22unsafe fn x86_64_fma_softmax2_fastcompact_f32_32n_run(buf: &mut [f32], max: f32) -> f32 {
23    let len = buf.len();
24    let ptr = buf.as_ptr();
25    let mut acc = 0f32;
26    const MLN2: f32 = 0.6931471805f32;
27    const A: f32 = 8388608.0f32;
28    const B: f32 = 1065353216.0f32;
29    const C: f32 = 60801.0f32;
30    const SLOPE: f32 = A / MLN2;
31    const OFFSET: f32 = B - C;
32    std::arch::asm!("
33            vbroadcastss ymm0, xmm0
34            vmovaps ymm1, ymm0
35            vmovaps ymm2, ymm0
36            vmovaps ymm3, ymm0
37
38            vpxor   ymm12, ymm12, ymm12
39            vbroadcastss ymm13, xmm13
40            vbroadcastss ymm14, xmm14
41            vbroadcastss ymm15, xmm15
42            2:
43                vmovaps ymm4, [{ptr}]
44                vmovaps ymm5, [{ptr} + 32]
45                vmovaps ymm6, [{ptr} + 64]
46                vmovaps ymm7, [{ptr} + 96]
47
48                vsubps ymm4, ymm4, ymm13
49                vsubps ymm5, ymm5, ymm13
50                vsubps ymm6, ymm6, ymm13
51                vsubps ymm7, ymm7, ymm13
52
53                vmovaps ymm8, ymm15
54                vmovaps ymm9, ymm15
55                vmovaps ymm10, ymm15
56                vmovaps ymm11, ymm15
57
58                vfmadd231ps ymm8, ymm4, ymm14
59                vfmadd231ps ymm9, ymm5, ymm14
60                vfmadd231ps ymm10, ymm6, ymm14
61                vfmadd231ps ymm11, ymm7, ymm14
62
63                vmaxps ymm8, ymm8, ymm12
64                vmaxps ymm9, ymm9, ymm12
65                vmaxps ymm10, ymm10, ymm12
66                vmaxps ymm11, ymm11, ymm12
67
68                vcvttps2dq ymm8, ymm8
69                vcvttps2dq ymm9, ymm9
70                vcvttps2dq ymm10, ymm10
71                vcvttps2dq ymm11, ymm11
72
73                vmovaps [{ptr}]     , ymm8
74                vmovaps [{ptr} + 32], ymm9
75                vmovaps [{ptr} + 64], ymm10
76                vmovaps [{ptr} + 96], ymm11
77
78                vaddps ymm0, ymm0, ymm8
79                vaddps ymm1, ymm1, ymm9
80                vaddps ymm2, ymm2, ymm10
81                vaddps ymm3, ymm3, ymm11
82
83                add {ptr}, 128
84                sub {len}, 32
85                jnz 2b
86
87            vaddps ymm0, ymm0, ymm1
88            vaddps ymm2, ymm2, ymm3
89            vaddps ymm0, ymm0, ymm2
90            vperm2f128 ymm1, ymm0, ymm0, 1
91            vaddps xmm0, xmm0, xmm1
92            vpermilps xmm1, xmm0, 2 + (3 << 2)
93            vaddps xmm0, xmm0, xmm1
94            vpermilps xmm1, xmm0, 1
95            vaddps xmm0, xmm0, xmm1
96            ",
97    len = inout(reg) len => _,
98    ptr = inout(reg) ptr => _,
99    inout("ymm0") acc,
100    out("ymm1") _, out("ymm2") _, out("ymm3") _,
101    out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _,
102    out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _,
103    out("ymm12") _,
104    inout("ymm13") max => _,
105    inout("ymm14") SLOPE => _,
106    inout("ymm15") OFFSET => _,
107    );
108    acc
109}
110
111#[cfg(test)]
112mod test_x86_64_fma_softmax2_fastcompact_f32_32n {
113    use super::*;
114    crate::softmax_l2_frame_tests!(is_x86_feature_detected!("fma"), f32, x86_64_fma_softmax2_fastcompact_f32_32n);
115}