Skip to main content

tract_linalg/x86_64_fma/
erf.rs

1// AVX-512 (zmm, 16-wide) error function kernel. Mirrors generic/erf.rs::serf
2// (Abramowitz & Stegun 7.1.26 six-coefficient polynomial) but runs the
3// polynomial via FMA chains over 4 zmm registers per iteration (64 lanes per
4// loop step). Validated against the generic scalar reference via
5// erf_frame_tests! at SuperApproximate tolerance.
6//
7// Algorithm (per lane):
8//   signum = sign(x);  abs = |x|
9//   y = a6
10//   y = y*abs + a5            (Horner FMA)
11//   y = y*abs + a4
12//   y = y*abs + a3
13//   y = y*abs + a2
14//   y = y*abs + a1
15//   y = y * abs               (final factor of abs)
16//   y = y + 1
17//   y = y^16                  (4 sequential squares)
18//   y = 1 / y                 (vdivps, full IEEE precision)
19//   y = 1 - y
20//   result = copysign(y, x)
21
22ew_impl_wrap!(
23    f32,
24    x86_64_avx512_erf_f32_64n,
25    64,
26    16,
27    (),
28    #[inline(never)]
29    fn run(buf: &mut [f32], _: ()) {
30        debug_assert!(buf.len() % Self::nr() == 0);
31        debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
32        if buf.is_empty() {
33            return;
34        }
35        unsafe { x86_64_avx512_erf_f32_64n_run(buf) }
36    }
37);
38
39#[target_feature(enable = "avx512f")]
40unsafe fn x86_64_avx512_erf_f32_64n_run(buf: &mut [f32]) {
41    unsafe {
42        let len = buf.len();
43        let ptr = buf.as_ptr();
44        const A1: f32 = 0.0705230784;
45        const A2: f32 = 0.0422820123;
46        const A3: f32 = 0.0092705272;
47        const A4: f32 = 0.0001520143;
48        const A5: f32 = 0.0002765672;
49        const A6: f32 = 0.0000430638;
50        // 0x7fffffff: positive-finite mask (clears sign bit). As f32 bits, this
51        // is NaN; we never use it as a numeric value — only as a bit mask via vandps.
52        const ABS_MASK: f32 = f32::from_bits(0x7fffffff);
53        const SIGN_MASK: f32 = f32::from_bits(0x80000000);
54        std::arch::asm!("
55            // broadcast constants (xmmN -> zmmN, broadcast across all 16 lanes)
56            vbroadcastss zmm0, xmm0           // a1
57            vbroadcastss zmm1, xmm1           // a2
58            vbroadcastss zmm2, xmm2           // a3
59            vbroadcastss zmm3, xmm3           // a4
60            vbroadcastss zmm4, xmm4           // a5
61            vbroadcastss zmm5, xmm5           // a6
62            vbroadcastss zmm6, xmm6           // 1.0
63            vbroadcastss zmm7, xmm7           // abs mask (0x7fffffff)
64            vbroadcastss zmm8, xmm8           // sign mask (0x80000000)
65            2:
66                // load 4 zmm of input
67                vmovaps zmm9,  [{ptr}]
68                vmovaps zmm10, [{ptr} + 64]
69                vmovaps zmm11, [{ptr} + 128]
70                vmovaps zmm12, [{ptr} + 192]
71
72                // sign[i] = x[i] & SIGN_MASK   (keeps only the sign bit)
73                vandps zmm13, zmm9,  zmm8
74                vandps zmm14, zmm10, zmm8
75                vandps zmm15, zmm11, zmm8
76                vandps zmm16, zmm12, zmm8
77
78                // abs[i] = x[i] & ABS_MASK     (clears the sign bit)
79                vandps zmm9,  zmm9,  zmm7
80                vandps zmm10, zmm10, zmm7
81                vandps zmm11, zmm11, zmm7
82                vandps zmm12, zmm12, zmm7
83
84                // y = a6 (in zmm17..20, 4 independent channels)
85                vmovaps zmm17, zmm5
86                vmovaps zmm18, zmm5
87                vmovaps zmm19, zmm5
88                vmovaps zmm20, zmm5
89
90                // y = y*abs + a5
91                vfmadd213ps zmm17, zmm9,  zmm4
92                vfmadd213ps zmm18, zmm10, zmm4
93                vfmadd213ps zmm19, zmm11, zmm4
94                vfmadd213ps zmm20, zmm12, zmm4
95
96                // y = y*abs + a4
97                vfmadd213ps zmm17, zmm9,  zmm3
98                vfmadd213ps zmm18, zmm10, zmm3
99                vfmadd213ps zmm19, zmm11, zmm3
100                vfmadd213ps zmm20, zmm12, zmm3
101
102                // y = y*abs + a3
103                vfmadd213ps zmm17, zmm9,  zmm2
104                vfmadd213ps zmm18, zmm10, zmm2
105                vfmadd213ps zmm19, zmm11, zmm2
106                vfmadd213ps zmm20, zmm12, zmm2
107
108                // y = y*abs + a2
109                vfmadd213ps zmm17, zmm9,  zmm1
110                vfmadd213ps zmm18, zmm10, zmm1
111                vfmadd213ps zmm19, zmm11, zmm1
112                vfmadd213ps zmm20, zmm12, zmm1
113
114                // y = y*abs + a1
115                vfmadd213ps zmm17, zmm9,  zmm0
116                vfmadd213ps zmm18, zmm10, zmm0
117                vfmadd213ps zmm19, zmm11, zmm0
118                vfmadd213ps zmm20, zmm12, zmm0
119
120                // y = y * abs  (final factor)
121                vmulps zmm17, zmm17, zmm9
122                vmulps zmm18, zmm18, zmm10
123                vmulps zmm19, zmm19, zmm11
124                vmulps zmm20, zmm20, zmm12
125
126                // y = y + 1
127                vaddps zmm17, zmm17, zmm6
128                vaddps zmm18, zmm18, zmm6
129                vaddps zmm19, zmm19, zmm6
130                vaddps zmm20, zmm20, zmm6
131
132                // y^16: square 4 times
133                vmulps zmm17, zmm17, zmm17
134                vmulps zmm18, zmm18, zmm18
135                vmulps zmm19, zmm19, zmm19
136                vmulps zmm20, zmm20, zmm20
137
138                vmulps zmm17, zmm17, zmm17
139                vmulps zmm18, zmm18, zmm18
140                vmulps zmm19, zmm19, zmm19
141                vmulps zmm20, zmm20, zmm20
142
143                vmulps zmm17, zmm17, zmm17
144                vmulps zmm18, zmm18, zmm18
145                vmulps zmm19, zmm19, zmm19
146                vmulps zmm20, zmm20, zmm20
147
148                vmulps zmm17, zmm17, zmm17
149                vmulps zmm18, zmm18, zmm18
150                vmulps zmm19, zmm19, zmm19
151                vmulps zmm20, zmm20, zmm20
152
153                // y = 1 / y      (full-precision reciprocal, matches generic .recip())
154                vdivps zmm21, zmm6, zmm17
155                vdivps zmm22, zmm6, zmm18
156                vdivps zmm23, zmm6, zmm19
157                vdivps zmm24, zmm6, zmm20
158
159                // y = 1 - y
160                vsubps zmm21, zmm6, zmm21
161                vsubps zmm22, zmm6, zmm22
162                vsubps zmm23, zmm6, zmm23
163                vsubps zmm24, zmm6, zmm24
164
165                // copysign: stamp the original sign bit onto the (positive) result
166                vorps zmm21, zmm21, zmm13
167                vorps zmm22, zmm22, zmm14
168                vorps zmm23, zmm23, zmm15
169                vorps zmm24, zmm24, zmm16
170
171                // store
172                vmovaps [{ptr}],       zmm21
173                vmovaps [{ptr} + 64],  zmm22
174                vmovaps [{ptr} + 128], zmm23
175                vmovaps [{ptr} + 192], zmm24
176
177                add {ptr}, 256
178                sub {len}, 64
179                jnz 2b
180            ",
181            len = inout(reg) len => _,
182            ptr = inout(reg) ptr => _,
183            inout("xmm0") A1 => _,
184            inout("xmm1") A2 => _,
185            inout("xmm2") A3 => _,
186            inout("xmm3") A4 => _,
187            inout("xmm4") A5 => _,
188            inout("xmm5") A6 => _,
189            inout("xmm6") 1f32 => _,
190            inout("xmm7") ABS_MASK => _,
191            inout("xmm8") SIGN_MASK => _,
192            out("zmm9")  _, out("zmm10") _, out("zmm11") _, out("zmm12") _,
193            out("zmm13") _, out("zmm14") _, out("zmm15") _, out("zmm16") _,
194            out("zmm17") _, out("zmm18") _, out("zmm19") _, out("zmm20") _,
195            out("zmm21") _, out("zmm22") _, out("zmm23") _, out("zmm24") _,
196        );
197    }
198}
199
200#[cfg(test)]
201pub mod test_x86_64_avx512_erf_f32_64n {
202    use super::*;
203    crate::erf_frame_tests!(is_x86_feature_detected!("avx512f"), f32, x86_64_avx512_erf_f32_64n);
204}