simd_lookup/
simd_gather.rs

1//! SIMD gather operations for efficient indexed memory access
2//!
3//! This module provides vectorized gather functions that load multiple values from
4//! memory using SIMD indices. On AVX-512, these use hardware gather instructions
5//! (`_mm512_i32gather_epi32`, `_mm512_mask_i32gather_epi32`); on other platforms,
6//! they fall back to scalar loops.
7//!
8//! # CPU Feature Requirements (Intel x86_64)
9//!
10//! ## Optimal Performance (AVX-512)
11//!
12//! - **`gather_u32index_u8` / `gather_masked_u32index_u8`**: Requires **AVX512F** + **AVX512BW**
13//!   - Uses `VGATHERDPS` (`_mm512_i32gather_epi32`) + `VPMOVDB` (`_mm512_cvtepi32_epi8`)
14//!   - Available on: Intel Skylake-X (Xeon), Ice Lake, Tiger Lake, and later
15//!   - Fallback: Scalar loop (works on all architectures)
16//!
17//! - **`gather_u32index_u32` / `gather_masked_u32index_u32`**: Requires **AVX512F**
18//!   - Uses `VGATHERDPS` (`_mm512_i32gather_epi32`)
19//!   - Available on: Intel Skylake-X (Xeon), Ice Lake, Tiger Lake, and later
20//!   - Fallback: Scalar loop (works on all architectures)
21//!
22//! ## Fallback Behavior
23//!
24//! All functions automatically fall back to scalar implementations when AVX-512
25//! features are not available. The fallback implementations work on:
26//! - x86_64 without AVX-512 (uses AVX2 gather if available, or scalar)
27//! - aarch64 (ARM NEON) - scalar fallback
28//! - All other architectures (scalar fallback)
29//!
30//! # Functions
31//!
32//! - [`gather_u32index_u8`] - Gather 16 bytes using u32 indices
33//! - [`gather_masked_u32index_u8`] - Masked gather of bytes with fallback values
34//! - [`gather_u32index_u32`] - Gather 16 u32 values using u32 indices
35//! - [`gather_masked_u32index_u32`] - Masked gather of u32 values with fallback
36//!
37//! # Important: Masked Gather Behavior on Intel
38//!
39//! When using masked gather functions, be aware of two distinct behaviors:
40//!
41//! ## 1. Architectural Fault Suppression (AVX-512)
42//!
43//! AVX-512 masked gathers are *architecturally* designed to **suppress page faults**
44//! for masked-off elements. If a masked element (mask bit = 0) points to an invalid
45//! address, it will NOT cause a page fault. This is documented in the IntelĀ® 64 and
46//! IA-32 Architectures SDM, Vol. 1, Section 15.6.4.
47//!
48//! This means masked gathers are safe to use when some indices may be invalid, as long
49//! as those lanes are masked off.
50//!
51//! ## 2. Speculative Memory Access (Performance Reality)
52//!
53//! Despite the mask, the hardware may still **speculatively access all memory locations**
54//! regardless of mask state. This was the root cause of the Gather Data Sampling (GDS)
55//! vulnerability (CVE-2022-40982).
56//!
57//! From Intel's GDS documentation:
58//! > "When a gather instruction performs loads from memory, different data elements are
59//! > merged into the destination vector register according to the mask specified. In some
60//! > situations, due to hardware optimizations specific to gather instructions, stale data
61//! > from previous usage of architectural or internal vector registers may get transiently
62//! > forwarded to dependent instructions without being updated by the gather loads."
63//!
64//! **Practical implications:**
65//! - The mask **does NOT reduce memory bandwidth** - all lanes likely issue loads
66//! - The mask **does NOT skip cache misses** on masked lanes
67//! - Post-GDS microcode updates add latency but fix the speculation issue
68//!
69//! ## Architecture Comparison
70//!
71//! | Feature                         | AVX2 Gather    | AVX-512 Gather           |
72//! |---------------------------------|----------------|--------------------------|
73//! | Masked fault suppression        | Limited/None   | Architecturally guaranteed |
74//! | Speculative access (pre-GDS)    | Yes            | Yes                      |
75//! | Post-GDS microcode              | N/A            | Adds latency, fixes spec |
76//!
77//! ## When to Use Masked Gathers
78//!
79//! **Good use cases:**
80//! - Conditional semantics (keeping fallback values for some lanes)
81//! - Fault suppression (safe to have invalid pointers in masked lanes on AVX-512)
82//! - Avoiding branching in vectorized code
83//!
84//! **NOT useful for:**
85//! - Reducing memory bandwidth (all locations still accessed)
86//! - Skipping expensive cache misses on masked lanes
87//! - Performance gains from partial masking
88//!
89//! # References
90//!
91//! - IntelĀ® 64 and IA-32 Architectures SDM, Vol. 1, Section 15.6.4 (AVX-512 Masking)
92//! - [Intel Gather Data Sampling (GDS) Documentation](https://www.intel.com/content/www/us/en/developer/articles/technical/software-security-guidance/technical-documentation/gather-data-sampling.html)
93//!
94//! # Example
95//!
96//! ```rust
97//! use wide::u32x16;
98//! use simd_lookup::simd_gather::gather_u32index_u8;
99//!
100//! let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
101//! let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
102//! let result = gather_u32index_u8(indices, &data, 1);
103//! assert_eq!(result.to_array(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
104//! ```
105
106use wide::{u8x16, u32x16};
107
108// =============================================================================
109// Public API
110// =============================================================================
111
112/// Gather 16 bytes from memory using u32 indices.
113///
114/// Computes: `result[i] = base[indices[i] * scale]` for each lane.
115///
116/// # Arguments
117/// * `indices` - Vector of 16 u32 indices
118/// * `base` - Base slice to gather from
119/// * `scale` - Scale factor applied to each index (1, 2, 4, or 8)
120///
121/// # Safety
122/// The caller must ensure that `indices[i] * scale < base.len()` for all lanes.
123/// Out-of-bounds access is undefined behavior.
124#[inline]
125pub fn gather_u32index_u8(indices: u32x16, base: &[u8], scale: u8) -> u8x16 {
126    #[cfg(target_arch = "x86_64")]
127    {
128        // Requires AVX512BW for _mm512_cvtepi32_epi8
129        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") {
130            return unsafe { gather_u32index_u8_avx512(indices, base, scale) };
131        }
132    }
133
134    gather_u32index_u8_scalar(indices, base, scale)
135}
136
137/// Masked gather of 16 bytes from memory using u32 indices.
138///
139/// For lanes where mask bit is 1: `result[i] = base[indices[i] * scale]`
140/// For lanes where mask bit is 0: `result[i] = fallback[i] as u8`
141///
142/// # Arguments
143/// * `indices` - Vector of 16 u32 indices
144/// * `base` - Base slice to gather from
145/// * `scale` - Scale factor applied to each index (1, 2, 4, or 8)
146/// * `mask` - 16-bit mask indicating which lanes to gather (1 = gather, 0 = use fallback)
147/// * `fallback` - Fallback values (low byte used) for masked-off lanes
148///
149/// # Safety
150/// For lanes where the mask bit is 1, the caller must ensure that
151/// `indices[i] * scale < base.len()`. Out-of-bounds access is undefined behavior.
152#[inline]
153pub fn gather_masked_u32index_u8(
154    indices: u32x16,
155    base: &[u8],
156    scale: u8,
157    mask: u16,
158    fallback: u32x16,
159) -> u8x16 {
160    #[cfg(target_arch = "x86_64")]
161    {
162        // Requires AVX512BW for _mm512_cvtepi32_epi8
163        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") {
164            return unsafe { gather_masked_u32index_u8_avx512(indices, base, scale, mask, fallback) };
165        }
166    }
167
168    gather_masked_u32index_u8_scalar(indices, base, scale, mask, fallback)
169}
170
171/// Gather 16 u32 values from memory using u32 indices.
172///
173/// Computes: `result[i] = base[indices[i] * scale / 4]` for each lane.
174///
175/// # Arguments
176/// * `indices` - Vector of 16 u32 indices
177/// * `base` - Base slice to gather from
178/// * `scale` - Scale factor applied to each index (1, 2, 4, or 8)
179///
180/// # Safety
181/// The caller must ensure that the computed byte offset is valid.
182/// Out-of-bounds access is undefined behavior.
183#[inline]
184pub fn gather_u32index_u32(indices: u32x16, base: &[u32], scale: u8) -> u32x16 {
185    #[cfg(target_arch = "x86_64")]
186    {
187        if is_x86_feature_detected!("avx512f") {
188            return unsafe { gather_u32index_u32_avx512(indices, base, scale) };
189        }
190    }
191
192    gather_u32index_u32_scalar(indices, base, scale)
193}
194
195/// Masked gather of 16 u32 values from memory using u32 indices.
196///
197/// For lanes where mask bit is 1: `result[i] = base[indices[i] * scale / 4]`
198/// For lanes where mask bit is 0: `result[i] = fallback[i]`
199///
200/// # Arguments
201/// * `indices` - Vector of 16 u32 indices
202/// * `base` - Base slice to gather from
203/// * `scale` - Scale factor applied to each index (1, 2, 4, or 8)
204/// * `mask` - 16-bit mask indicating which lanes to gather (1 = gather, 0 = use fallback)
205/// * `fallback` - Fallback values for masked-off lanes
206///
207/// # Safety
208/// For lanes where the mask bit is 1, the caller must ensure that
209/// the computed byte offset is valid. Out-of-bounds access is undefined behavior.
210#[inline]
211pub fn gather_masked_u32index_u32(
212    indices: u32x16,
213    base: &[u32],
214    scale: u8,
215    mask: u16,
216    fallback: u32x16,
217) -> u32x16 {
218    #[cfg(target_arch = "x86_64")]
219    {
220        if is_x86_feature_detected!("avx512f") {
221            return unsafe { gather_masked_u32index_u32_avx512(indices, base, scale, mask, fallback) };
222        }
223    }
224
225    gather_masked_u32index_u32_scalar(indices, base, scale, mask, fallback)
226}
227
228// =============================================================================
229// x86_64 AVX512 Implementations
230// =============================================================================
231
232#[cfg(target_arch = "x86_64")]
233use std::arch::x86_64::*;
234
235#[cfg(target_arch = "x86_64")]
236use std::arch::is_x86_feature_detected;
237
238/// AVX512 implementation of gather_u32index_u8
239///
240/// Uses `_mm512_i32gather_epi32` to gather 32-bit words, then extracts low bytes
241/// using `_mm512_cvtepi32_epi8`.
242#[cfg(target_arch = "x86_64")]
243#[inline]
244#[target_feature(enable = "avx512f", enable = "avx512bw")]
245unsafe fn gather_u32index_u8_avx512(indices: u32x16, base: &[u8], scale: u8) -> u8x16 {
246    unsafe {
247        let idx = std::mem::transmute::<u32x16, __m512i>(indices);
248
249        // Gather 32-bit values (we only care about the low byte of each)
250        let gathered = match scale {
251            1 => _mm512_i32gather_epi32::<1>(idx, base.as_ptr() as *const i32),
252            2 => _mm512_i32gather_epi32::<2>(idx, base.as_ptr() as *const i32),
253            4 => _mm512_i32gather_epi32::<4>(idx, base.as_ptr() as *const i32),
254            8 => _mm512_i32gather_epi32::<8>(idx, base.as_ptr() as *const i32),
255            _ => _mm512_i32gather_epi32::<1>(idx, base.as_ptr() as *const i32),
256        };
257
258        // Extract low byte from each 32-bit lane
259        extract_low_bytes_avx512(gathered)
260    }
261}
262
263/// AVX512 implementation of gather_masked_u32index_u8
264///
265/// Uses `_mm512_cvtepi32_epi8` to extract low bytes from gathered 32-bit values.
266#[cfg(target_arch = "x86_64")]
267#[inline]
268#[target_feature(enable = "avx512f", enable = "avx512bw")]
269unsafe fn gather_masked_u32index_u8_avx512(
270    indices: u32x16,
271    base: &[u8],
272    scale: u8,
273    mask: u16,
274    fallback: u32x16,
275) -> u8x16 {
276    unsafe {
277        let idx = std::mem::transmute::<u32x16, __m512i>(indices);
278        let src = std::mem::transmute::<u32x16, __m512i>(fallback);
279
280        // Masked gather: where mask bit is 1, gather from memory; where 0, use src
281        let gathered = match scale {
282            1 => _mm512_mask_i32gather_epi32::<1>(src, mask, idx, base.as_ptr() as *const i32),
283            2 => _mm512_mask_i32gather_epi32::<2>(src, mask, idx, base.as_ptr() as *const i32),
284            4 => _mm512_mask_i32gather_epi32::<4>(src, mask, idx, base.as_ptr() as *const i32),
285            8 => _mm512_mask_i32gather_epi32::<8>(src, mask, idx, base.as_ptr() as *const i32),
286            _ => _mm512_mask_i32gather_epi32::<1>(src, mask, idx, base.as_ptr() as *const i32),
287        };
288
289        // Extract low byte from each 32-bit lane
290        extract_low_bytes_avx512(gathered)
291    }
292}
293
294/// AVX512 implementation of gather_u32index_u32
295#[cfg(target_arch = "x86_64")]
296#[inline]
297#[target_feature(enable = "avx512f")]
298unsafe fn gather_u32index_u32_avx512(indices: u32x16, base: &[u32], scale: u8) -> u32x16 {
299    unsafe {
300        let idx = std::mem::transmute::<u32x16, __m512i>(indices);
301
302        let gathered = match scale {
303            1 => _mm512_i32gather_epi32::<1>(idx, base.as_ptr() as *const i32),
304            2 => _mm512_i32gather_epi32::<2>(idx, base.as_ptr() as *const i32),
305            4 => _mm512_i32gather_epi32::<4>(idx, base.as_ptr() as *const i32),
306            8 => _mm512_i32gather_epi32::<8>(idx, base.as_ptr() as *const i32),
307            _ => _mm512_i32gather_epi32::<4>(idx, base.as_ptr() as *const i32),
308        };
309
310        std::mem::transmute::<__m512i, u32x16>(gathered)
311    }
312}
313
314/// AVX512 implementation of gather_masked_u32index_u32
315#[cfg(target_arch = "x86_64")]
316#[inline]
317#[target_feature(enable = "avx512f")]
318unsafe fn gather_masked_u32index_u32_avx512(
319    indices: u32x16,
320    base: &[u32],
321    scale: u8,
322    mask: u16,
323    fallback: u32x16,
324) -> u32x16 {
325    unsafe {
326        let idx = std::mem::transmute::<u32x16, __m512i>(indices);
327        let src = std::mem::transmute::<u32x16, __m512i>(fallback);
328
329        let gathered = match scale {
330            1 => _mm512_mask_i32gather_epi32::<1>(src, mask, idx, base.as_ptr() as *const i32),
331            2 => _mm512_mask_i32gather_epi32::<2>(src, mask, idx, base.as_ptr() as *const i32),
332            4 => _mm512_mask_i32gather_epi32::<4>(src, mask, idx, base.as_ptr() as *const i32),
333            8 => _mm512_mask_i32gather_epi32::<8>(src, mask, idx, base.as_ptr() as *const i32),
334            _ => _mm512_mask_i32gather_epi32::<4>(src, mask, idx, base.as_ptr() as *const i32),
335        };
336
337        std::mem::transmute::<__m512i, u32x16>(gathered)
338    }
339}
340
341/// Extract the low byte from each 32-bit lane of a 512-bit vector.
342///
343/// Uses `_mm512_cvtepi32_epi8` which truncates each 32-bit element to 8 bits
344/// and packs all 16 results into a 128-bit vector (u8x16).
345#[cfg(target_arch = "x86_64")]
346#[inline]
347#[target_feature(enable = "avx512f", enable = "avx512bw")]
348unsafe fn extract_low_bytes_avx512(gathered: __m512i) -> u8x16 {
349    unsafe {
350        // Truncate each 32-bit lane to 8 bits, pack into __m128i
351        let packed = _mm512_cvtepi32_epi8(gathered);
352        std::mem::transmute::<__m128i, u8x16>(packed)
353    }
354}
355
356// =============================================================================
357// Scalar Fallback Implementations
358// =============================================================================
359
360/// Scalar fallback for gather_u32index_u8
361#[inline]
362fn gather_u32index_u8_scalar(indices: u32x16, base: &[u8], scale: u8) -> u8x16 {
363    let idx_arr = indices.to_array();
364    let scale = scale as usize;
365    let mut result = [0u8; 16];
366
367    for i in 0..16 {
368        let offset = idx_arr[i] as usize * scale;
369        result[i] = base[offset];
370    }
371
372    u8x16::from(result)
373}
374
375/// Scalar fallback for gather_masked_u32index_u8
376#[inline]
377fn gather_masked_u32index_u8_scalar(
378    indices: u32x16,
379    base: &[u8],
380    scale: u8,
381    mask: u16,
382    fallback: u32x16,
383) -> u8x16 {
384    let idx_arr = indices.to_array();
385    let fallback_arr = fallback.to_array();
386    let scale = scale as usize;
387    let mut result = [0u8; 16];
388
389    for i in 0..16 {
390        if (mask >> i) & 1 != 0 {
391            let offset = idx_arr[i] as usize * scale;
392            result[i] = base[offset];
393        } else {
394            result[i] = fallback_arr[i] as u8;
395        }
396    }
397
398    u8x16::from(result)
399}
400
401/// Scalar fallback for gather_u32index_u32
402#[inline]
403fn gather_u32index_u32_scalar(indices: u32x16, base: &[u32], scale: u8) -> u32x16 {
404    let idx_arr = indices.to_array();
405    let scale = scale as usize;
406    let mut result = [0u32; 16];
407
408    for i in 0..16 {
409        // Scale is in bytes, so divide by 4 for u32 indexing
410        let offset = (idx_arr[i] as usize * scale) / 4;
411        result[i] = base[offset];
412    }
413
414    u32x16::from(result)
415}
416
417/// Scalar fallback for gather_masked_u32index_u32
418#[inline]
419fn gather_masked_u32index_u32_scalar(
420    indices: u32x16,
421    base: &[u32],
422    scale: u8,
423    mask: u16,
424    fallback: u32x16,
425) -> u32x16 {
426    let idx_arr = indices.to_array();
427    let fallback_arr = fallback.to_array();
428    let scale = scale as usize;
429    let mut result = [0u32; 16];
430
431    for i in 0..16 {
432        if (mask >> i) & 1 != 0 {
433            // Scale is in bytes, so divide by 4 for u32 indexing
434            let offset = (idx_arr[i] as usize * scale) / 4;
435            result[i] = base[offset];
436        } else {
437            result[i] = fallback_arr[i];
438        }
439    }
440
441    u32x16::from(result)
442}
443
444// =============================================================================
445// Tests
446// =============================================================================
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    #[test]
453    fn test_gather_u32index_u8_basic() {
454        let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
455        let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
456
457        let result = gather_u32index_u8(indices, &data, 1);
458        assert_eq!(
459            result.to_array(),
460            [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
461        );
462    }
463
464    #[test]
465    fn test_gather_u32index_u8_scaled() {
466        let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
467        // With scale=2, indices are multiplied by 2
468        let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
469
470        let result = gather_u32index_u8(indices, &data, 2);
471        assert_eq!(
472            result.to_array(),
473            [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]
474        );
475    }
476
477    #[test]
478    fn test_gather_u32index_u8_non_sequential() {
479        let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
480        let indices = u32x16::from([100, 50, 200, 25, 150, 75, 225, 10, 0, 255, 128, 64, 192, 32, 96, 160]);
481
482        let result = gather_u32index_u8(indices, &data, 1);
483        assert_eq!(
484            result.to_array(),
485            [100, 50, 200, 25, 150, 75, 225, 10, 0, 255, 128, 64, 192, 32, 96, 160]
486        );
487    }
488
489    #[test]
490    fn test_gather_masked_u32index_u8() {
491        let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
492        let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
493        let fallback = u32x16::from([255; 16]);
494        // Mask: only gather even lanes (bits 0, 2, 4, 6, 8, 10, 12, 14)
495        let mask = 0b0101010101010101u16;
496
497        let result = gather_masked_u32index_u8(indices, &data, 1, mask, fallback);
498        assert_eq!(
499            result.to_array(),
500            [0, 255, 2, 255, 4, 255, 6, 255, 8, 255, 10, 255, 12, 255, 14, 255]
501        );
502    }
503
504    #[test]
505    fn test_gather_masked_u32index_u8_all_masked() {
506        let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
507        let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
508        let fallback = u32x16::from([42; 16]);
509        let mask = 0u16; // No lanes active
510
511        let result = gather_masked_u32index_u8(indices, &data, 1, mask, fallback);
512        assert_eq!(result.to_array(), [42; 16]);
513    }
514
515    #[test]
516    fn test_gather_u32index_u32_basic() {
517        let data: Vec<u32> = (0..256).map(|i| i as u32 * 1000).collect();
518        let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
519
520        // Scale 4 means each index addresses a u32 directly
521        let result = gather_u32index_u32(indices, &data, 4);
522        assert_eq!(
523            result.to_array(),
524            [0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 11000, 12000, 13000, 14000, 15000]
525        );
526    }
527
528    #[test]
529    fn test_gather_masked_u32index_u32() {
530        let data: Vec<u32> = (0..256).map(|i| i as u32 * 100).collect();
531        let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
532        let fallback = u32x16::from([999; 16]);
533        // Mask: only odd lanes (bits 1, 3, 5, 7, 9, 11, 13, 15)
534        let mask = 0b1010101010101010u16;
535
536        let result = gather_masked_u32index_u32(indices, &data, 4, mask, fallback);
537        assert_eq!(
538            result.to_array(),
539            [999, 100, 999, 300, 999, 500, 999, 700, 999, 900, 999, 1100, 999, 1300, 999, 1500]
540        );
541    }
542
543    #[test]
544    fn test_gather_u32index_u32_non_sequential() {
545        let data: Vec<u32> = (0..256).map(|i| i as u32).collect();
546        let indices = u32x16::from([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
547
548        let result = gather_u32index_u32(indices, &data, 4);
549        assert_eq!(
550            result.to_array(),
551            [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
552        );
553    }
554}
555