velesdb_core/simd_native/reduction.rs
1//! Shared horizontal reduction helpers and multi-accumulator loop macros for SIMD.
2//!
3//! Provides canonical `hsum_avx256` and `hsum_avx512` functions so that
4//! every AVX2/AVX-512 kernel reduces accumulators through a single code path.
5//! Also provides 4-accumulator ([`simd_4acc_dot_loop!`], [`simd_4acc_l2_loop!`])
6//! and 8-accumulator ([`simd_8acc_dot_loop!`], [`simd_8acc_l2_loop!`]) macros
7//! that encode the ILP unrolling patterns used across all ISAs.
8
9#![allow(clippy::incompatible_msrv)] // SIMD intrinsics gated behind target_feature + runtime detection
10
11// =============================================================================
12// Horizontal sum helpers
13// =============================================================================
14
15/// Horizontal sum of 8 packed f32 values in an AVX2 `__m256` register.
16///
17/// Reduces `[a, b, c, d, e, f, g, h]` → `a + b + c + d + e + f + g + h`.
18///
19/// # Safety
20///
21/// Caller must ensure CPU supports AVX2 (enforced by `#[target_feature]`
22/// and runtime detection via `simd_level()`).
23#[cfg(target_arch = "x86_64")]
24#[target_feature(enable = "avx2")]
25#[inline]
26pub(crate) unsafe fn hsum_avx256(v: std::arch::x86_64::__m256) -> f32 {
27 // SAFETY: All intrinsics require AVX2 which is guaranteed by #[target_feature].
28 // No pointer arithmetic or memory access — operates purely on register values.
29 use std::arch::x86_64::{
30 _mm256_castps256_ps128, _mm256_extractf128_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32,
31 _mm_movehdup_ps, _mm_movehl_ps,
32 };
33 let hi = _mm256_extractf128_ps(v, 1);
34 let lo = _mm256_castps256_ps128(v);
35 let sum128 = _mm_add_ps(lo, hi);
36 let shuf = _mm_movehdup_ps(sum128);
37 let sums = _mm_add_ps(sum128, shuf);
38 let shuf2 = _mm_movehl_ps(sums, sums);
39 _mm_cvtss_f32(_mm_add_ss(sums, shuf2))
40}
41
42/// Horizontal sum of 16 packed f32 values in an AVX-512 `__m512` register.
43///
44/// Wraps the native `_mm512_reduce_add_ps` intrinsic for API symmetry
45/// with [`hsum_avx256`].
46///
47/// # Safety
48///
49/// Caller must ensure CPU supports AVX-512F (enforced by `#[target_feature]`
50/// and runtime detection via `simd_level()`).
51#[cfg(target_arch = "x86_64")]
52#[target_feature(enable = "avx512f")]
53#[inline]
54pub(crate) unsafe fn hsum_avx512(v: std::arch::x86_64::__m512) -> f32 {
55 // SAFETY: `_mm512_reduce_add_ps` requires AVX-512F, guaranteed by #[target_feature].
56 // No pointer arithmetic — operates purely on register values.
57 std::arch::x86_64::_mm512_reduce_add_ps(v)
58}
59
60// =============================================================================
61// 4-accumulator reduction helpers
62// =============================================================================
63
64/// Combines 4 AVX2 accumulators into one via binary-tree addition.
65///
66/// # Safety
67///
68/// Caller must ensure CPU supports AVX2.
69#[cfg(target_arch = "x86_64")]
70#[target_feature(enable = "avx2")]
71#[inline]
72#[allow(dead_code)] // Available for custom 4-acc kernels outside the macro
73pub(crate) unsafe fn reduce_4acc_avx256(
74 a0: std::arch::x86_64::__m256,
75 a1: std::arch::x86_64::__m256,
76 a2: std::arch::x86_64::__m256,
77 a3: std::arch::x86_64::__m256,
78) -> std::arch::x86_64::__m256 {
79 // SAFETY: `_mm256_add_ps` requires AVX2, guaranteed by #[target_feature].
80 // No pointer arithmetic — operates purely on register values.
81 use std::arch::x86_64::_mm256_add_ps;
82 let sum01 = _mm256_add_ps(a0, a1);
83 let sum23 = _mm256_add_ps(a2, a3);
84 _mm256_add_ps(sum01, sum23)
85}
86
87/// Combines 4 AVX-512 accumulators into one via binary-tree addition.
88///
89/// # Safety
90///
91/// Caller must ensure CPU supports AVX-512F.
92#[cfg(target_arch = "x86_64")]
93#[target_feature(enable = "avx512f")]
94#[inline]
95#[allow(dead_code)] // Available for custom 4-acc kernels outside the macro
96pub(crate) unsafe fn reduce_4acc_avx512(
97 a0: std::arch::x86_64::__m512,
98 a1: std::arch::x86_64::__m512,
99 a2: std::arch::x86_64::__m512,
100 a3: std::arch::x86_64::__m512,
101) -> std::arch::x86_64::__m512 {
102 // SAFETY: `_mm512_add_ps` requires AVX-512F, guaranteed by #[target_feature].
103 // No pointer arithmetic — operates purely on register values.
104 use std::arch::x86_64::_mm512_add_ps;
105 let sum01 = _mm512_add_ps(a0, a1);
106 let sum23 = _mm512_add_ps(a2, a3);
107 _mm512_add_ps(sum01, sum23)
108}
109
110// =============================================================================
111// 4-accumulator loop macros
112// =============================================================================
113
114/// 4-accumulator unrolled SIMD loop for dot product (ILP optimization).
115///
116/// Processes `4 × lane` elements per iteration using 4 independent
117/// accumulators to hide FMA latency. Works across AVX2, AVX-512, and NEON
118/// by accepting ISA-specific intrinsics as parameters.
119///
120/// Returns `(combined_accumulator, updated_a_ptr, updated_b_ptr)`.
121///
122/// # Arguments
123///
124/// - `$a_ptr`, `$b_ptr` — Starting pointers for the two input vectors
125/// - `$end` — End pointer for the main loop (aligned to `4 × lane`)
126/// - `$zero` — Zero-init expression (e.g., `_mm256_setzero_ps()`)
127/// - `$load` — SIMD load intrinsic (e.g., `_mm256_loadu_ps`)
128/// - `$fmadd` — FMA intrinsic with signature `fmadd(a, b, acc) → a*b + acc`
129/// - `$add` — SIMD add intrinsic (e.g., `_mm256_add_ps`)
130/// - `$lane` — Number of f32 elements per SIMD register (4/8/16)
131///
132/// # Safety
133///
134/// Must be invoked inside an `unsafe` context where the specified
135/// SIMD intrinsics are valid for the current CPU.
136#[macro_export]
137macro_rules! simd_4acc_dot_loop {
138 ($a_ptr:expr, $b_ptr:expr, $end:expr,
139 $zero:expr, $load:ident, $fmadd:ident, $add:ident, $lane:expr) => {{
140 let mut acc0 = $zero;
141 let mut acc1 = $zero;
142 let mut acc2 = $zero;
143 let mut acc3 = $zero;
144 let mut a_p = $a_ptr;
145 let mut b_p = $b_ptr;
146
147 while a_p < $end {
148 let va0 = $load(a_p);
149 let vb0 = $load(b_p);
150 acc0 = $fmadd(va0, vb0, acc0);
151
152 let va1 = $load(a_p.add($lane));
153 let vb1 = $load(b_p.add($lane));
154 acc1 = $fmadd(va1, vb1, acc1);
155
156 let va2 = $load(a_p.add(2 * $lane));
157 let vb2 = $load(b_p.add(2 * $lane));
158 acc2 = $fmadd(va2, vb2, acc2);
159
160 let va3 = $load(a_p.add(3 * $lane));
161 let vb3 = $load(b_p.add(3 * $lane));
162 acc3 = $fmadd(va3, vb3, acc3);
163
164 a_p = a_p.add(4 * $lane);
165 b_p = b_p.add(4 * $lane);
166 }
167
168 let sum01 = $add(acc0, acc1);
169 let sum23 = $add(acc2, acc3);
170 ($add(sum01, sum23), a_p, b_p)
171 }};
172}
173
174/// 4-accumulator unrolled SIMD loop for squared L2 distance.
175///
176/// Same structure as [`simd_4acc_dot_loop!`] but computes `sum((a-b)²)`
177/// instead of `sum(a·b)`. Requires an additional `$sub` intrinsic.
178///
179/// # Safety
180///
181/// Same requirements as [`simd_4acc_dot_loop!`].
182#[macro_export]
183macro_rules! simd_4acc_l2_loop {
184 ($a_ptr:expr, $b_ptr:expr, $end:expr,
185 $zero:expr, $load:ident, $sub:ident, $fmadd:ident, $add:ident, $lane:expr) => {{
186 let mut acc0 = $zero;
187 let mut acc1 = $zero;
188 let mut acc2 = $zero;
189 let mut acc3 = $zero;
190 let mut a_p = $a_ptr;
191 let mut b_p = $b_ptr;
192
193 while a_p < $end {
194 let va0 = $load(a_p);
195 let vb0 = $load(b_p);
196 let diff0 = $sub(va0, vb0);
197 acc0 = $fmadd(diff0, diff0, acc0);
198
199 let va1 = $load(a_p.add($lane));
200 let vb1 = $load(b_p.add($lane));
201 let diff1 = $sub(va1, vb1);
202 acc1 = $fmadd(diff1, diff1, acc1);
203
204 let va2 = $load(a_p.add(2 * $lane));
205 let vb2 = $load(b_p.add(2 * $lane));
206 let diff2 = $sub(va2, vb2);
207 acc2 = $fmadd(diff2, diff2, acc2);
208
209 let va3 = $load(a_p.add(3 * $lane));
210 let vb3 = $load(b_p.add(3 * $lane));
211 let diff3 = $sub(va3, vb3);
212 acc3 = $fmadd(diff3, diff3, acc3);
213
214 a_p = a_p.add(4 * $lane);
215 b_p = b_p.add(4 * $lane);
216 }
217
218 let sum01 = $add(acc0, acc1);
219 let sum23 = $add(acc2, acc3);
220 ($add(sum01, sum23), a_p, b_p)
221 }};
222}
223
224// =============================================================================
225// 8-accumulator loop macros
226// =============================================================================
227
228/// 8-accumulator unrolled SIMD loop for dot product (ILP optimization).
229///
230/// Processes `8 × lane` elements per iteration using 8 independent
231/// accumulators to maximally hide FMA latency on wide-issue CPUs.
232/// Targets AVX-512 kernels for very large vectors (>= 1024 dimensions).
233///
234/// Returns `(combined_accumulator, updated_a_ptr, updated_b_ptr)`.
235///
236/// # Arguments
237///
238/// - `$a_ptr`, `$b_ptr` — Starting pointers for the two input vectors
239/// - `$end` — End pointer for the main loop (aligned to `8 × lane`)
240/// - `$zero` — Zero-init expression (e.g., `_mm512_setzero_ps()`)
241/// - `$load` — SIMD load intrinsic (e.g., `_mm512_loadu_ps`)
242/// - `$fmadd` — FMA intrinsic with signature `fmadd(a, b, acc) → a*b + acc`
243/// - `$add` — SIMD add intrinsic (e.g., `_mm512_add_ps`)
244/// - `$lane` — Number of f32 elements per SIMD register (16 for AVX-512)
245///
246/// # Safety
247///
248/// Must be invoked inside an `unsafe` context where the specified
249/// SIMD intrinsics are valid for the current CPU.
250#[macro_export]
251macro_rules! simd_8acc_dot_loop {
252 ($a_ptr:expr, $b_ptr:expr, $end:expr,
253 $zero:expr, $load:ident, $fmadd:ident, $add:ident, $lane:expr) => {{
254 let mut s0 = $zero;
255 let mut s1 = $zero;
256 let mut s2 = $zero;
257 let mut s3 = $zero;
258 let mut s4 = $zero;
259 let mut s5 = $zero;
260 let mut s6 = $zero;
261 let mut s7 = $zero;
262 let mut a_p = $a_ptr;
263 let mut b_p = $b_ptr;
264
265 while a_p < $end {
266 s0 = $fmadd($load(a_p), $load(b_p), s0);
267 s1 = $fmadd($load(a_p.add($lane)), $load(b_p.add($lane)), s1);
268 s2 = $fmadd($load(a_p.add(2 * $lane)), $load(b_p.add(2 * $lane)), s2);
269 s3 = $fmadd($load(a_p.add(3 * $lane)), $load(b_p.add(3 * $lane)), s3);
270 s4 = $fmadd($load(a_p.add(4 * $lane)), $load(b_p.add(4 * $lane)), s4);
271 s5 = $fmadd($load(a_p.add(5 * $lane)), $load(b_p.add(5 * $lane)), s5);
272 s6 = $fmadd($load(a_p.add(6 * $lane)), $load(b_p.add(6 * $lane)), s6);
273 s7 = $fmadd($load(a_p.add(7 * $lane)), $load(b_p.add(7 * $lane)), s7);
274
275 a_p = a_p.add(8 * $lane);
276 b_p = b_p.add(8 * $lane);
277 }
278
279 // Binary-tree reduction: 8 → 4 → 2 → 1
280 s0 = $add(s0, s4);
281 s1 = $add(s1, s5);
282 s2 = $add(s2, s6);
283 s3 = $add(s3, s7);
284 let sum01 = $add(s0, s1);
285 let sum23 = $add(s2, s3);
286 ($add(sum01, sum23), a_p, b_p)
287 }};
288}
289
290/// 8-accumulator unrolled SIMD loop for squared L2 distance.
291///
292/// Same structure as [`simd_8acc_dot_loop!`] but computes `sum((a-b)²)`
293/// instead of `sum(a·b)`. Requires an additional `$sub` intrinsic.
294///
295/// # Safety
296///
297/// Same requirements as [`simd_8acc_dot_loop!`].
298#[macro_export]
299macro_rules! simd_8acc_l2_loop {
300 ($a_ptr:expr, $b_ptr:expr, $end:expr,
301 $zero:expr, $load:ident, $sub:ident, $fmadd:ident, $add:ident, $lane:expr) => {{
302 let mut s0 = $zero;
303 let mut s1 = $zero;
304 let mut s2 = $zero;
305 let mut s3 = $zero;
306 let mut s4 = $zero;
307 let mut s5 = $zero;
308 let mut s6 = $zero;
309 let mut s7 = $zero;
310 let mut a_p = $a_ptr;
311 let mut b_p = $b_ptr;
312
313 while a_p < $end {
314 let d0 = $sub($load(a_p), $load(b_p));
315 s0 = $fmadd(d0, d0, s0);
316 let d1 = $sub($load(a_p.add($lane)), $load(b_p.add($lane)));
317 s1 = $fmadd(d1, d1, s1);
318 let d2 = $sub($load(a_p.add(2 * $lane)), $load(b_p.add(2 * $lane)));
319 s2 = $fmadd(d2, d2, s2);
320 let d3 = $sub($load(a_p.add(3 * $lane)), $load(b_p.add(3 * $lane)));
321 s3 = $fmadd(d3, d3, s3);
322 let d4 = $sub($load(a_p.add(4 * $lane)), $load(b_p.add(4 * $lane)));
323 s4 = $fmadd(d4, d4, s4);
324 let d5 = $sub($load(a_p.add(5 * $lane)), $load(b_p.add(5 * $lane)));
325 s5 = $fmadd(d5, d5, s5);
326 let d6 = $sub($load(a_p.add(6 * $lane)), $load(b_p.add(6 * $lane)));
327 s6 = $fmadd(d6, d6, s6);
328 let d7 = $sub($load(a_p.add(7 * $lane)), $load(b_p.add(7 * $lane)));
329 s7 = $fmadd(d7, d7, s7);
330
331 a_p = a_p.add(8 * $lane);
332 b_p = b_p.add(8 * $lane);
333 }
334
335 // Binary-tree reduction: 8 → 4 → 2 → 1
336 s0 = $add(s0, s4);
337 s1 = $add(s1, s5);
338 s2 = $add(s2, s6);
339 s3 = $add(s3, s7);
340 let sum01 = $add(s0, s1);
341 let sum23 = $add(s2, s3);
342 ($add(sum01, sum23), a_p, b_p)
343 }};
344}
345
346// Re-export macros for crate-internal use
347#[allow(unused_imports)]
348pub(crate) use simd_4acc_dot_loop;
349#[allow(unused_imports)]
350pub(crate) use simd_4acc_l2_loop;
351#[allow(unused_imports)]
352pub(crate) use simd_8acc_dot_loop;
353#[allow(unused_imports)]
354pub(crate) use simd_8acc_l2_loop;