simd_lookup/
wide_utils.rs

1//! SIMD utilities and trait extensions for the `wide` crate
2//!
3//! This module provides optimized platform-specific implementations for common SIMD operations
4//! that are not directly available in the `wide` crate, including:
5//! - Widening operations (u32x8 → u64x8)
6//! - Bitmask to vector conversion
7//! - Shuffle/permute operations
8//! - Vector splitting (high/low extraction)
9//! - Cross-platform optimizations for x86_64 and aarch64
10//!
11//! # Examples
12//!
13//! ```rust
14//! use simd_lookup::wide_utils::{WideUtilsExt, FromBitmask};
15//! use wide::{u32x8, u64x8};
16//!
17//! let input = u32x8::from([1, 2, 3, 4, 5, 6, 7, 8]);
18//! let widened: u64x8 = input.widen_to_u64x8();
19//!
20//! let mask = 0b10101010u8;
21//! let mask_vec: u64x8 = u64x8::from_bitmask(mask);
22//! ```
23
24use wide::{u8x16, u32x4, u32x8, u32x16, u64x4, u64x8};
25
26// =============================================================================
27// Shuffle Lookup Tables for Compress Operations
28// =============================================================================
29
30/// Shuffle indices for compressing u32x8 based on an 8-bit mask.
31/// Each entry contains the indices to shuffle selected elements to the front.
32/// Stored as arrays; use `get_compress_indices_u32x8()` for zero-cost SIMD access.
33pub static SHUFFLE_COMPRESS_IDX_U32X8: [[u32; 8]; 256] = {
34    let mut table = [[7u32; 8]; 256];
35    let mut mask = 0usize;
36    while mask < 256 {
37        let mut dest_pos = 0usize;
38        let mut src_pos = 0usize;
39        while src_pos < 8 {
40            if (mask >> src_pos) & 1 != 0 {
41                table[mask][dest_pos] = src_pos as u32;
42                dest_pos += 1;
43            }
44            src_pos += 1;
45        }
46        mask += 1;
47    }
48    table
49};
50
51/// Get compress indices as u32x8 (zero-cost transmute from array)
52#[inline(always)]
53pub fn get_compress_indices_u32x8(mask: u8) -> u32x8 {
54    // Safety: [u32; 8] has same size and alignment as u32x8
55    unsafe { std::mem::transmute(SHUFFLE_COMPRESS_IDX_U32X8[mask as usize]) }
56}
57
58/// Legacy alias for backwards compatibility
59pub static SHUFFLE_COMPRESS_IDX_U32: &[[u32; 8]; 256] = &SHUFFLE_COMPRESS_IDX_U32X8;
60
61/// Shuffle indices for compressing u8x16 (low half).
62pub static SHUFFLE_COMPRESS_IDX_U8_LO: [[u8; 8]; 256] = {
63    let mut table = [[0u8; 8]; 256];
64    let mut mask = 0usize;
65    while mask < 256 {
66        let mut dest_pos = 0usize;
67        let mut src_pos = 0usize;
68        while src_pos < 8 {
69            if (mask >> src_pos) & 1 != 0 {
70                table[mask][dest_pos] = src_pos as u8;
71                dest_pos += 1;
72            }
73            src_pos += 1;
74        }
75        while dest_pos < 8 {
76            table[mask][dest_pos] = 0;
77            dest_pos += 1;
78        }
79        mask += 1;
80    }
81    table
82};
83
84/// High byte shuffle indices (with +8 offset baked in)
85pub static SHUFFLE_COMPRESS_IDX_U8_HI: [[u8; 8]; 256] = {
86    let mut table = [[0u8; 8]; 256];
87    let mut mask = 0usize;
88    while mask < 256 {
89        let mut dest_pos = 0usize;
90        let mut src_pos = 0usize;
91        while src_pos < 8 {
92            if (mask >> src_pos) & 1 != 0 {
93                table[mask][dest_pos] = (src_pos + 8) as u8;
94                dest_pos += 1;
95            }
96            src_pos += 1;
97        }
98        while dest_pos < 8 {
99            table[mask][dest_pos] = 8;
100            dest_pos += 1;
101        }
102        mask += 1;
103    }
104    table
105};
106
107// =============================================================================
108// Main Trait: WideUtilsExt - All SIMD utility operations in one place
109// =============================================================================
110
111/// Trait extension for `wide` SIMD types providing additional utility operations
112pub trait WideUtilsExt: Sized {
113    /// The output type for widening operations
114    type Widened;
115
116    /// Widen the vector elements to a larger type
117    fn widen_to_u64x8(self) -> Self::Widened;
118
119    /// Shuffle elements according to the given indices (indices are same SIMD type).
120    /// `result[i] = self[indices[i]]`
121    fn shuffle(self, indices: Self) -> Self;
122
123    /// Double each element (self + self). Wraps on overflow.
124    ///
125    /// This is the most efficient way to multiply by 2 since addition is
126    /// well-supported on all SIMD architectures (NEON `vaddq`, SSE `paddq`).
127    ///
128    /// For multiply by powers of 2, chain calls:
129    /// - `x.double()` = x * 2
130    /// - `x.double().double()` = x * 4
131    /// - `x.double().double().double()` = x * 8
132    #[inline(always)]
133    fn double(self) -> Self
134    where
135        Self: std::ops::Add<Output = Self> + Copy,
136    {
137        self + self
138    }
139}
140
141/// Trait for creating SIMD vectors from bitmasks
142pub trait FromBitmask<T> {
143    /// Create a SIMD vector from a bitmask where each bit becomes 0 or T::MAX
144    fn from_bitmask(mask: u8) -> Self;
145}
146
147/// Trait for splitting SIMD vectors into high/low halves efficiently
148pub trait SimdSplit: Sized {
149    /// The half-width type (e.g., u32x8 for u32x16)
150    type Half;
151
152    /// Split into (low, high) halves using efficient intrinsics
153    fn split_low_high(self) -> (Self::Half, Self::Half);
154
155    /// Extract the low half
156    #[inline(always)]
157    fn low_half(self) -> Self::Half {
158        self.split_low_high().0
159    }
160
161    /// Extract the high half
162    #[inline(always)]
163    fn high_half(self) -> Self::Half {
164        self.split_low_high().1
165    }
166}
167
168
169// =============================================================================
170// SimdSplit Implementations
171// =============================================================================
172
173impl SimdSplit for u32x16 {
174    type Half = u32x8;
175
176    #[inline(always)]
177    fn split_low_high(self) -> (u32x8, u32x8) {
178        #[cfg(target_arch = "x86_64")]
179        {
180            if is_x86_feature_detected!("avx512f") {
181                return unsafe { split_u32x16_avx512(self) };
182            }
183        }
184
185        // Fallback: use pointer casting (zero-copy)
186        split_u32x16_cast(self)
187    }
188}
189
190impl SimdSplit for u64x8 {
191    type Half = u64x4;
192
193    #[inline(always)]
194    fn split_low_high(self) -> (u64x4, u64x4) {
195        #[cfg(target_arch = "x86_64")]
196        {
197            if is_x86_feature_detected!("avx512f") {
198                return unsafe { split_u64x8_avx512(self) };
199            }
200        }
201
202        // Fallback: use pointer casting
203        split_u64x8_cast(self)
204    }
205}
206
207impl SimdSplit for u8x16 {
208    type Half = [u8; 8];
209
210    #[inline(always)]
211    fn split_low_high(self) -> ([u8; 8], [u8; 8]) {
212        let arr = self.to_array();
213        let mut lo = [0u8; 8];
214        let mut hi = [0u8; 8];
215        lo.copy_from_slice(&arr[0..8]);
216        hi.copy_from_slice(&arr[8..16]);
217        (lo, hi)
218    }
219}
220
221// =============================================================================
222// WideUtilsExt Implementations
223// =============================================================================
224
225impl WideUtilsExt for u32x8 {
226    type Widened = u64x8;
227
228    #[inline(always)]
229    fn widen_to_u64x8(self) -> u64x8 {
230        #[cfg(target_arch = "x86_64")]
231        {
232            if is_x86_feature_detected!("avx512f") {
233                return unsafe { widen_u32x8_to_u64x8_avx512(self) };
234            } else if is_x86_feature_detected!("avx2") {
235                return unsafe { widen_u32x8_to_u64x8_avx2(self) };
236            }
237            return widen_u32x8_to_u64x8_scalar(self);
238        }
239
240        #[cfg(target_arch = "aarch64")]
241        {
242            return unsafe { widen_u32x8_to_u64x8_neon(self) };
243        }
244
245        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
246        {
247            widen_u32x8_to_u64x8_scalar(self)
248        }
249    }
250
251    #[inline(always)]
252    fn shuffle(self, indices: Self) -> Self {
253        #[cfg(target_arch = "x86_64")]
254        {
255            if is_x86_feature_detected!("avx2") {
256                return unsafe { shuffle_u32x8_avx2(self, indices) };
257            }
258            return shuffle_u32x8_scalar(self, indices);
259        }
260
261        // On ARM, try SVE first (has native permute), fall back to scalar
262        // (NEON TBL requires byte-level index conversion which adds overhead)
263        #[cfg(target_arch = "aarch64")]
264        {
265            if std::arch::is_aarch64_feature_detected!("sve") {
266                return unsafe { shuffle_u32x8_sve(self, indices) };
267            }
268            return shuffle_u32x8_scalar(self, indices);
269        }
270
271        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
272        {
273            shuffle_u32x8_scalar(self, indices)
274        }
275    }
276}
277
278impl WideUtilsExt for u32x4 {
279    type Widened = u64x4;
280
281    #[inline(always)]
282    fn widen_to_u64x8(self) -> u64x4 {
283        #[cfg(target_arch = "x86_64")]
284        {
285            if is_x86_feature_detected!("avx2") {
286                return unsafe { widen_u32x4_to_u64x4_avx2(self) };
287            }
288            return widen_u32x4_to_u64x4_scalar(self);
289        }
290
291        #[cfg(target_arch = "aarch64")]
292        {
293            return unsafe { widen_u32x4_to_u64x4_neon(self) };
294        }
295
296        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
297        {
298            widen_u32x4_to_u64x4_scalar(self)
299        }
300    }
301
302    #[inline(always)]
303    fn shuffle(self, indices: Self) -> Self {
304        #[cfg(target_arch = "aarch64")]
305        {
306            if std::arch::is_aarch64_feature_detected!("sve") {
307                return unsafe { shuffle_u32x4_sve(self, indices) };
308            }
309        }
310        // Scalar fallback - faster than NEON TBL for u32 (no byte-index conversion)
311        shuffle_u32x4_scalar(self, indices)
312    }
313}
314
315impl WideUtilsExt for u8x16 {
316    type Widened = (); // u8 doesn't widen in the same way
317
318    #[inline(always)]
319    fn widen_to_u64x8(self) -> () {
320        // Not applicable for u8x16
321    }
322
323    #[inline(always)]
324    fn shuffle(self, indices: Self) -> Self {
325        #[cfg(target_arch = "x86_64")]
326        {
327            if is_x86_feature_detected!("ssse3") {
328                return unsafe { shuffle_u8x16_ssse3(self, indices) };
329            }
330            return shuffle_u8x16_scalar(self, indices);
331        }
332
333        #[cfg(target_arch = "aarch64")]
334        {
335            // SVE has native tbl, but NEON TBL is already optimal for u8x16
336            return unsafe { shuffle_u8x16_neon(self, indices) };
337        }
338
339        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
340        {
341            shuffle_u8x16_scalar(self, indices)
342        }
343    }
344}
345
346// =============================================================================
347// FromBitmask Implementations
348// =============================================================================
349
350impl FromBitmask<u64> for u64x8 {
351    #[inline(always)]
352    fn from_bitmask(mask: u8) -> Self {
353        #[cfg(target_arch = "x86_64")]
354        {
355            if is_x86_feature_detected!("avx512f") {
356                return unsafe { u64x8_from_bitmask_avx512(mask) };
357            } else if is_x86_feature_detected!("avx2") {
358                return unsafe { u64x8_from_bitmask_avx2(mask) };
359            }
360            return u64x8_from_bitmask_scalar(mask);
361        }
362
363        #[cfg(target_arch = "aarch64")]
364        {
365            return unsafe { u64x8_from_bitmask_neon(mask) };
366        }
367
368        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
369        {
370            u64x8_from_bitmask_scalar(mask)
371        }
372    }
373}
374
375impl FromBitmask<u32> for u32x8 {
376    #[inline(always)]
377    fn from_bitmask(mask: u8) -> Self {
378        #[cfg(target_arch = "x86_64")]
379        {
380            if is_x86_feature_detected!("avx512f") {
381                return unsafe { u32x8_from_bitmask_avx512(mask) };
382            } else if is_x86_feature_detected!("avx2") {
383                return unsafe { u32x8_from_bitmask_avx2(mask) };
384            }
385            return u32x8_from_bitmask_scalar(mask);
386        }
387
388        #[cfg(target_arch = "aarch64")]
389        {
390            return unsafe { u32x8_from_bitmask_neon(mask) };
391        }
392
393        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
394        {
395            u32x8_from_bitmask_scalar(mask)
396        }
397    }
398}
399
400
401// =============================================================================
402// x86_64 Implementations
403// =============================================================================
404
405#[cfg(target_arch = "x86_64")]
406use std::arch::x86_64::*;
407
408#[cfg(target_arch = "x86_64")]
409use std::arch::is_x86_feature_detected;
410
411#[cfg(target_arch = "x86_64")]
412#[inline]
413#[target_feature(enable = "avx512f")]
414unsafe fn split_u32x16_avx512(input: u32x16) -> (u32x8, u32x8) {
415    unsafe {
416        let raw = std::mem::transmute::<u32x16, __m512i>(input);
417        let lo = _mm512_castsi512_si256(raw);
418        let hi = _mm512_extracti64x4_epi64(raw, 1);
419        (
420            std::mem::transmute::<__m256i, u32x8>(lo),
421            std::mem::transmute::<__m256i, u32x8>(hi),
422        )
423    }
424}
425
426#[cfg(target_arch = "x86_64")]
427#[inline]
428#[target_feature(enable = "avx512f")]
429unsafe fn split_u64x8_avx512(input: u64x8) -> (u64x4, u64x4) {
430    unsafe {
431        let raw = std::mem::transmute::<u64x8, __m512i>(input);
432        let lo = _mm512_castsi512_si256(raw);
433        let hi = _mm512_extracti64x4_epi64(raw, 1);
434        (
435            std::mem::transmute::<__m256i, u64x4>(lo),
436            std::mem::transmute::<__m256i, u64x4>(hi),
437        )
438    }
439}
440
441#[cfg(target_arch = "x86_64")]
442#[inline]
443#[target_feature(enable = "avx512f")]
444unsafe fn widen_u32x8_to_u64x8_avx512(input: u32x8) -> u64x8 {
445    unsafe {
446        let raw = std::mem::transmute::<u32x8, __m256i>(input);
447        let widened = _mm512_cvtepu32_epi64(raw);
448        std::mem::transmute::<__m512i, u64x8>(widened)
449    }
450}
451
452#[cfg(target_arch = "x86_64")]
453#[inline]
454#[target_feature(enable = "avx2")]
455unsafe fn widen_u32x8_to_u64x8_avx2(input: u32x8) -> u64x8 {
456    unsafe {
457        let raw = std::mem::transmute::<u32x8, __m256i>(input);
458        let low = _mm256_extracti128_si256(raw, 0);
459        let high = _mm256_extracti128_si256(raw, 1);
460        let low_wide = _mm256_cvtepu32_epi64(low);
461        let high_wide = _mm256_cvtepu32_epi64(high);
462        let low_array: [u64; 4] = std::mem::transmute(low_wide);
463        let high_array: [u64; 4] = std::mem::transmute(high_wide);
464        u64x8::from([
465            low_array[0], low_array[1], low_array[2], low_array[3],
466            high_array[0], high_array[1], high_array[2], high_array[3],
467        ])
468    }
469}
470
471#[cfg(target_arch = "x86_64")]
472#[inline]
473#[target_feature(enable = "avx2")]
474unsafe fn widen_u32x4_to_u64x4_avx2(input: u32x4) -> u64x4 {
475    unsafe {
476        let raw = std::mem::transmute::<u32x4, __m128i>(input);
477        let widened = _mm256_cvtepu32_epi64(raw);
478        std::mem::transmute::<__m256i, u64x4>(widened)
479    }
480}
481
482#[cfg(target_arch = "x86_64")]
483#[inline]
484#[target_feature(enable = "avx512f")]
485unsafe fn u64x8_from_bitmask_avx512(mask: u8) -> u64x8 {
486    unsafe {
487        let vec = _mm512_maskz_set1_epi64(mask, -1i64);
488        std::mem::transmute::<__m512i, u64x8>(vec)
489    }
490}
491
492#[cfg(target_arch = "x86_64")]
493#[inline]
494#[target_feature(enable = "avx2")]
495unsafe fn u64x8_from_bitmask_avx2(mask: u8) -> u64x8 {
496    let mut values = [0u64; 8];
497    for i in 0..8 {
498        values[i] = if (mask >> i) & 1 != 0 { u64::MAX } else { 0 };
499    }
500    u64x8::from(values)
501}
502
503#[cfg(target_arch = "x86_64")]
504#[inline]
505#[target_feature(enable = "avx512f")]
506unsafe fn u32x8_from_bitmask_avx512(mask: u8) -> u32x8 {
507    unsafe {
508        let vec = _mm256_maskz_set1_epi32(mask, -1i32);
509        std::mem::transmute::<__m256i, u32x8>(vec)
510    }
511}
512
513#[cfg(target_arch = "x86_64")]
514#[inline]
515#[target_feature(enable = "avx2")]
516unsafe fn u32x8_from_bitmask_avx2(mask: u8) -> u32x8 {
517    let mut values = [0u32; 8];
518    for i in 0..8 {
519        values[i] = if (mask >> i) & 1 != 0 { u32::MAX } else { 0 };
520    }
521    u32x8::from(values)
522}
523
524#[cfg(target_arch = "x86_64")]
525#[inline]
526#[target_feature(enable = "avx2")]
527unsafe fn shuffle_u32x8_avx2(input: u32x8, indices: u32x8) -> u32x8 {
528    unsafe {
529        let raw = std::mem::transmute::<u32x8, __m256i>(input);
530        let idx = std::mem::transmute::<u32x8, __m256i>(indices);
531        let shuffled = _mm256_permutevar8x32_epi32(raw, idx);
532        std::mem::transmute::<__m256i, u32x8>(shuffled)
533    }
534}
535
536#[cfg(target_arch = "x86_64")]
537#[inline]
538#[target_feature(enable = "ssse3")]
539unsafe fn shuffle_u8x16_ssse3(input: u8x16, indices: u8x16) -> u8x16 {
540    unsafe {
541        let raw = std::mem::transmute::<u8x16, __m128i>(input);
542        let idx = std::mem::transmute::<u8x16, __m128i>(indices);
543        let shuffled = _mm_shuffle_epi8(raw, idx);
544        std::mem::transmute::<__m128i, u8x16>(shuffled)
545    }
546}
547
548// =============================================================================
549// ARM NEON Implementations
550// =============================================================================
551
552#[cfg(target_arch = "aarch64")]
553use std::arch::aarch64::*;
554
555#[cfg(target_arch = "aarch64")]
556#[inline]
557unsafe fn widen_u32x8_to_u64x8_neon(input: u32x8) -> u64x8 {
558    let array = input.to_array();
559    unsafe {
560        let low_input = vld1q_u32(array.as_ptr());
561        let high_input = vld1q_u32(array.as_ptr().add(4));
562        let (low_0, low_1) = widen_u32x4_to_u64x4_neon_raw(low_input);
563        let (high_0, high_1) = widen_u32x4_to_u64x4_neon_raw(high_input);
564        let mut result = [0u64; 8];
565        vst1q_u64(result.as_mut_ptr(), low_0);
566        vst1q_u64(result.as_mut_ptr().add(2), low_1);
567        vst1q_u64(result.as_mut_ptr().add(4), high_0);
568        vst1q_u64(result.as_mut_ptr().add(6), high_1);
569        u64x8::from(result)
570    }
571}
572
573#[cfg(target_arch = "aarch64")]
574#[inline]
575unsafe fn widen_u32x4_to_u64x4_neon(input: u32x4) -> u64x4 {
576    let array = input.to_array();
577    unsafe {
578        let neon_input = vld1q_u32(array.as_ptr());
579        let (low, high) = widen_u32x4_to_u64x4_neon_raw(neon_input);
580        let mut result = [0u64; 4];
581        vst1q_u64(result.as_mut_ptr(), low);
582        vst1q_u64(result.as_mut_ptr().add(2), high);
583        u64x4::from(result)
584    }
585}
586
587#[cfg(target_arch = "aarch64")]
588#[inline]
589#[target_feature(enable = "neon")]
590unsafe fn widen_u32x4_to_u64x4_neon_raw(input: uint32x4_t) -> (uint64x2_t, uint64x2_t) {
591    let low = vmovl_u32(vget_low_u32(input));
592    let high = vmovl_u32(vget_high_u32(input));
593    (low, high)
594}
595
596/// NEON-optimized u64x8 from bitmask using parallel bit extraction.
597/// Uses NEON compare to expand 8 bits into 8 u64 values (0 or u64::MAX).
598#[cfg(target_arch = "aarch64")]
599#[inline]
600#[target_feature(enable = "neon")]
601unsafe fn u64x8_from_bitmask_neon(mask: u8) -> u64x8 {
602    unsafe {
603        // Bit pattern for testing each bit position
604        static BIT_PATTERN: [u8; 8] = [1, 2, 4, 8, 16, 32, 64, 128];
605
606        // Broadcast the mask to all 8 lanes
607        let mask_vec = vdup_n_u8(mask);
608
609        // Load the bit pattern
610        let bits = vld1_u8(BIT_PATTERN.as_ptr());
611
612        // AND mask with bit pattern, then compare to isolate each bit
613        let anded = vand_u8(mask_vec, bits);
614        let cmp = vceq_u8(anded, bits); // 0xFF where bit is set, 0x00 otherwise
615
616        // Use SIGNED widening to sign-extend 0xFF (-1) to 0xFFFF, 0xFFFFFFFF, etc.
617        // Reinterpret the u8 comparison result as signed i8
618        let cmp_signed = vreinterpret_s8_u8(cmp);
619
620        // Signed widen: i8 -> i16 -> i32 -> i64
621        // -1 (0xFF) becomes -1 (0xFFFF), then -1 (0xFFFFFFFF), then -1 (0xFFFFFFFFFFFFFFFF)
622        let wide16 = vmovl_s8(cmp_signed);
623
624        let wide32_lo = vmovl_s16(vget_low_s16(wide16));
625        let wide32_hi = vmovl_s16(vget_high_s16(wide16));
626
627        let wide64_0 = vmovl_s32(vget_low_s32(wide32_lo));
628        let wide64_1 = vmovl_s32(vget_high_s32(wide32_lo));
629        let wide64_2 = vmovl_s32(vget_low_s32(wide32_hi));
630        let wide64_3 = vmovl_s32(vget_high_s32(wide32_hi));
631
632        // Reinterpret back to unsigned and store
633        let mut result = [0u64; 8];
634        vst1q_u64(result.as_mut_ptr(), vreinterpretq_u64_s64(wide64_0));
635        vst1q_u64(result.as_mut_ptr().add(2), vreinterpretq_u64_s64(wide64_1));
636        vst1q_u64(result.as_mut_ptr().add(4), vreinterpretq_u64_s64(wide64_2));
637        vst1q_u64(result.as_mut_ptr().add(6), vreinterpretq_u64_s64(wide64_3));
638
639        u64x8::from(result)
640    }
641}
642
643/// NEON-optimized u32x8 from bitmask using parallel bit extraction.
644/// Uses NEON compare to expand 8 bits into 8 u32 values (0 or u32::MAX).
645#[cfg(target_arch = "aarch64")]
646#[inline]
647#[target_feature(enable = "neon")]
648unsafe fn u32x8_from_bitmask_neon(mask: u8) -> u32x8 {
649    unsafe {
650        // Bit pattern for testing each bit position
651        static BIT_PATTERN: [u8; 8] = [1, 2, 4, 8, 16, 32, 64, 128];
652
653        // Broadcast the mask to all 8 lanes
654        let mask_vec = vdup_n_u8(mask);
655
656        // Load the bit pattern
657        let bits = vld1_u8(BIT_PATTERN.as_ptr());
658
659        // AND mask with bit pattern, then compare to isolate each bit
660        let anded = vand_u8(mask_vec, bits);
661        let cmp = vceq_u8(anded, bits); // 0xFF where bit is set, 0x00 otherwise
662
663        // Use SIGNED widening to sign-extend 0xFF (-1) to 0xFFFFFFFF
664        let cmp_signed = vreinterpret_s8_u8(cmp);
665
666        // Signed widen: i8 -> i16 -> i32
667        let wide16 = vmovl_s8(cmp_signed);
668        let wide32_lo = vmovl_s16(vget_low_s16(wide16));
669        let wide32_hi = vmovl_s16(vget_high_s16(wide16));
670
671        // Reinterpret back to unsigned and store
672        let mut result = [0u32; 8];
673        vst1q_u32(result.as_mut_ptr(), vreinterpretq_u32_s32(wide32_lo));
674        vst1q_u32(result.as_mut_ptr().add(4), vreinterpretq_u32_s32(wide32_hi));
675
676        u32x8::from(result)
677    }
678}
679
680/// NEON u8x16 shuffle using TBL instruction
681#[cfg(target_arch = "aarch64")]
682#[inline]
683unsafe fn shuffle_u8x16_neon(input: u8x16, indices: u8x16) -> u8x16 {
684    unsafe {
685        let arr = input.to_array();
686        let idx_arr = indices.to_array();
687        let data = vld1q_u8(arr.as_ptr());
688        let idx = vld1q_u8(idx_arr.as_ptr());
689        let result = vqtbl1q_u8(data, idx);
690        let mut out = [0u8; 16];
691        vst1q_u8(out.as_mut_ptr(), result);
692        u8x16::from(out)
693    }
694}
695
696// NOTE: NEON TBL-based u32x4/u32x8 shuffle implementations were removed.
697// While NEON TBL is fast for u8x16 shuffles, using it for u32 shuffles requires
698// converting u32 indices to byte indices (4 bytes per element) via a loop,
699// which adds significant overhead. Scalar shuffle is faster for u32 types on ARM
700// unless SVE is available.
701
702// =============================================================================
703// SVE Implementations (ARM Scalable Vector Extension)
704// =============================================================================
705
706/// SVE u32x4 shuffle using TBL instruction via inline assembly
707/// SVE TBL does element-wise table lookup with native u32 indices (no byte conversion needed)
708#[cfg(target_arch = "aarch64")]
709#[inline]
710#[target_feature(enable = "sve")]
711unsafe fn shuffle_u32x4_sve(input: u32x4, indices: u32x4) -> u32x4 {
712    use std::arch::asm;
713    let data_arr = input.to_array();
714    let idx_arr = indices.to_array();
715    let mut out = [0u32; 4];
716
717    unsafe {
718        asm!(
719            "ptrue p0.s, vl4",           // Predicate for 4 elements
720            "ld1w {{z0.s}}, p0/z, [{data}]",  // Load data
721            "ld1w {{z1.s}}, p0/z, [{idx}]",   // Load indices
722            "tbl z2.s, {{z0.s}}, z1.s",       // Table lookup permute
723            "st1w {{z2.s}}, p0, [{out}]",     // Store result
724            data = in(reg) data_arr.as_ptr(),
725            idx = in(reg) idx_arr.as_ptr(),
726            out = in(reg) out.as_mut_ptr(),
727            options(nostack)
728        );
729    }
730    u32x4::from(out)
731}
732
733/// SVE u32x8 shuffle using TBL instruction via inline assembly
734/// Processes as two u32x4 halves
735#[cfg(target_arch = "aarch64")]
736#[inline]
737#[target_feature(enable = "sve")]
738unsafe fn shuffle_u32x8_sve(input: u32x8, indices: u32x8) -> u32x8 {
739    use std::arch::asm;
740    let data_arr = input.to_array();
741    let idx_arr = indices.to_array();
742    let mut out = [0u32; 8];
743
744    unsafe {
745        // Process low half (indices 0-3 reference data 0-3)
746        asm!(
747            "ptrue p0.s, vl4",
748            "ld1w {{z0.s}}, p0/z, [{data}]",
749            "ld1w {{z1.s}}, p0/z, [{idx}]",
750            "tbl z2.s, {{z0.s}}, z1.s",
751            "st1w {{z2.s}}, p0, [{out}]",
752            data = in(reg) data_arr.as_ptr(),
753            idx = in(reg) idx_arr.as_ptr(),
754            out = in(reg) out.as_mut_ptr(),
755            options(nostack)
756        );
757
758        // Process high half (indices 4-7 reference data 4-7, adjusted by -4)
759        asm!(
760            "ptrue p0.s, vl4",
761            "ld1w {{z0.s}}, p0/z, [{data}]",  // Load high data
762            "ld1w {{z1.s}}, p0/z, [{idx}]",   // Load high indices
763            "mov z3.s, #4",                    // Load constant 4
764            "sub z1.s, z1.s, z3.s",           // Subtract 4 from indices
765            "tbl z2.s, {{z0.s}}, z1.s",       // Table lookup
766            "st1w {{z2.s}}, p0, [{out}]",
767            data = in(reg) data_arr.as_ptr().add(4),
768            idx = in(reg) idx_arr.as_ptr().add(4),
769            out = in(reg) out.as_mut_ptr().add(4),
770            options(nostack)
771        );
772    }
773    u32x8::from(out)
774}
775
776// =============================================================================
777// Scalar/Portable Fallback Implementations
778// =============================================================================
779
780#[inline(always)]
781fn split_u32x16_cast(input: u32x16) -> (u32x8, u32x8) {
782    // True zero-copy: (u32x8, u32x8) has identical layout to u32x16
783    // Safety: Both are 64 bytes of contiguous u32 values
784    unsafe { std::mem::transmute(input) }
785}
786
787#[inline(always)]
788fn split_u64x8_cast(input: u64x8) -> (u64x4, u64x4) {
789    // True zero-copy: (u64x4, u64x4) has identical layout to u64x8
790    // Safety: Both are 64 bytes of contiguous u64 values
791    unsafe { std::mem::transmute(input) }
792}
793
794#[allow(dead_code)]
795#[inline]
796fn widen_u32x8_to_u64x8_scalar(input: u32x8) -> u64x8 {
797    let array = input.to_array();
798    u64x8::from(array.map(|x| x as u64))
799}
800
801#[allow(dead_code)]
802#[inline]
803fn widen_u32x4_to_u64x4_scalar(input: u32x4) -> u64x4 {
804    let array = input.to_array();
805    u64x4::from(array.map(|x| x as u64))
806}
807
808#[allow(dead_code)]
809#[inline]
810fn u64x8_from_bitmask_scalar(mask: u8) -> u64x8 {
811    let mut values = [0u64; 8];
812    for (i, value) in values.iter_mut().enumerate() {
813        *value = if (mask >> i) & 1 != 0 { u64::MAX } else { 0 };
814    }
815    u64x8::from(values)
816}
817
818#[allow(dead_code)]
819#[inline]
820fn u32x8_from_bitmask_scalar(mask: u8) -> u32x8 {
821    let mut values = [0u32; 8];
822    for (i, value) in values.iter_mut().enumerate() {
823        *value = if (mask >> i) & 1 != 0 { u32::MAX } else { 0 };
824    }
825    u32x8::from(values)
826}
827
828#[allow(dead_code)]
829#[inline]
830fn shuffle_u32x8_scalar(input: u32x8, indices: u32x8) -> u32x8 {
831    let arr = input.to_array();
832    let idx = indices.to_array();
833    let mut result = [0u32; 8];
834    for i in 0..8 {
835        result[i] = arr[(idx[i] & 7) as usize];
836    }
837    u32x8::from(result)
838}
839
840#[allow(dead_code)]
841#[inline]
842fn shuffle_u32x4_scalar(input: u32x4, indices: u32x4) -> u32x4 {
843    let arr = input.to_array();
844    let idx = indices.to_array();
845    let mut result = [0u32; 4];
846    for i in 0..4 {
847        result[i] = arr[(idx[i] & 3) as usize];
848    }
849    u32x4::from(result)
850}
851
852#[allow(dead_code)]
853#[inline]
854fn shuffle_u8x16_scalar(input: u8x16, indices: u8x16) -> u8x16 {
855    let arr = input.to_array();
856    let idx = indices.to_array();
857    let mut result = [0u8; 16];
858    for i in 0..16 {
859        let index = idx[i] as usize;
860        result[i] = if index < 16 { arr[index] } else { 0 };
861    }
862    u8x16::from(result)
863}
864
865// =============================================================================
866// Scalar Mul/Div Implementations
867// =============================================================================
868
869
870// =============================================================================
871// Tests
872// =============================================================================
873
874#[cfg(test)]
875mod tests {
876    use super::*;
877
878    #[test]
879    fn test_u32x8_widening() {
880        let input = u32x8::from([1, 2, 3, 4, 5, 6, 7, 8]);
881        let widened: u64x8 = input.widen_to_u64x8();
882        assert_eq!(widened.to_array(), [1u64, 2, 3, 4, 5, 6, 7, 8]);
883    }
884
885    #[test]
886    fn test_u32x4_widening() {
887        let input = u32x4::from([1, 2, 3, 4]);
888        let widened: u64x4 = input.widen_to_u64x8();
889        assert_eq!(widened.to_array(), [1u64, 2, 3, 4]);
890    }
891
892    #[test]
893    fn test_u64x8_from_bitmask() {
894        let mask = 0b10101010u8;
895        let mask_vec: u64x8 = u64x8::from_bitmask(mask);
896        let expected = [0u64, u64::MAX, 0u64, u64::MAX, 0u64, u64::MAX, 0u64, u64::MAX];
897        assert_eq!(mask_vec.to_array(), expected);
898    }
899
900    #[test]
901    fn test_u32x8_from_bitmask() {
902        let mask = 0b11000011u8;
903        let mask_vec: u32x8 = u32x8::from_bitmask(mask);
904        let expected = [u32::MAX, u32::MAX, 0u32, 0u32, 0u32, 0u32, u32::MAX, u32::MAX];
905        assert_eq!(mask_vec.to_array(), expected);
906    }
907
908    #[test]
909    fn test_edge_cases() {
910        let mask_zero = 0b00000000u8;
911        let vec_zero: u64x8 = u64x8::from_bitmask(mask_zero);
912        assert_eq!(vec_zero.to_array(), [0u64; 8]);
913
914        let mask_all = 0b11111111u8;
915        let vec_all: u64x8 = u64x8::from_bitmask(mask_all);
916        assert_eq!(vec_all.to_array(), [u64::MAX; 8]);
917    }
918
919    #[test]
920    fn test_shuffle_u32x8() {
921        let input = u32x8::from([10, 20, 30, 40, 50, 60, 70, 80]);
922
923        // Identity shuffle
924        let indices = u32x8::from([0, 1, 2, 3, 4, 5, 6, 7]);
925        let result = input.shuffle(indices);
926        assert_eq!(result.to_array(), [10, 20, 30, 40, 50, 60, 70, 80]);
927
928        // Reverse shuffle
929        let indices = u32x8::from([7, 6, 5, 4, 3, 2, 1, 0]);
930        let result = input.shuffle(indices);
931        assert_eq!(result.to_array(), [80, 70, 60, 50, 40, 30, 20, 10]);
932
933        // Compress-like shuffle
934        let indices = u32x8::from([1, 4, 5, 7, 0, 0, 0, 0]);
935        let result = input.shuffle(indices);
936        assert_eq!(result.to_array()[0..4], [20, 50, 60, 80]);
937    }
938
939    #[test]
940    fn test_shuffle_u8x16() {
941        let input = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
942
943        // Identity
944        let indices = u8x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
945        let result = input.shuffle(indices);
946        assert_eq!(result.to_array(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
947
948        // Reverse
949        let indices = u8x16::from([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
950        let result = input.shuffle(indices);
951        assert_eq!(result.to_array(), [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
952    }
953
954    #[test]
955    fn test_simd_compress_indices_table() {
956        // Mask 0b10110010 = bits 1,4,5,7 set
957        let mask = 0b10110010u8;
958        let indices = get_compress_indices_u32x8(mask);
959        let arr = indices.to_array();
960        assert_eq!(arr[0], 1);
961        assert_eq!(arr[1], 4);
962        assert_eq!(arr[2], 5);
963        assert_eq!(arr[3], 7);
964
965        // All set = identity
966        let indices = get_compress_indices_u32x8(0xFF);
967        assert_eq!(indices.to_array(), [0, 1, 2, 3, 4, 5, 6, 7]);
968
969        // Test raw array access
970        let raw = SHUFFLE_COMPRESS_IDX_U32X8[0b10110010];
971        assert_eq!(raw[0], 1);
972        assert_eq!(raw[1], 4);
973    }
974
975    #[test]
976    fn test_simd_split_u32x16() {
977        let input = u32x16::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
978        let (lo, hi) = input.split_low_high();
979
980        assert_eq!(lo.to_array(), [1, 2, 3, 4, 5, 6, 7, 8]);
981        assert_eq!(hi.to_array(), [9, 10, 11, 12, 13, 14, 15, 16]);
982    }
983
984    #[test]
985    fn test_simd_split_u64x8() {
986        let input = u64x8::from([1, 2, 3, 4, 5, 6, 7, 8]);
987        let (lo, hi) = input.split_low_high();
988
989        assert_eq!(lo.to_array(), [1, 2, 3, 4]);
990        assert_eq!(hi.to_array(), [5, 6, 7, 8]);
991    }
992
993    #[test]
994    fn test_shuffle_u32x4() {
995        let input = u32x4::from([10, 20, 30, 40]);
996
997        // Identity
998        let indices = u32x4::from([0, 1, 2, 3]);
999        let result = input.shuffle(indices);
1000        assert_eq!(result.to_array(), [10, 20, 30, 40]);
1001
1002        // Reverse
1003        let indices = u32x4::from([3, 2, 1, 0]);
1004        let result = input.shuffle(indices);
1005        assert_eq!(result.to_array(), [40, 30, 20, 10]);
1006
1007        // Broadcast first element
1008        let indices = u32x4::from([0, 0, 0, 0]);
1009        let result = input.shuffle(indices);
1010        assert_eq!(result.to_array(), [10, 10, 10, 10]);
1011    }
1012
1013    #[test]
1014    fn test_double_u8x16() {
1015        let a = u8x16::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
1016        let result = a.double();
1017        assert_eq!(
1018            result.to_array(),
1019            [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32]
1020        );
1021    }
1022
1023    #[test]
1024    fn test_double_triple_for_x8() {
1025        // x.double().double().double() = x * 8
1026        let a = u8x16::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
1027        let result = a.double().double().double();
1028        assert_eq!(
1029            result.to_array(),
1030            [8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128]
1031        );
1032    }
1033
1034    #[test]
1035    fn test_double_overflow() {
1036        // Test wrapping: 128 * 2 = 256 wraps to 0
1037        let a = u8x16::splat(128);
1038        let result = a.double();
1039        assert_eq!(result.to_array(), [0u8; 16]);
1040
1041        // 200 * 2 = 400 wraps to 144
1042        let a = u8x16::splat(200);
1043        let result = a.double();
1044        assert_eq!(result.to_array(), [144u8; 16]); // 400 & 0xFF = 144
1045    }
1046}