Skip to main content

velesdb_core/simd_native/
reduction.rs

1//! Shared horizontal reduction helpers and multi-accumulator loop macros for SIMD.
2//!
3//! Provides canonical `hsum_avx256` and `hsum_avx512` functions so that
4//! every AVX2/AVX-512 kernel reduces accumulators through a single code path.
5//! Also provides 4-accumulator ([`simd_4acc_dot_loop!`], [`simd_4acc_l2_loop!`])
6//! and 8-accumulator ([`simd_8acc_dot_loop!`], [`simd_8acc_l2_loop!`]) macros
7//! that encode the ILP unrolling patterns used across all ISAs.
8
9#![allow(clippy::incompatible_msrv)] // SIMD intrinsics gated behind target_feature + runtime detection
10
11// =============================================================================
12// Horizontal sum helpers
13// =============================================================================
14
15/// Horizontal sum of 8 packed f32 values in an AVX2 `__m256` register.
16///
17/// Reduces `[a, b, c, d, e, f, g, h]` → `a + b + c + d + e + f + g + h`.
18///
19/// # Safety
20///
21/// Caller must ensure CPU supports AVX2 (enforced by `#[target_feature]`
22/// and runtime detection via `simd_level()`).
23#[cfg(target_arch = "x86_64")]
24#[target_feature(enable = "avx2")]
25#[inline]
26pub(crate) unsafe fn hsum_avx256(v: std::arch::x86_64::__m256) -> f32 {
27    // SAFETY: All intrinsics require AVX2 which is guaranteed by #[target_feature].
28    // No pointer arithmetic or memory access — operates purely on register values.
29    use std::arch::x86_64::{
30        _mm256_castps256_ps128, _mm256_extractf128_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32,
31        _mm_movehdup_ps, _mm_movehl_ps,
32    };
33    let hi = _mm256_extractf128_ps(v, 1);
34    let lo = _mm256_castps256_ps128(v);
35    let sum128 = _mm_add_ps(lo, hi);
36    let shuf = _mm_movehdup_ps(sum128);
37    let sums = _mm_add_ps(sum128, shuf);
38    let shuf2 = _mm_movehl_ps(sums, sums);
39    _mm_cvtss_f32(_mm_add_ss(sums, shuf2))
40}
41
42/// Horizontal sum of 16 packed f32 values in an AVX-512 `__m512` register.
43///
44/// Wraps the native `_mm512_reduce_add_ps` intrinsic for API symmetry
45/// with [`hsum_avx256`].
46///
47/// # Safety
48///
49/// Caller must ensure CPU supports AVX-512F (enforced by `#[target_feature]`
50/// and runtime detection via `simd_level()`).
51#[cfg(target_arch = "x86_64")]
52#[target_feature(enable = "avx512f")]
53#[inline]
54pub(crate) unsafe fn hsum_avx512(v: std::arch::x86_64::__m512) -> f32 {
55    // SAFETY: `_mm512_reduce_add_ps` requires AVX-512F, guaranteed by #[target_feature].
56    // No pointer arithmetic — operates purely on register values.
57    std::arch::x86_64::_mm512_reduce_add_ps(v)
58}
59
60// =============================================================================
61// 4-accumulator reduction helpers
62// =============================================================================
63
64/// Combines 4 AVX2 accumulators into one via binary-tree addition.
65///
66/// # Safety
67///
68/// Caller must ensure CPU supports AVX2.
69#[cfg(target_arch = "x86_64")]
70#[target_feature(enable = "avx2")]
71#[inline]
72#[allow(dead_code)] // Available for custom 4-acc kernels outside the macro
73pub(crate) unsafe fn reduce_4acc_avx256(
74    a0: std::arch::x86_64::__m256,
75    a1: std::arch::x86_64::__m256,
76    a2: std::arch::x86_64::__m256,
77    a3: std::arch::x86_64::__m256,
78) -> std::arch::x86_64::__m256 {
79    // SAFETY: `_mm256_add_ps` requires AVX2, guaranteed by #[target_feature].
80    // No pointer arithmetic — operates purely on register values.
81    use std::arch::x86_64::_mm256_add_ps;
82    let sum01 = _mm256_add_ps(a0, a1);
83    let sum23 = _mm256_add_ps(a2, a3);
84    _mm256_add_ps(sum01, sum23)
85}
86
87/// Combines 4 AVX-512 accumulators into one via binary-tree addition.
88///
89/// # Safety
90///
91/// Caller must ensure CPU supports AVX-512F.
92#[cfg(target_arch = "x86_64")]
93#[target_feature(enable = "avx512f")]
94#[inline]
95#[allow(dead_code)] // Available for custom 4-acc kernels outside the macro
96pub(crate) unsafe fn reduce_4acc_avx512(
97    a0: std::arch::x86_64::__m512,
98    a1: std::arch::x86_64::__m512,
99    a2: std::arch::x86_64::__m512,
100    a3: std::arch::x86_64::__m512,
101) -> std::arch::x86_64::__m512 {
102    // SAFETY: `_mm512_add_ps` requires AVX-512F, guaranteed by #[target_feature].
103    // No pointer arithmetic — operates purely on register values.
104    use std::arch::x86_64::_mm512_add_ps;
105    let sum01 = _mm512_add_ps(a0, a1);
106    let sum23 = _mm512_add_ps(a2, a3);
107    _mm512_add_ps(sum01, sum23)
108}
109
110// =============================================================================
111// 4-accumulator loop macros
112// =============================================================================
113
114/// 4-accumulator unrolled SIMD loop for dot product (ILP optimization).
115///
116/// Processes `4 × lane` elements per iteration using 4 independent
117/// accumulators to hide FMA latency. Works across AVX2, AVX-512, and NEON
118/// by accepting ISA-specific intrinsics as parameters.
119///
120/// Returns `(combined_accumulator, updated_a_ptr, updated_b_ptr)`.
121///
122/// # Arguments
123///
124/// - `$a_ptr`, `$b_ptr` — Starting pointers for the two input vectors
125/// - `$end` — End pointer for the main loop (aligned to `4 × lane`)
126/// - `$zero` — Zero-init expression (e.g., `_mm256_setzero_ps()`)
127/// - `$load` — SIMD load intrinsic (e.g., `_mm256_loadu_ps`)
128/// - `$fmadd` — FMA intrinsic with signature `fmadd(a, b, acc) → a*b + acc`
129/// - `$add` — SIMD add intrinsic (e.g., `_mm256_add_ps`)
130/// - `$lane` — Number of f32 elements per SIMD register (4/8/16)
131///
132/// # Safety
133///
134/// Must be invoked inside an `unsafe` context where the specified
135/// SIMD intrinsics are valid for the current CPU.
136#[macro_export]
137macro_rules! simd_4acc_dot_loop {
138    ($a_ptr:expr, $b_ptr:expr, $end:expr,
139     $zero:expr, $load:ident, $fmadd:ident, $add:ident, $lane:expr) => {{
140        let mut acc0 = $zero;
141        let mut acc1 = $zero;
142        let mut acc2 = $zero;
143        let mut acc3 = $zero;
144        let mut a_p = $a_ptr;
145        let mut b_p = $b_ptr;
146
147        while a_p < $end {
148            let va0 = $load(a_p);
149            let vb0 = $load(b_p);
150            acc0 = $fmadd(va0, vb0, acc0);
151
152            let va1 = $load(a_p.add($lane));
153            let vb1 = $load(b_p.add($lane));
154            acc1 = $fmadd(va1, vb1, acc1);
155
156            let va2 = $load(a_p.add(2 * $lane));
157            let vb2 = $load(b_p.add(2 * $lane));
158            acc2 = $fmadd(va2, vb2, acc2);
159
160            let va3 = $load(a_p.add(3 * $lane));
161            let vb3 = $load(b_p.add(3 * $lane));
162            acc3 = $fmadd(va3, vb3, acc3);
163
164            a_p = a_p.add(4 * $lane);
165            b_p = b_p.add(4 * $lane);
166        }
167
168        let sum01 = $add(acc0, acc1);
169        let sum23 = $add(acc2, acc3);
170        ($add(sum01, sum23), a_p, b_p)
171    }};
172}
173
174/// 4-accumulator unrolled SIMD loop for squared L2 distance.
175///
176/// Same structure as [`simd_4acc_dot_loop!`] but computes `sum((a-b)²)`
177/// instead of `sum(a·b)`. Requires an additional `$sub` intrinsic.
178///
179/// # Safety
180///
181/// Same requirements as [`simd_4acc_dot_loop!`].
182#[macro_export]
183macro_rules! simd_4acc_l2_loop {
184    ($a_ptr:expr, $b_ptr:expr, $end:expr,
185     $zero:expr, $load:ident, $sub:ident, $fmadd:ident, $add:ident, $lane:expr) => {{
186        let mut acc0 = $zero;
187        let mut acc1 = $zero;
188        let mut acc2 = $zero;
189        let mut acc3 = $zero;
190        let mut a_p = $a_ptr;
191        let mut b_p = $b_ptr;
192
193        while a_p < $end {
194            let va0 = $load(a_p);
195            let vb0 = $load(b_p);
196            let diff0 = $sub(va0, vb0);
197            acc0 = $fmadd(diff0, diff0, acc0);
198
199            let va1 = $load(a_p.add($lane));
200            let vb1 = $load(b_p.add($lane));
201            let diff1 = $sub(va1, vb1);
202            acc1 = $fmadd(diff1, diff1, acc1);
203
204            let va2 = $load(a_p.add(2 * $lane));
205            let vb2 = $load(b_p.add(2 * $lane));
206            let diff2 = $sub(va2, vb2);
207            acc2 = $fmadd(diff2, diff2, acc2);
208
209            let va3 = $load(a_p.add(3 * $lane));
210            let vb3 = $load(b_p.add(3 * $lane));
211            let diff3 = $sub(va3, vb3);
212            acc3 = $fmadd(diff3, diff3, acc3);
213
214            a_p = a_p.add(4 * $lane);
215            b_p = b_p.add(4 * $lane);
216        }
217
218        let sum01 = $add(acc0, acc1);
219        let sum23 = $add(acc2, acc3);
220        ($add(sum01, sum23), a_p, b_p)
221    }};
222}
223
224// =============================================================================
225// 8-accumulator loop macros
226// =============================================================================
227
228/// 8-accumulator unrolled SIMD loop for dot product (ILP optimization).
229///
230/// Processes `8 × lane` elements per iteration using 8 independent
231/// accumulators to maximally hide FMA latency on wide-issue CPUs.
232/// Targets AVX-512 kernels for very large vectors (>= 1024 dimensions).
233///
234/// Returns `(combined_accumulator, updated_a_ptr, updated_b_ptr)`.
235///
236/// # Arguments
237///
238/// - `$a_ptr`, `$b_ptr` — Starting pointers for the two input vectors
239/// - `$end` — End pointer for the main loop (aligned to `8 × lane`)
240/// - `$zero` — Zero-init expression (e.g., `_mm512_setzero_ps()`)
241/// - `$load` — SIMD load intrinsic (e.g., `_mm512_loadu_ps`)
242/// - `$fmadd` — FMA intrinsic with signature `fmadd(a, b, acc) → a*b + acc`
243/// - `$add` — SIMD add intrinsic (e.g., `_mm512_add_ps`)
244/// - `$lane` — Number of f32 elements per SIMD register (16 for AVX-512)
245///
246/// # Safety
247///
248/// Must be invoked inside an `unsafe` context where the specified
249/// SIMD intrinsics are valid for the current CPU.
250#[macro_export]
251macro_rules! simd_8acc_dot_loop {
252    ($a_ptr:expr, $b_ptr:expr, $end:expr,
253     $zero:expr, $load:ident, $fmadd:ident, $add:ident, $lane:expr) => {{
254        let mut s0 = $zero;
255        let mut s1 = $zero;
256        let mut s2 = $zero;
257        let mut s3 = $zero;
258        let mut s4 = $zero;
259        let mut s5 = $zero;
260        let mut s6 = $zero;
261        let mut s7 = $zero;
262        let mut a_p = $a_ptr;
263        let mut b_p = $b_ptr;
264
265        while a_p < $end {
266            s0 = $fmadd($load(a_p), $load(b_p), s0);
267            s1 = $fmadd($load(a_p.add($lane)), $load(b_p.add($lane)), s1);
268            s2 = $fmadd($load(a_p.add(2 * $lane)), $load(b_p.add(2 * $lane)), s2);
269            s3 = $fmadd($load(a_p.add(3 * $lane)), $load(b_p.add(3 * $lane)), s3);
270            s4 = $fmadd($load(a_p.add(4 * $lane)), $load(b_p.add(4 * $lane)), s4);
271            s5 = $fmadd($load(a_p.add(5 * $lane)), $load(b_p.add(5 * $lane)), s5);
272            s6 = $fmadd($load(a_p.add(6 * $lane)), $load(b_p.add(6 * $lane)), s6);
273            s7 = $fmadd($load(a_p.add(7 * $lane)), $load(b_p.add(7 * $lane)), s7);
274
275            a_p = a_p.add(8 * $lane);
276            b_p = b_p.add(8 * $lane);
277        }
278
279        // Binary-tree reduction: 8 → 4 → 2 → 1
280        s0 = $add(s0, s4);
281        s1 = $add(s1, s5);
282        s2 = $add(s2, s6);
283        s3 = $add(s3, s7);
284        let sum01 = $add(s0, s1);
285        let sum23 = $add(s2, s3);
286        ($add(sum01, sum23), a_p, b_p)
287    }};
288}
289
290/// 8-accumulator unrolled SIMD loop for squared L2 distance.
291///
292/// Same structure as [`simd_8acc_dot_loop!`] but computes `sum((a-b)²)`
293/// instead of `sum(a·b)`. Requires an additional `$sub` intrinsic.
294///
295/// # Safety
296///
297/// Same requirements as [`simd_8acc_dot_loop!`].
298#[macro_export]
299macro_rules! simd_8acc_l2_loop {
300    ($a_ptr:expr, $b_ptr:expr, $end:expr,
301     $zero:expr, $load:ident, $sub:ident, $fmadd:ident, $add:ident, $lane:expr) => {{
302        let mut s0 = $zero;
303        let mut s1 = $zero;
304        let mut s2 = $zero;
305        let mut s3 = $zero;
306        let mut s4 = $zero;
307        let mut s5 = $zero;
308        let mut s6 = $zero;
309        let mut s7 = $zero;
310        let mut a_p = $a_ptr;
311        let mut b_p = $b_ptr;
312
313        while a_p < $end {
314            let d0 = $sub($load(a_p), $load(b_p));
315            s0 = $fmadd(d0, d0, s0);
316            let d1 = $sub($load(a_p.add($lane)), $load(b_p.add($lane)));
317            s1 = $fmadd(d1, d1, s1);
318            let d2 = $sub($load(a_p.add(2 * $lane)), $load(b_p.add(2 * $lane)));
319            s2 = $fmadd(d2, d2, s2);
320            let d3 = $sub($load(a_p.add(3 * $lane)), $load(b_p.add(3 * $lane)));
321            s3 = $fmadd(d3, d3, s3);
322            let d4 = $sub($load(a_p.add(4 * $lane)), $load(b_p.add(4 * $lane)));
323            s4 = $fmadd(d4, d4, s4);
324            let d5 = $sub($load(a_p.add(5 * $lane)), $load(b_p.add(5 * $lane)));
325            s5 = $fmadd(d5, d5, s5);
326            let d6 = $sub($load(a_p.add(6 * $lane)), $load(b_p.add(6 * $lane)));
327            s6 = $fmadd(d6, d6, s6);
328            let d7 = $sub($load(a_p.add(7 * $lane)), $load(b_p.add(7 * $lane)));
329            s7 = $fmadd(d7, d7, s7);
330
331            a_p = a_p.add(8 * $lane);
332            b_p = b_p.add(8 * $lane);
333        }
334
335        // Binary-tree reduction: 8 → 4 → 2 → 1
336        s0 = $add(s0, s4);
337        s1 = $add(s1, s5);
338        s2 = $add(s2, s6);
339        s3 = $add(s3, s7);
340        let sum01 = $add(s0, s1);
341        let sum23 = $add(s2, s3);
342        ($add(sum01, sum23), a_p, b_p)
343    }};
344}
345
346// Re-export macros for crate-internal use
347#[allow(unused_imports)]
348pub(crate) use simd_4acc_dot_loop;
349#[allow(unused_imports)]
350pub(crate) use simd_4acc_l2_loop;
351#[allow(unused_imports)]
352pub(crate) use simd_8acc_dot_loop;
353#[allow(unused_imports)]
354pub(crate) use simd_8acc_l2_loop;