velesdb_core/simd_native/reduction.rs
1//! Shared horizontal reduction helpers and 4-accumulator loop macros for SIMD.
2//!
3//! Provides canonical `hsum_avx256` and `hsum_avx512` functions so that
4//! every AVX2/AVX-512 kernel reduces accumulators through a single code path.
5//! Also provides [`simd_4acc_dot_loop!`] and [`simd_4acc_l2_loop!`] macros
6//! that encode the 4-accumulator ILP unrolling pattern used across all ISAs.
7
8#![allow(clippy::incompatible_msrv)] // SIMD intrinsics gated behind target_feature + runtime detection
9
10// =============================================================================
11// Horizontal sum helpers
12// =============================================================================
13
14/// Horizontal sum of 8 packed f32 values in an AVX2 `__m256` register.
15///
16/// Reduces `[a, b, c, d, e, f, g, h]` → `a + b + c + d + e + f + g + h`.
17///
18/// # Safety
19///
20/// Caller must ensure CPU supports AVX2 (enforced by `#[target_feature]`
21/// and runtime detection via `simd_level()`).
22#[cfg(target_arch = "x86_64")]
23#[target_feature(enable = "avx2")]
24#[inline]
25pub(crate) unsafe fn hsum_avx256(v: std::arch::x86_64::__m256) -> f32 {
26 // SAFETY: All intrinsics require AVX2 which is guaranteed by #[target_feature].
27 // No pointer arithmetic or memory access — operates purely on register values.
28 use std::arch::x86_64::{
29 _mm256_castps256_ps128, _mm256_extractf128_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32,
30 _mm_movehdup_ps, _mm_movehl_ps,
31 };
32 let hi = _mm256_extractf128_ps(v, 1);
33 let lo = _mm256_castps256_ps128(v);
34 let sum128 = _mm_add_ps(lo, hi);
35 let shuf = _mm_movehdup_ps(sum128);
36 let sums = _mm_add_ps(sum128, shuf);
37 let shuf2 = _mm_movehl_ps(sums, sums);
38 _mm_cvtss_f32(_mm_add_ss(sums, shuf2))
39}
40
41/// Horizontal sum of 16 packed f32 values in an AVX-512 `__m512` register.
42///
43/// Wraps the native `_mm512_reduce_add_ps` intrinsic for API symmetry
44/// with [`hsum_avx256`].
45///
46/// # Safety
47///
48/// Caller must ensure CPU supports AVX-512F (enforced by `#[target_feature]`
49/// and runtime detection via `simd_level()`).
50#[cfg(target_arch = "x86_64")]
51#[target_feature(enable = "avx512f")]
52#[inline]
53#[allow(dead_code)] // Available for AVX-512 kernels; currently they use _mm512_reduce_add_ps directly
54pub(crate) unsafe fn hsum_avx512(v: std::arch::x86_64::__m512) -> f32 {
55 // SAFETY: `_mm512_reduce_add_ps` requires AVX-512F, guaranteed by #[target_feature].
56 // No pointer arithmetic — operates purely on register values.
57 std::arch::x86_64::_mm512_reduce_add_ps(v)
58}
59
60// =============================================================================
61// 4-accumulator reduction helpers
62// =============================================================================
63
64/// Combines 4 AVX2 accumulators into one via binary-tree addition.
65///
66/// # Safety
67///
68/// Caller must ensure CPU supports AVX2.
69#[cfg(target_arch = "x86_64")]
70#[target_feature(enable = "avx2")]
71#[inline]
72#[allow(dead_code)] // Available for custom 4-acc kernels outside the macro
73pub(crate) unsafe fn reduce_4acc_avx256(
74 a0: std::arch::x86_64::__m256,
75 a1: std::arch::x86_64::__m256,
76 a2: std::arch::x86_64::__m256,
77 a3: std::arch::x86_64::__m256,
78) -> std::arch::x86_64::__m256 {
79 // SAFETY: `_mm256_add_ps` requires AVX2, guaranteed by #[target_feature].
80 // No pointer arithmetic — operates purely on register values.
81 use std::arch::x86_64::_mm256_add_ps;
82 let sum01 = _mm256_add_ps(a0, a1);
83 let sum23 = _mm256_add_ps(a2, a3);
84 _mm256_add_ps(sum01, sum23)
85}
86
87/// Combines 4 AVX-512 accumulators into one via binary-tree addition.
88///
89/// # Safety
90///
91/// Caller must ensure CPU supports AVX-512F.
92#[cfg(target_arch = "x86_64")]
93#[target_feature(enable = "avx512f")]
94#[inline]
95#[allow(dead_code)] // Available for custom 4-acc kernels outside the macro
96pub(crate) unsafe fn reduce_4acc_avx512(
97 a0: std::arch::x86_64::__m512,
98 a1: std::arch::x86_64::__m512,
99 a2: std::arch::x86_64::__m512,
100 a3: std::arch::x86_64::__m512,
101) -> std::arch::x86_64::__m512 {
102 // SAFETY: `_mm512_add_ps` requires AVX-512F, guaranteed by #[target_feature].
103 // No pointer arithmetic — operates purely on register values.
104 use std::arch::x86_64::_mm512_add_ps;
105 let sum01 = _mm512_add_ps(a0, a1);
106 let sum23 = _mm512_add_ps(a2, a3);
107 _mm512_add_ps(sum01, sum23)
108}
109
110// =============================================================================
111// 4-accumulator loop macros
112// =============================================================================
113
114/// 4-accumulator unrolled SIMD loop for dot product (ILP optimization).
115///
116/// Processes `4 × lane` elements per iteration using 4 independent
117/// accumulators to hide FMA latency. Works across AVX2, AVX-512, and NEON
118/// by accepting ISA-specific intrinsics as parameters.
119///
120/// Returns `(combined_accumulator, updated_a_ptr, updated_b_ptr)`.
121///
122/// # Arguments
123///
124/// - `$a_ptr`, `$b_ptr` — Starting pointers for the two input vectors
125/// - `$end` — End pointer for the main loop (aligned to `4 × lane`)
126/// - `$zero` — Zero-init expression (e.g., `_mm256_setzero_ps()`)
127/// - `$load` — SIMD load intrinsic (e.g., `_mm256_loadu_ps`)
128/// - `$fmadd` — FMA intrinsic with signature `fmadd(a, b, acc) → a*b + acc`
129/// - `$add` — SIMD add intrinsic (e.g., `_mm256_add_ps`)
130/// - `$lane` — Number of f32 elements per SIMD register (4/8/16)
131///
132/// # Safety
133///
134/// Must be invoked inside an `unsafe` context where the specified
135/// SIMD intrinsics are valid for the current CPU.
136#[macro_export]
137macro_rules! simd_4acc_dot_loop {
138 ($a_ptr:expr, $b_ptr:expr, $end:expr,
139 $zero:expr, $load:ident, $fmadd:ident, $add:ident, $lane:expr) => {{
140 let mut acc0 = $zero;
141 let mut acc1 = $zero;
142 let mut acc2 = $zero;
143 let mut acc3 = $zero;
144 let mut a_p = $a_ptr;
145 let mut b_p = $b_ptr;
146
147 while a_p < $end {
148 let va0 = $load(a_p);
149 let vb0 = $load(b_p);
150 acc0 = $fmadd(va0, vb0, acc0);
151
152 let va1 = $load(a_p.add($lane));
153 let vb1 = $load(b_p.add($lane));
154 acc1 = $fmadd(va1, vb1, acc1);
155
156 let va2 = $load(a_p.add(2 * $lane));
157 let vb2 = $load(b_p.add(2 * $lane));
158 acc2 = $fmadd(va2, vb2, acc2);
159
160 let va3 = $load(a_p.add(3 * $lane));
161 let vb3 = $load(b_p.add(3 * $lane));
162 acc3 = $fmadd(va3, vb3, acc3);
163
164 a_p = a_p.add(4 * $lane);
165 b_p = b_p.add(4 * $lane);
166 }
167
168 let sum01 = $add(acc0, acc1);
169 let sum23 = $add(acc2, acc3);
170 ($add(sum01, sum23), a_p, b_p)
171 }};
172}
173
174/// 4-accumulator unrolled SIMD loop for squared L2 distance.
175///
176/// Same structure as [`simd_4acc_dot_loop!`] but computes `sum((a-b)²)`
177/// instead of `sum(a·b)`. Requires an additional `$sub` intrinsic.
178///
179/// # Safety
180///
181/// Same requirements as [`simd_4acc_dot_loop!`].
182#[macro_export]
183macro_rules! simd_4acc_l2_loop {
184 ($a_ptr:expr, $b_ptr:expr, $end:expr,
185 $zero:expr, $load:ident, $sub:ident, $fmadd:ident, $add:ident, $lane:expr) => {{
186 let mut acc0 = $zero;
187 let mut acc1 = $zero;
188 let mut acc2 = $zero;
189 let mut acc3 = $zero;
190 let mut a_p = $a_ptr;
191 let mut b_p = $b_ptr;
192
193 while a_p < $end {
194 let va0 = $load(a_p);
195 let vb0 = $load(b_p);
196 let diff0 = $sub(va0, vb0);
197 acc0 = $fmadd(diff0, diff0, acc0);
198
199 let va1 = $load(a_p.add($lane));
200 let vb1 = $load(b_p.add($lane));
201 let diff1 = $sub(va1, vb1);
202 acc1 = $fmadd(diff1, diff1, acc1);
203
204 let va2 = $load(a_p.add(2 * $lane));
205 let vb2 = $load(b_p.add(2 * $lane));
206 let diff2 = $sub(va2, vb2);
207 acc2 = $fmadd(diff2, diff2, acc2);
208
209 let va3 = $load(a_p.add(3 * $lane));
210 let vb3 = $load(b_p.add(3 * $lane));
211 let diff3 = $sub(va3, vb3);
212 acc3 = $fmadd(diff3, diff3, acc3);
213
214 a_p = a_p.add(4 * $lane);
215 b_p = b_p.add(4 * $lane);
216 }
217
218 let sum01 = $add(acc0, acc1);
219 let sum23 = $add(acc2, acc3);
220 ($add(sum01, sum23), a_p, b_p)
221 }};
222}
223
224// Re-export macros for crate-internal use
225#[allow(unused_imports)]
226pub(crate) use simd_4acc_dot_loop;
227#[allow(unused_imports)]
228pub(crate) use simd_4acc_l2_loop;