1ew_impl_wrap!(
23 f32,
24 x86_64_avx512_erf_f32_64n,
25 64,
26 16,
27 (),
28 #[inline(never)]
29 fn run(buf: &mut [f32], _: ()) {
30 debug_assert!(buf.len() % Self::nr() == 0);
31 debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
32 if buf.is_empty() {
33 return;
34 }
35 unsafe { x86_64_avx512_erf_f32_64n_run(buf) }
36 }
37);
38
39#[target_feature(enable = "avx512f")]
40unsafe fn x86_64_avx512_erf_f32_64n_run(buf: &mut [f32]) {
41 unsafe {
42 let len = buf.len();
43 let ptr = buf.as_ptr();
44 const A1: f32 = 0.0705230784;
45 const A2: f32 = 0.0422820123;
46 const A3: f32 = 0.0092705272;
47 const A4: f32 = 0.0001520143;
48 const A5: f32 = 0.0002765672;
49 const A6: f32 = 0.0000430638;
50 const ABS_MASK: f32 = f32::from_bits(0x7fffffff);
53 const SIGN_MASK: f32 = f32::from_bits(0x80000000);
54 std::arch::asm!("
55 // broadcast constants (xmmN -> zmmN, broadcast across all 16 lanes)
56 vbroadcastss zmm0, xmm0 // a1
57 vbroadcastss zmm1, xmm1 // a2
58 vbroadcastss zmm2, xmm2 // a3
59 vbroadcastss zmm3, xmm3 // a4
60 vbroadcastss zmm4, xmm4 // a5
61 vbroadcastss zmm5, xmm5 // a6
62 vbroadcastss zmm6, xmm6 // 1.0
63 vbroadcastss zmm7, xmm7 // abs mask (0x7fffffff)
64 vbroadcastss zmm8, xmm8 // sign mask (0x80000000)
65 2:
66 // load 4 zmm of input
67 vmovaps zmm9, [{ptr}]
68 vmovaps zmm10, [{ptr} + 64]
69 vmovaps zmm11, [{ptr} + 128]
70 vmovaps zmm12, [{ptr} + 192]
71
72 // sign[i] = x[i] & SIGN_MASK (keeps only the sign bit)
73 vandps zmm13, zmm9, zmm8
74 vandps zmm14, zmm10, zmm8
75 vandps zmm15, zmm11, zmm8
76 vandps zmm16, zmm12, zmm8
77
78 // abs[i] = x[i] & ABS_MASK (clears the sign bit)
79 vandps zmm9, zmm9, zmm7
80 vandps zmm10, zmm10, zmm7
81 vandps zmm11, zmm11, zmm7
82 vandps zmm12, zmm12, zmm7
83
84 // y = a6 (in zmm17..20, 4 independent channels)
85 vmovaps zmm17, zmm5
86 vmovaps zmm18, zmm5
87 vmovaps zmm19, zmm5
88 vmovaps zmm20, zmm5
89
90 // y = y*abs + a5
91 vfmadd213ps zmm17, zmm9, zmm4
92 vfmadd213ps zmm18, zmm10, zmm4
93 vfmadd213ps zmm19, zmm11, zmm4
94 vfmadd213ps zmm20, zmm12, zmm4
95
96 // y = y*abs + a4
97 vfmadd213ps zmm17, zmm9, zmm3
98 vfmadd213ps zmm18, zmm10, zmm3
99 vfmadd213ps zmm19, zmm11, zmm3
100 vfmadd213ps zmm20, zmm12, zmm3
101
102 // y = y*abs + a3
103 vfmadd213ps zmm17, zmm9, zmm2
104 vfmadd213ps zmm18, zmm10, zmm2
105 vfmadd213ps zmm19, zmm11, zmm2
106 vfmadd213ps zmm20, zmm12, zmm2
107
108 // y = y*abs + a2
109 vfmadd213ps zmm17, zmm9, zmm1
110 vfmadd213ps zmm18, zmm10, zmm1
111 vfmadd213ps zmm19, zmm11, zmm1
112 vfmadd213ps zmm20, zmm12, zmm1
113
114 // y = y*abs + a1
115 vfmadd213ps zmm17, zmm9, zmm0
116 vfmadd213ps zmm18, zmm10, zmm0
117 vfmadd213ps zmm19, zmm11, zmm0
118 vfmadd213ps zmm20, zmm12, zmm0
119
120 // y = y * abs (final factor)
121 vmulps zmm17, zmm17, zmm9
122 vmulps zmm18, zmm18, zmm10
123 vmulps zmm19, zmm19, zmm11
124 vmulps zmm20, zmm20, zmm12
125
126 // y = y + 1
127 vaddps zmm17, zmm17, zmm6
128 vaddps zmm18, zmm18, zmm6
129 vaddps zmm19, zmm19, zmm6
130 vaddps zmm20, zmm20, zmm6
131
132 // y^16: square 4 times
133 vmulps zmm17, zmm17, zmm17
134 vmulps zmm18, zmm18, zmm18
135 vmulps zmm19, zmm19, zmm19
136 vmulps zmm20, zmm20, zmm20
137
138 vmulps zmm17, zmm17, zmm17
139 vmulps zmm18, zmm18, zmm18
140 vmulps zmm19, zmm19, zmm19
141 vmulps zmm20, zmm20, zmm20
142
143 vmulps zmm17, zmm17, zmm17
144 vmulps zmm18, zmm18, zmm18
145 vmulps zmm19, zmm19, zmm19
146 vmulps zmm20, zmm20, zmm20
147
148 vmulps zmm17, zmm17, zmm17
149 vmulps zmm18, zmm18, zmm18
150 vmulps zmm19, zmm19, zmm19
151 vmulps zmm20, zmm20, zmm20
152
153 // y = 1 / y (full-precision reciprocal, matches generic .recip())
154 vdivps zmm21, zmm6, zmm17
155 vdivps zmm22, zmm6, zmm18
156 vdivps zmm23, zmm6, zmm19
157 vdivps zmm24, zmm6, zmm20
158
159 // y = 1 - y
160 vsubps zmm21, zmm6, zmm21
161 vsubps zmm22, zmm6, zmm22
162 vsubps zmm23, zmm6, zmm23
163 vsubps zmm24, zmm6, zmm24
164
165 // copysign: stamp the original sign bit onto the (positive) result
166 vorps zmm21, zmm21, zmm13
167 vorps zmm22, zmm22, zmm14
168 vorps zmm23, zmm23, zmm15
169 vorps zmm24, zmm24, zmm16
170
171 // store
172 vmovaps [{ptr}], zmm21
173 vmovaps [{ptr} + 64], zmm22
174 vmovaps [{ptr} + 128], zmm23
175 vmovaps [{ptr} + 192], zmm24
176
177 add {ptr}, 256
178 sub {len}, 64
179 jnz 2b
180 ",
181 len = inout(reg) len => _,
182 ptr = inout(reg) ptr => _,
183 inout("xmm0") A1 => _,
184 inout("xmm1") A2 => _,
185 inout("xmm2") A3 => _,
186 inout("xmm3") A4 => _,
187 inout("xmm4") A5 => _,
188 inout("xmm5") A6 => _,
189 inout("xmm6") 1f32 => _,
190 inout("xmm7") ABS_MASK => _,
191 inout("xmm8") SIGN_MASK => _,
192 out("zmm9") _, out("zmm10") _, out("zmm11") _, out("zmm12") _,
193 out("zmm13") _, out("zmm14") _, out("zmm15") _, out("zmm16") _,
194 out("zmm17") _, out("zmm18") _, out("zmm19") _, out("zmm20") _,
195 out("zmm21") _, out("zmm22") _, out("zmm23") _, out("zmm24") _,
196 );
197 }
198}
199
200#[cfg(test)]
201pub mod test_x86_64_avx512_erf_f32_64n {
202 use super::*;
203 crate::erf_frame_tests!(is_x86_feature_detected!("avx512f"), f32, x86_64_avx512_erf_f32_64n);
204}