Skip to main content

velesdb_core/simd_native/
reduction.rs

1//! Shared horizontal reduction helpers and 4-accumulator loop macros for SIMD.
2//!
3//! Provides canonical `hsum_avx256` and `hsum_avx512` functions so that
4//! every AVX2/AVX-512 kernel reduces accumulators through a single code path.
5//! Also provides [`simd_4acc_dot_loop!`] and [`simd_4acc_l2_loop!`] macros
6//! that encode the 4-accumulator ILP unrolling pattern used across all ISAs.
7
8#![allow(clippy::incompatible_msrv)] // SIMD intrinsics gated behind target_feature + runtime detection
9
10// =============================================================================
11// Horizontal sum helpers
12// =============================================================================
13
14/// Horizontal sum of 8 packed f32 values in an AVX2 `__m256` register.
15///
16/// Reduces `[a, b, c, d, e, f, g, h]` → `a + b + c + d + e + f + g + h`.
17///
18/// # Safety
19///
20/// Caller must ensure CPU supports AVX2 (enforced by `#[target_feature]`
21/// and runtime detection via `simd_level()`).
22#[cfg(target_arch = "x86_64")]
23#[target_feature(enable = "avx2")]
24#[inline]
25pub(crate) unsafe fn hsum_avx256(v: std::arch::x86_64::__m256) -> f32 {
26    // SAFETY: All intrinsics require AVX2 which is guaranteed by #[target_feature].
27    // No pointer arithmetic or memory access — operates purely on register values.
28    use std::arch::x86_64::{
29        _mm256_castps256_ps128, _mm256_extractf128_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32,
30        _mm_movehdup_ps, _mm_movehl_ps,
31    };
32    let hi = _mm256_extractf128_ps(v, 1);
33    let lo = _mm256_castps256_ps128(v);
34    let sum128 = _mm_add_ps(lo, hi);
35    let shuf = _mm_movehdup_ps(sum128);
36    let sums = _mm_add_ps(sum128, shuf);
37    let shuf2 = _mm_movehl_ps(sums, sums);
38    _mm_cvtss_f32(_mm_add_ss(sums, shuf2))
39}
40
41/// Horizontal sum of 16 packed f32 values in an AVX-512 `__m512` register.
42///
43/// Wraps the native `_mm512_reduce_add_ps` intrinsic for API symmetry
44/// with [`hsum_avx256`].
45///
46/// # Safety
47///
48/// Caller must ensure CPU supports AVX-512F (enforced by `#[target_feature]`
49/// and runtime detection via `simd_level()`).
50#[cfg(target_arch = "x86_64")]
51#[target_feature(enable = "avx512f")]
52#[inline]
53#[allow(dead_code)] // Available for AVX-512 kernels; currently they use _mm512_reduce_add_ps directly
54pub(crate) unsafe fn hsum_avx512(v: std::arch::x86_64::__m512) -> f32 {
55    // SAFETY: `_mm512_reduce_add_ps` requires AVX-512F, guaranteed by #[target_feature].
56    // No pointer arithmetic — operates purely on register values.
57    std::arch::x86_64::_mm512_reduce_add_ps(v)
58}
59
60// =============================================================================
61// 4-accumulator reduction helpers
62// =============================================================================
63
64/// Combines 4 AVX2 accumulators into one via binary-tree addition.
65///
66/// # Safety
67///
68/// Caller must ensure CPU supports AVX2.
69#[cfg(target_arch = "x86_64")]
70#[target_feature(enable = "avx2")]
71#[inline]
72#[allow(dead_code)] // Available for custom 4-acc kernels outside the macro
73pub(crate) unsafe fn reduce_4acc_avx256(
74    a0: std::arch::x86_64::__m256,
75    a1: std::arch::x86_64::__m256,
76    a2: std::arch::x86_64::__m256,
77    a3: std::arch::x86_64::__m256,
78) -> std::arch::x86_64::__m256 {
79    // SAFETY: `_mm256_add_ps` requires AVX2, guaranteed by #[target_feature].
80    // No pointer arithmetic — operates purely on register values.
81    use std::arch::x86_64::_mm256_add_ps;
82    let sum01 = _mm256_add_ps(a0, a1);
83    let sum23 = _mm256_add_ps(a2, a3);
84    _mm256_add_ps(sum01, sum23)
85}
86
87/// Combines 4 AVX-512 accumulators into one via binary-tree addition.
88///
89/// # Safety
90///
91/// Caller must ensure CPU supports AVX-512F.
92#[cfg(target_arch = "x86_64")]
93#[target_feature(enable = "avx512f")]
94#[inline]
95#[allow(dead_code)] // Available for custom 4-acc kernels outside the macro
96pub(crate) unsafe fn reduce_4acc_avx512(
97    a0: std::arch::x86_64::__m512,
98    a1: std::arch::x86_64::__m512,
99    a2: std::arch::x86_64::__m512,
100    a3: std::arch::x86_64::__m512,
101) -> std::arch::x86_64::__m512 {
102    // SAFETY: `_mm512_add_ps` requires AVX-512F, guaranteed by #[target_feature].
103    // No pointer arithmetic — operates purely on register values.
104    use std::arch::x86_64::_mm512_add_ps;
105    let sum01 = _mm512_add_ps(a0, a1);
106    let sum23 = _mm512_add_ps(a2, a3);
107    _mm512_add_ps(sum01, sum23)
108}
109
110// =============================================================================
111// 4-accumulator loop macros
112// =============================================================================
113
114/// 4-accumulator unrolled SIMD loop for dot product (ILP optimization).
115///
116/// Processes `4 × lane` elements per iteration using 4 independent
117/// accumulators to hide FMA latency. Works across AVX2, AVX-512, and NEON
118/// by accepting ISA-specific intrinsics as parameters.
119///
120/// Returns `(combined_accumulator, updated_a_ptr, updated_b_ptr)`.
121///
122/// # Arguments
123///
124/// - `$a_ptr`, `$b_ptr` — Starting pointers for the two input vectors
125/// - `$end` — End pointer for the main loop (aligned to `4 × lane`)
126/// - `$zero` — Zero-init expression (e.g., `_mm256_setzero_ps()`)
127/// - `$load` — SIMD load intrinsic (e.g., `_mm256_loadu_ps`)
128/// - `$fmadd` — FMA intrinsic with signature `fmadd(a, b, acc) → a*b + acc`
129/// - `$add` — SIMD add intrinsic (e.g., `_mm256_add_ps`)
130/// - `$lane` — Number of f32 elements per SIMD register (4/8/16)
131///
132/// # Safety
133///
134/// Must be invoked inside an `unsafe` context where the specified
135/// SIMD intrinsics are valid for the current CPU.
136#[macro_export]
137macro_rules! simd_4acc_dot_loop {
138    ($a_ptr:expr, $b_ptr:expr, $end:expr,
139     $zero:expr, $load:ident, $fmadd:ident, $add:ident, $lane:expr) => {{
140        let mut acc0 = $zero;
141        let mut acc1 = $zero;
142        let mut acc2 = $zero;
143        let mut acc3 = $zero;
144        let mut a_p = $a_ptr;
145        let mut b_p = $b_ptr;
146
147        while a_p < $end {
148            let va0 = $load(a_p);
149            let vb0 = $load(b_p);
150            acc0 = $fmadd(va0, vb0, acc0);
151
152            let va1 = $load(a_p.add($lane));
153            let vb1 = $load(b_p.add($lane));
154            acc1 = $fmadd(va1, vb1, acc1);
155
156            let va2 = $load(a_p.add(2 * $lane));
157            let vb2 = $load(b_p.add(2 * $lane));
158            acc2 = $fmadd(va2, vb2, acc2);
159
160            let va3 = $load(a_p.add(3 * $lane));
161            let vb3 = $load(b_p.add(3 * $lane));
162            acc3 = $fmadd(va3, vb3, acc3);
163
164            a_p = a_p.add(4 * $lane);
165            b_p = b_p.add(4 * $lane);
166        }
167
168        let sum01 = $add(acc0, acc1);
169        let sum23 = $add(acc2, acc3);
170        ($add(sum01, sum23), a_p, b_p)
171    }};
172}
173
174/// 4-accumulator unrolled SIMD loop for squared L2 distance.
175///
176/// Same structure as [`simd_4acc_dot_loop!`] but computes `sum((a-b)²)`
177/// instead of `sum(a·b)`. Requires an additional `$sub` intrinsic.
178///
179/// # Safety
180///
181/// Same requirements as [`simd_4acc_dot_loop!`].
182#[macro_export]
183macro_rules! simd_4acc_l2_loop {
184    ($a_ptr:expr, $b_ptr:expr, $end:expr,
185     $zero:expr, $load:ident, $sub:ident, $fmadd:ident, $add:ident, $lane:expr) => {{
186        let mut acc0 = $zero;
187        let mut acc1 = $zero;
188        let mut acc2 = $zero;
189        let mut acc3 = $zero;
190        let mut a_p = $a_ptr;
191        let mut b_p = $b_ptr;
192
193        while a_p < $end {
194            let va0 = $load(a_p);
195            let vb0 = $load(b_p);
196            let diff0 = $sub(va0, vb0);
197            acc0 = $fmadd(diff0, diff0, acc0);
198
199            let va1 = $load(a_p.add($lane));
200            let vb1 = $load(b_p.add($lane));
201            let diff1 = $sub(va1, vb1);
202            acc1 = $fmadd(diff1, diff1, acc1);
203
204            let va2 = $load(a_p.add(2 * $lane));
205            let vb2 = $load(b_p.add(2 * $lane));
206            let diff2 = $sub(va2, vb2);
207            acc2 = $fmadd(diff2, diff2, acc2);
208
209            let va3 = $load(a_p.add(3 * $lane));
210            let vb3 = $load(b_p.add(3 * $lane));
211            let diff3 = $sub(va3, vb3);
212            acc3 = $fmadd(diff3, diff3, acc3);
213
214            a_p = a_p.add(4 * $lane);
215            b_p = b_p.add(4 * $lane);
216        }
217
218        let sum01 = $add(acc0, acc1);
219        let sum23 = $add(acc2, acc3);
220        ($add(sum01, sum23), a_p, b_p)
221    }};
222}
223
224// Re-export macros for crate-internal use
225#[allow(unused_imports)]
226pub(crate) use simd_4acc_dot_loop;
227#[allow(unused_imports)]
228pub(crate) use simd_4acc_l2_loop;