simd_lookup/simd_gather.rs
1//! SIMD gather operations for efficient indexed memory access
2//!
3//! This module provides vectorized gather functions that load multiple values from
4//! memory using SIMD indices. On AVX-512, these use hardware gather instructions
5//! (`_mm512_i32gather_epi32`, `_mm512_mask_i32gather_epi32`); on other platforms,
6//! they fall back to scalar loops.
7//!
8//! # CPU Feature Requirements (Intel x86_64)
9//!
10//! ## Optimal Performance (AVX-512)
11//!
12//! - **`gather_u32index_u8` / `gather_masked_u32index_u8`**: Requires **AVX512F** + **AVX512BW**
13//! - Uses `VGATHERDPS` (`_mm512_i32gather_epi32`) + `VPMOVDB` (`_mm512_cvtepi32_epi8`)
14//! - Available on: Intel Skylake-X (Xeon), Ice Lake, Tiger Lake, and later
15//! - Fallback: Scalar loop (works on all architectures)
16//!
17//! - **`gather_u32index_u32` / `gather_masked_u32index_u32`**: Requires **AVX512F**
18//! - Uses `VGATHERDPS` (`_mm512_i32gather_epi32`)
19//! - Available on: Intel Skylake-X (Xeon), Ice Lake, Tiger Lake, and later
20//! - Fallback: Scalar loop (works on all architectures)
21//!
22//! ## Fallback Behavior
23//!
24//! All functions automatically fall back to scalar implementations when AVX-512
25//! features are not available. The fallback implementations work on:
26//! - x86_64 without AVX-512 (uses AVX2 gather if available, or scalar)
27//! - aarch64 (ARM NEON) - scalar fallback
28//! - All other architectures (scalar fallback)
29//!
30//! # Functions
31//!
32//! - [`gather_u32index_u8`] - Gather 16 bytes using u32 indices
33//! - [`gather_masked_u32index_u8`] - Masked gather of bytes with fallback values
34//! - [`gather_u32index_u32`] - Gather 16 u32 values using u32 indices
35//! - [`gather_masked_u32index_u32`] - Masked gather of u32 values with fallback
36//!
37//! # Important: Masked Gather Behavior on Intel
38//!
39//! When using masked gather functions, be aware of two distinct behaviors:
40//!
41//! ## 1. Architectural Fault Suppression (AVX-512)
42//!
43//! AVX-512 masked gathers are *architecturally* designed to **suppress page faults**
44//! for masked-off elements. If a masked element (mask bit = 0) points to an invalid
45//! address, it will NOT cause a page fault. This is documented in the IntelĀ® 64 and
46//! IA-32 Architectures SDM, Vol. 1, Section 15.6.4.
47//!
48//! This means masked gathers are safe to use when some indices may be invalid, as long
49//! as those lanes are masked off.
50//!
51//! ## 2. Speculative Memory Access (Performance Reality)
52//!
53//! Despite the mask, the hardware may still **speculatively access all memory locations**
54//! regardless of mask state. This was the root cause of the Gather Data Sampling (GDS)
55//! vulnerability (CVE-2022-40982).
56//!
57//! From Intel's GDS documentation:
58//! > "When a gather instruction performs loads from memory, different data elements are
59//! > merged into the destination vector register according to the mask specified. In some
60//! > situations, due to hardware optimizations specific to gather instructions, stale data
61//! > from previous usage of architectural or internal vector registers may get transiently
62//! > forwarded to dependent instructions without being updated by the gather loads."
63//!
64//! **Practical implications:**
65//! - The mask **does NOT reduce memory bandwidth** - all lanes likely issue loads
66//! - The mask **does NOT skip cache misses** on masked lanes
67//! - Post-GDS microcode updates add latency but fix the speculation issue
68//!
69//! ## Architecture Comparison
70//!
71//! | Feature | AVX2 Gather | AVX-512 Gather |
72//! |---------------------------------|----------------|--------------------------|
73//! | Masked fault suppression | Limited/None | Architecturally guaranteed |
74//! | Speculative access (pre-GDS) | Yes | Yes |
75//! | Post-GDS microcode | N/A | Adds latency, fixes spec |
76//!
77//! ## When to Use Masked Gathers
78//!
79//! **Good use cases:**
80//! - Conditional semantics (keeping fallback values for some lanes)
81//! - Fault suppression (safe to have invalid pointers in masked lanes on AVX-512)
82//! - Avoiding branching in vectorized code
83//!
84//! **NOT useful for:**
85//! - Reducing memory bandwidth (all locations still accessed)
86//! - Skipping expensive cache misses on masked lanes
87//! - Performance gains from partial masking
88//!
89//! # References
90//!
91//! - IntelĀ® 64 and IA-32 Architectures SDM, Vol. 1, Section 15.6.4 (AVX-512 Masking)
92//! - [Intel Gather Data Sampling (GDS) Documentation](https://www.intel.com/content/www/us/en/developer/articles/technical/software-security-guidance/technical-documentation/gather-data-sampling.html)
93//!
94//! # Example
95//!
96//! ```rust
97//! use wide::u32x16;
98//! use simd_lookup::simd_gather::gather_u32index_u8;
99//!
100//! let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
101//! let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
102//! let result = gather_u32index_u8(indices, &data, 1);
103//! assert_eq!(result.to_array(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
104//! ```
105
106use wide::{u8x16, u32x16};
107
108// =============================================================================
109// Public API
110// =============================================================================
111
112/// Gather 16 bytes from memory using u32 indices.
113///
114/// Computes: `result[i] = base[indices[i] * scale]` for each lane.
115///
116/// # Arguments
117/// * `indices` - Vector of 16 u32 indices
118/// * `base` - Base slice to gather from
119/// * `scale` - Scale factor applied to each index (1, 2, 4, or 8)
120///
121/// # Safety
122/// The caller must ensure that `indices[i] * scale < base.len()` for all lanes.
123/// Out-of-bounds access is undefined behavior.
124#[inline]
125pub fn gather_u32index_u8(indices: u32x16, base: &[u8], scale: u8) -> u8x16 {
126 #[cfg(target_arch = "x86_64")]
127 {
128 // Requires AVX512BW for _mm512_cvtepi32_epi8
129 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") {
130 return unsafe { gather_u32index_u8_avx512(indices, base, scale) };
131 }
132 }
133
134 gather_u32index_u8_scalar(indices, base, scale)
135}
136
137/// Masked gather of 16 bytes from memory using u32 indices.
138///
139/// For lanes where mask bit is 1: `result[i] = base[indices[i] * scale]`
140/// For lanes where mask bit is 0: `result[i] = fallback[i] as u8`
141///
142/// # Arguments
143/// * `indices` - Vector of 16 u32 indices
144/// * `base` - Base slice to gather from
145/// * `scale` - Scale factor applied to each index (1, 2, 4, or 8)
146/// * `mask` - 16-bit mask indicating which lanes to gather (1 = gather, 0 = use fallback)
147/// * `fallback` - Fallback values (low byte used) for masked-off lanes
148///
149/// # Safety
150/// For lanes where the mask bit is 1, the caller must ensure that
151/// `indices[i] * scale < base.len()`. Out-of-bounds access is undefined behavior.
152#[inline]
153pub fn gather_masked_u32index_u8(
154 indices: u32x16,
155 base: &[u8],
156 scale: u8,
157 mask: u16,
158 fallback: u32x16,
159) -> u8x16 {
160 #[cfg(target_arch = "x86_64")]
161 {
162 // Requires AVX512BW for _mm512_cvtepi32_epi8
163 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") {
164 return unsafe { gather_masked_u32index_u8_avx512(indices, base, scale, mask, fallback) };
165 }
166 }
167
168 gather_masked_u32index_u8_scalar(indices, base, scale, mask, fallback)
169}
170
171/// Gather 16 u32 values from memory using u32 indices.
172///
173/// Computes: `result[i] = base[indices[i] * scale / 4]` for each lane.
174///
175/// # Arguments
176/// * `indices` - Vector of 16 u32 indices
177/// * `base` - Base slice to gather from
178/// * `scale` - Scale factor applied to each index (1, 2, 4, or 8)
179///
180/// # Safety
181/// The caller must ensure that the computed byte offset is valid.
182/// Out-of-bounds access is undefined behavior.
183#[inline]
184pub fn gather_u32index_u32(indices: u32x16, base: &[u32], scale: u8) -> u32x16 {
185 #[cfg(target_arch = "x86_64")]
186 {
187 if is_x86_feature_detected!("avx512f") {
188 return unsafe { gather_u32index_u32_avx512(indices, base, scale) };
189 }
190 }
191
192 gather_u32index_u32_scalar(indices, base, scale)
193}
194
195/// Masked gather of 16 u32 values from memory using u32 indices.
196///
197/// For lanes where mask bit is 1: `result[i] = base[indices[i] * scale / 4]`
198/// For lanes where mask bit is 0: `result[i] = fallback[i]`
199///
200/// # Arguments
201/// * `indices` - Vector of 16 u32 indices
202/// * `base` - Base slice to gather from
203/// * `scale` - Scale factor applied to each index (1, 2, 4, or 8)
204/// * `mask` - 16-bit mask indicating which lanes to gather (1 = gather, 0 = use fallback)
205/// * `fallback` - Fallback values for masked-off lanes
206///
207/// # Safety
208/// For lanes where the mask bit is 1, the caller must ensure that
209/// the computed byte offset is valid. Out-of-bounds access is undefined behavior.
210#[inline]
211pub fn gather_masked_u32index_u32(
212 indices: u32x16,
213 base: &[u32],
214 scale: u8,
215 mask: u16,
216 fallback: u32x16,
217) -> u32x16 {
218 #[cfg(target_arch = "x86_64")]
219 {
220 if is_x86_feature_detected!("avx512f") {
221 return unsafe { gather_masked_u32index_u32_avx512(indices, base, scale, mask, fallback) };
222 }
223 }
224
225 gather_masked_u32index_u32_scalar(indices, base, scale, mask, fallback)
226}
227
228// =============================================================================
229// x86_64 AVX512 Implementations
230// =============================================================================
231
232#[cfg(target_arch = "x86_64")]
233use std::arch::x86_64::*;
234
235#[cfg(target_arch = "x86_64")]
236use std::arch::is_x86_feature_detected;
237
238/// AVX512 implementation of gather_u32index_u8
239///
240/// Uses `_mm512_i32gather_epi32` to gather 32-bit words, then extracts low bytes
241/// using `_mm512_cvtepi32_epi8`.
242#[cfg(target_arch = "x86_64")]
243#[inline]
244#[target_feature(enable = "avx512f", enable = "avx512bw")]
245unsafe fn gather_u32index_u8_avx512(indices: u32x16, base: &[u8], scale: u8) -> u8x16 {
246 unsafe {
247 let idx = std::mem::transmute::<u32x16, __m512i>(indices);
248
249 // Gather 32-bit values (we only care about the low byte of each)
250 let gathered = match scale {
251 1 => _mm512_i32gather_epi32::<1>(idx, base.as_ptr() as *const i32),
252 2 => _mm512_i32gather_epi32::<2>(idx, base.as_ptr() as *const i32),
253 4 => _mm512_i32gather_epi32::<4>(idx, base.as_ptr() as *const i32),
254 8 => _mm512_i32gather_epi32::<8>(idx, base.as_ptr() as *const i32),
255 _ => _mm512_i32gather_epi32::<1>(idx, base.as_ptr() as *const i32),
256 };
257
258 // Extract low byte from each 32-bit lane
259 extract_low_bytes_avx512(gathered)
260 }
261}
262
263/// AVX512 implementation of gather_masked_u32index_u8
264///
265/// Uses `_mm512_cvtepi32_epi8` to extract low bytes from gathered 32-bit values.
266#[cfg(target_arch = "x86_64")]
267#[inline]
268#[target_feature(enable = "avx512f", enable = "avx512bw")]
269unsafe fn gather_masked_u32index_u8_avx512(
270 indices: u32x16,
271 base: &[u8],
272 scale: u8,
273 mask: u16,
274 fallback: u32x16,
275) -> u8x16 {
276 unsafe {
277 let idx = std::mem::transmute::<u32x16, __m512i>(indices);
278 let src = std::mem::transmute::<u32x16, __m512i>(fallback);
279
280 // Masked gather: where mask bit is 1, gather from memory; where 0, use src
281 let gathered = match scale {
282 1 => _mm512_mask_i32gather_epi32::<1>(src, mask, idx, base.as_ptr() as *const i32),
283 2 => _mm512_mask_i32gather_epi32::<2>(src, mask, idx, base.as_ptr() as *const i32),
284 4 => _mm512_mask_i32gather_epi32::<4>(src, mask, idx, base.as_ptr() as *const i32),
285 8 => _mm512_mask_i32gather_epi32::<8>(src, mask, idx, base.as_ptr() as *const i32),
286 _ => _mm512_mask_i32gather_epi32::<1>(src, mask, idx, base.as_ptr() as *const i32),
287 };
288
289 // Extract low byte from each 32-bit lane
290 extract_low_bytes_avx512(gathered)
291 }
292}
293
294/// AVX512 implementation of gather_u32index_u32
295#[cfg(target_arch = "x86_64")]
296#[inline]
297#[target_feature(enable = "avx512f")]
298unsafe fn gather_u32index_u32_avx512(indices: u32x16, base: &[u32], scale: u8) -> u32x16 {
299 unsafe {
300 let idx = std::mem::transmute::<u32x16, __m512i>(indices);
301
302 let gathered = match scale {
303 1 => _mm512_i32gather_epi32::<1>(idx, base.as_ptr() as *const i32),
304 2 => _mm512_i32gather_epi32::<2>(idx, base.as_ptr() as *const i32),
305 4 => _mm512_i32gather_epi32::<4>(idx, base.as_ptr() as *const i32),
306 8 => _mm512_i32gather_epi32::<8>(idx, base.as_ptr() as *const i32),
307 _ => _mm512_i32gather_epi32::<4>(idx, base.as_ptr() as *const i32),
308 };
309
310 std::mem::transmute::<__m512i, u32x16>(gathered)
311 }
312}
313
314/// AVX512 implementation of gather_masked_u32index_u32
315#[cfg(target_arch = "x86_64")]
316#[inline]
317#[target_feature(enable = "avx512f")]
318unsafe fn gather_masked_u32index_u32_avx512(
319 indices: u32x16,
320 base: &[u32],
321 scale: u8,
322 mask: u16,
323 fallback: u32x16,
324) -> u32x16 {
325 unsafe {
326 let idx = std::mem::transmute::<u32x16, __m512i>(indices);
327 let src = std::mem::transmute::<u32x16, __m512i>(fallback);
328
329 let gathered = match scale {
330 1 => _mm512_mask_i32gather_epi32::<1>(src, mask, idx, base.as_ptr() as *const i32),
331 2 => _mm512_mask_i32gather_epi32::<2>(src, mask, idx, base.as_ptr() as *const i32),
332 4 => _mm512_mask_i32gather_epi32::<4>(src, mask, idx, base.as_ptr() as *const i32),
333 8 => _mm512_mask_i32gather_epi32::<8>(src, mask, idx, base.as_ptr() as *const i32),
334 _ => _mm512_mask_i32gather_epi32::<4>(src, mask, idx, base.as_ptr() as *const i32),
335 };
336
337 std::mem::transmute::<__m512i, u32x16>(gathered)
338 }
339}
340
341/// Extract the low byte from each 32-bit lane of a 512-bit vector.
342///
343/// Uses `_mm512_cvtepi32_epi8` which truncates each 32-bit element to 8 bits
344/// and packs all 16 results into a 128-bit vector (u8x16).
345#[cfg(target_arch = "x86_64")]
346#[inline]
347#[target_feature(enable = "avx512f", enable = "avx512bw")]
348unsafe fn extract_low_bytes_avx512(gathered: __m512i) -> u8x16 {
349 unsafe {
350 // Truncate each 32-bit lane to 8 bits, pack into __m128i
351 let packed = _mm512_cvtepi32_epi8(gathered);
352 std::mem::transmute::<__m128i, u8x16>(packed)
353 }
354}
355
356// =============================================================================
357// Scalar Fallback Implementations
358// =============================================================================
359
360/// Scalar fallback for gather_u32index_u8
361#[inline]
362fn gather_u32index_u8_scalar(indices: u32x16, base: &[u8], scale: u8) -> u8x16 {
363 let idx_arr = indices.to_array();
364 let scale = scale as usize;
365 let mut result = [0u8; 16];
366
367 for i in 0..16 {
368 let offset = idx_arr[i] as usize * scale;
369 result[i] = base[offset];
370 }
371
372 u8x16::from(result)
373}
374
375/// Scalar fallback for gather_masked_u32index_u8
376#[inline]
377fn gather_masked_u32index_u8_scalar(
378 indices: u32x16,
379 base: &[u8],
380 scale: u8,
381 mask: u16,
382 fallback: u32x16,
383) -> u8x16 {
384 let idx_arr = indices.to_array();
385 let fallback_arr = fallback.to_array();
386 let scale = scale as usize;
387 let mut result = [0u8; 16];
388
389 for i in 0..16 {
390 if (mask >> i) & 1 != 0 {
391 let offset = idx_arr[i] as usize * scale;
392 result[i] = base[offset];
393 } else {
394 result[i] = fallback_arr[i] as u8;
395 }
396 }
397
398 u8x16::from(result)
399}
400
401/// Scalar fallback for gather_u32index_u32
402#[inline]
403fn gather_u32index_u32_scalar(indices: u32x16, base: &[u32], scale: u8) -> u32x16 {
404 let idx_arr = indices.to_array();
405 let scale = scale as usize;
406 let mut result = [0u32; 16];
407
408 for i in 0..16 {
409 // Scale is in bytes, so divide by 4 for u32 indexing
410 let offset = (idx_arr[i] as usize * scale) / 4;
411 result[i] = base[offset];
412 }
413
414 u32x16::from(result)
415}
416
417/// Scalar fallback for gather_masked_u32index_u32
418#[inline]
419fn gather_masked_u32index_u32_scalar(
420 indices: u32x16,
421 base: &[u32],
422 scale: u8,
423 mask: u16,
424 fallback: u32x16,
425) -> u32x16 {
426 let idx_arr = indices.to_array();
427 let fallback_arr = fallback.to_array();
428 let scale = scale as usize;
429 let mut result = [0u32; 16];
430
431 for i in 0..16 {
432 if (mask >> i) & 1 != 0 {
433 // Scale is in bytes, so divide by 4 for u32 indexing
434 let offset = (idx_arr[i] as usize * scale) / 4;
435 result[i] = base[offset];
436 } else {
437 result[i] = fallback_arr[i];
438 }
439 }
440
441 u32x16::from(result)
442}
443
444// =============================================================================
445// Tests
446// =============================================================================
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451
452 #[test]
453 fn test_gather_u32index_u8_basic() {
454 let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
455 let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
456
457 let result = gather_u32index_u8(indices, &data, 1);
458 assert_eq!(
459 result.to_array(),
460 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
461 );
462 }
463
464 #[test]
465 fn test_gather_u32index_u8_scaled() {
466 let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
467 // With scale=2, indices are multiplied by 2
468 let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
469
470 let result = gather_u32index_u8(indices, &data, 2);
471 assert_eq!(
472 result.to_array(),
473 [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]
474 );
475 }
476
477 #[test]
478 fn test_gather_u32index_u8_non_sequential() {
479 let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
480 let indices = u32x16::from([100, 50, 200, 25, 150, 75, 225, 10, 0, 255, 128, 64, 192, 32, 96, 160]);
481
482 let result = gather_u32index_u8(indices, &data, 1);
483 assert_eq!(
484 result.to_array(),
485 [100, 50, 200, 25, 150, 75, 225, 10, 0, 255, 128, 64, 192, 32, 96, 160]
486 );
487 }
488
489 #[test]
490 fn test_gather_masked_u32index_u8() {
491 let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
492 let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
493 let fallback = u32x16::from([255; 16]);
494 // Mask: only gather even lanes (bits 0, 2, 4, 6, 8, 10, 12, 14)
495 let mask = 0b0101010101010101u16;
496
497 let result = gather_masked_u32index_u8(indices, &data, 1, mask, fallback);
498 assert_eq!(
499 result.to_array(),
500 [0, 255, 2, 255, 4, 255, 6, 255, 8, 255, 10, 255, 12, 255, 14, 255]
501 );
502 }
503
504 #[test]
505 fn test_gather_masked_u32index_u8_all_masked() {
506 let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
507 let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
508 let fallback = u32x16::from([42; 16]);
509 let mask = 0u16; // No lanes active
510
511 let result = gather_masked_u32index_u8(indices, &data, 1, mask, fallback);
512 assert_eq!(result.to_array(), [42; 16]);
513 }
514
515 #[test]
516 fn test_gather_u32index_u32_basic() {
517 let data: Vec<u32> = (0..256).map(|i| i as u32 * 1000).collect();
518 let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
519
520 // Scale 4 means each index addresses a u32 directly
521 let result = gather_u32index_u32(indices, &data, 4);
522 assert_eq!(
523 result.to_array(),
524 [0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 11000, 12000, 13000, 14000, 15000]
525 );
526 }
527
528 #[test]
529 fn test_gather_masked_u32index_u32() {
530 let data: Vec<u32> = (0..256).map(|i| i as u32 * 100).collect();
531 let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
532 let fallback = u32x16::from([999; 16]);
533 // Mask: only odd lanes (bits 1, 3, 5, 7, 9, 11, 13, 15)
534 let mask = 0b1010101010101010u16;
535
536 let result = gather_masked_u32index_u32(indices, &data, 4, mask, fallback);
537 assert_eq!(
538 result.to_array(),
539 [999, 100, 999, 300, 999, 500, 999, 700, 999, 900, 999, 1100, 999, 1300, 999, 1500]
540 );
541 }
542
543 #[test]
544 fn test_gather_u32index_u32_non_sequential() {
545 let data: Vec<u32> = (0..256).map(|i| i as u32).collect();
546 let indices = u32x16::from([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
547
548 let result = gather_u32index_u32(indices, &data, 4);
549 assert_eq!(
550 result.to_array(),
551 [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
552 );
553 }
554}
555