1map_reduce_impl_wrap!(
2 f32,
3 x86_64_fma_softmax2_fastcompact_f32_32n,
4 32,
5 8,
6 f32,
7 f32::MIN,
8 0f32,
9 #[inline(never)]
10 fn run(buf: &mut [f32], max: f32) -> f32 {
11 assert!(buf.len() % 32 == 0);
12 assert!(buf.len() > 0);
13 unsafe { x86_64_fma_softmax2_fastcompact_f32_32n_run(buf, max) }
14 },
15 #[inline(never)]
16 fn reduce_two(a: f32, b: f32) -> f32 {
17 a + b
18 }
19);
20
21#[target_feature(enable = "avx,fma")]
22unsafe fn x86_64_fma_softmax2_fastcompact_f32_32n_run(buf: &mut [f32], max: f32) -> f32 {
23 unsafe {
24 let len = buf.len();
25 let ptr = buf.as_ptr();
26 let mut acc = 0f32;
27 const MLN2: f32 = 0.6931471805f32;
28 const A: f32 = 8388608.0f32;
29 const B: f32 = 1065353216.0f32;
30 const C: f32 = 60801.0f32;
31 const SLOPE: f32 = A / MLN2;
32 const OFFSET: f32 = B - C;
33 std::arch::asm!("
34 vbroadcastss ymm0, xmm0
35 vmovaps ymm1, ymm0
36 vmovaps ymm2, ymm0
37 vmovaps ymm3, ymm0
38
39 vpxor ymm12, ymm12, ymm12
40 vbroadcastss ymm13, xmm13
41 vbroadcastss ymm14, xmm14
42 vbroadcastss ymm15, xmm15
43 2:
44 vmovaps ymm4, [{ptr}]
45 vmovaps ymm5, [{ptr} + 32]
46 vmovaps ymm6, [{ptr} + 64]
47 vmovaps ymm7, [{ptr} + 96]
48
49 vsubps ymm4, ymm4, ymm13
50 vsubps ymm5, ymm5, ymm13
51 vsubps ymm6, ymm6, ymm13
52 vsubps ymm7, ymm7, ymm13
53
54 vmovaps ymm8, ymm15
55 vmovaps ymm9, ymm15
56 vmovaps ymm10, ymm15
57 vmovaps ymm11, ymm15
58
59 vfmadd231ps ymm8, ymm4, ymm14
60 vfmadd231ps ymm9, ymm5, ymm14
61 vfmadd231ps ymm10, ymm6, ymm14
62 vfmadd231ps ymm11, ymm7, ymm14
63
64 vmaxps ymm8, ymm8, ymm12
65 vmaxps ymm9, ymm9, ymm12
66 vmaxps ymm10, ymm10, ymm12
67 vmaxps ymm11, ymm11, ymm12
68
69 vcvttps2dq ymm8, ymm8
70 vcvttps2dq ymm9, ymm9
71 vcvttps2dq ymm10, ymm10
72 vcvttps2dq ymm11, ymm11
73
74 vmovaps [{ptr}] , ymm8
75 vmovaps [{ptr} + 32], ymm9
76 vmovaps [{ptr} + 64], ymm10
77 vmovaps [{ptr} + 96], ymm11
78
79 vaddps ymm0, ymm0, ymm8
80 vaddps ymm1, ymm1, ymm9
81 vaddps ymm2, ymm2, ymm10
82 vaddps ymm3, ymm3, ymm11
83
84 add {ptr}, 128
85 sub {len}, 32
86 jnz 2b
87
88 vaddps ymm0, ymm0, ymm1
89 vaddps ymm2, ymm2, ymm3
90 vaddps ymm0, ymm0, ymm2
91 vperm2f128 ymm1, ymm0, ymm0, 1
92 vaddps xmm0, xmm0, xmm1
93 vpermilps xmm1, xmm0, 2 + (3 << 2)
94 vaddps xmm0, xmm0, xmm1
95 vpermilps xmm1, xmm0, 1
96 vaddps xmm0, xmm0, xmm1
97 ",
98 len = inout(reg) len => _,
99 ptr = inout(reg) ptr => _,
100 inout("ymm0") acc,
101 out("ymm1") _, out("ymm2") _, out("ymm3") _,
102 out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _,
103 out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _,
104 out("ymm12") _,
105 inout("ymm13") max => _,
106 inout("ymm14") SLOPE => _,
107 inout("ymm15") OFFSET => _,
108 );
109 acc
110 }
111}
112
113#[cfg(test)]
114mod test_x86_64_fma_softmax2_fastcompact_f32_32n {
115 use super::*;
116 crate::softmax_l2_frame_tests!(
117 is_x86_feature_detected!("fma"),
118 f32,
119 x86_64_fma_softmax2_fastcompact_f32_32n
120 );
121}