1map_reduce_impl_wrap!(
2 f32,
3 x86_64_fma_softmax2_fastcompact_f32_32n,
4 32,
5 8,
6 f32,
7 f32::MIN,
8 0f32,
9 #[inline(never)]
10 fn run(buf: &mut [f32], max: f32) -> f32 {
11 assert!(buf.len() % 32 == 0);
12 assert!(buf.len() > 0);
13 unsafe { x86_64_fma_softmax2_fastcompact_f32_32n_run(buf, max) }
14 },
15 #[inline(never)]
16 fn reduce_two(a: f32, b: f32) -> f32 {
17 a + b
18 }
19);
20
21#[target_feature(enable = "avx,fma")]
22unsafe fn x86_64_fma_softmax2_fastcompact_f32_32n_run(buf: &mut [f32], max: f32) -> f32 {
23 let len = buf.len();
24 let ptr = buf.as_ptr();
25 let mut acc = 0f32;
26 const MLN2: f32 = 0.6931471805f32;
27 const A: f32 = 8388608.0f32;
28 const B: f32 = 1065353216.0f32;
29 const C: f32 = 60801.0f32;
30 const SLOPE: f32 = A / MLN2;
31 const OFFSET: f32 = B - C;
32 std::arch::asm!("
33 vbroadcastss ymm0, xmm0
34 vmovaps ymm1, ymm0
35 vmovaps ymm2, ymm0
36 vmovaps ymm3, ymm0
37
38 vpxor ymm12, ymm12, ymm12
39 vbroadcastss ymm13, xmm13
40 vbroadcastss ymm14, xmm14
41 vbroadcastss ymm15, xmm15
42 2:
43 vmovaps ymm4, [{ptr}]
44 vmovaps ymm5, [{ptr} + 32]
45 vmovaps ymm6, [{ptr} + 64]
46 vmovaps ymm7, [{ptr} + 96]
47
48 vsubps ymm4, ymm4, ymm13
49 vsubps ymm5, ymm5, ymm13
50 vsubps ymm6, ymm6, ymm13
51 vsubps ymm7, ymm7, ymm13
52
53 vmovaps ymm8, ymm15
54 vmovaps ymm9, ymm15
55 vmovaps ymm10, ymm15
56 vmovaps ymm11, ymm15
57
58 vfmadd231ps ymm8, ymm4, ymm14
59 vfmadd231ps ymm9, ymm5, ymm14
60 vfmadd231ps ymm10, ymm6, ymm14
61 vfmadd231ps ymm11, ymm7, ymm14
62
63 vmaxps ymm8, ymm8, ymm12
64 vmaxps ymm9, ymm9, ymm12
65 vmaxps ymm10, ymm10, ymm12
66 vmaxps ymm11, ymm11, ymm12
67
68 vcvttps2dq ymm8, ymm8
69 vcvttps2dq ymm9, ymm9
70 vcvttps2dq ymm10, ymm10
71 vcvttps2dq ymm11, ymm11
72
73 vmovaps [{ptr}] , ymm8
74 vmovaps [{ptr} + 32], ymm9
75 vmovaps [{ptr} + 64], ymm10
76 vmovaps [{ptr} + 96], ymm11
77
78 vaddps ymm0, ymm0, ymm8
79 vaddps ymm1, ymm1, ymm9
80 vaddps ymm2, ymm2, ymm10
81 vaddps ymm3, ymm3, ymm11
82
83 add {ptr}, 128
84 sub {len}, 32
85 jnz 2b
86
87 vaddps ymm0, ymm0, ymm1
88 vaddps ymm2, ymm2, ymm3
89 vaddps ymm0, ymm0, ymm2
90 vperm2f128 ymm1, ymm0, ymm0, 1
91 vaddps xmm0, xmm0, xmm1
92 vpermilps xmm1, xmm0, 2 + (3 << 2)
93 vaddps xmm0, xmm0, xmm1
94 vpermilps xmm1, xmm0, 1
95 vaddps xmm0, xmm0, xmm1
96 ",
97 len = inout(reg) len => _,
98 ptr = inout(reg) ptr => _,
99 inout("ymm0") acc,
100 out("ymm1") _, out("ymm2") _, out("ymm3") _,
101 out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _,
102 out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _,
103 out("ymm12") _,
104 inout("ymm13") max => _,
105 inout("ymm14") SLOPE => _,
106 inout("ymm15") OFFSET => _,
107 );
108 acc
109}
110
111#[cfg(test)]
112mod test_x86_64_fma_softmax2_fastcompact_f32_32n {
113 use super::*;
114 crate::softmax_l2_frame_tests!(is_x86_feature_detected!("fma"), f32, x86_64_fma_softmax2_fastcompact_f32_32n);
115}