sonic_simd/
neon.rs

1use std::arch::aarch64::*;
2
3use super::{bits::NeonBits, Mask, Simd};
4
5#[derive(Debug)]
6#[repr(transparent)]
7pub struct Simd128u(uint8x16_t);
8
9#[derive(Debug)]
10#[repr(transparent)]
11pub struct Simd128i(int8x16_t);
12
13impl Simd for Simd128u {
14    const LANES: usize = 16;
15    type Mask = Mask128;
16    type Element = u8;
17
18    #[inline(always)]
19    unsafe fn loadu(ptr: *const u8) -> Self {
20        Self(vld1q_u8(ptr))
21    }
22
23    #[inline(always)]
24    unsafe fn storeu(&self, ptr: *mut u8) {
25        vst1q_u8(ptr, self.0);
26    }
27
28    #[inline(always)]
29    fn eq(&self, lhs: &Self) -> Self::Mask {
30        unsafe { Mask128(vceqq_u8(self.0, lhs.0)) }
31    }
32
33    #[inline(always)]
34    fn splat(ch: u8) -> Self {
35        unsafe { Self(vdupq_n_u8(ch)) }
36    }
37
38    // less or equal
39    #[inline(always)]
40    fn le(&self, lhs: &Self) -> Self::Mask {
41        unsafe { Mask128(vcleq_u8(self.0, lhs.0)) }
42    }
43
44    // greater than
45    #[inline(always)]
46    fn gt(&self, lhs: &Self) -> Self::Mask {
47        unsafe { Mask128(vcgtq_u8(self.0, lhs.0)) }
48    }
49}
50
51impl Simd for Simd128i {
52    const LANES: usize = 16;
53    type Mask = Mask128;
54    type Element = i8;
55
56    #[inline(always)]
57    unsafe fn loadu(ptr: *const u8) -> Self {
58        Self(vld1q_s8(ptr as *const i8))
59    }
60
61    #[inline(always)]
62    unsafe fn storeu(&self, ptr: *mut u8) {
63        vst1q_s8(ptr as *mut i8, self.0);
64    }
65
66    #[inline(always)]
67    fn eq(&self, lhs: &Self) -> Self::Mask {
68        unsafe { Mask128(vceqq_s8(self.0, lhs.0)) }
69    }
70
71    #[inline(always)]
72    fn splat(elem: i8) -> Self {
73        unsafe { Self(vdupq_n_s8(elem)) }
74    }
75
76    // less or equal
77    #[inline(always)]
78    fn le(&self, lhs: &Self) -> Self::Mask {
79        unsafe { Mask128(vcleq_s8(self.0, lhs.0)) }
80    }
81
82    // greater than
83    #[inline(always)]
84    fn gt(&self, lhs: &Self) -> Self::Mask {
85        unsafe { Mask128(vcgtq_s8(self.0, lhs.0)) }
86    }
87}
88
89pub(crate) const BIT_MASK_TAB: [u8; 16] = [
90    0x01u8, 0x02, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80,
91];
92
93#[derive(Debug)]
94#[repr(transparent)]
95pub struct Mask128(pub(crate) uint8x16_t);
96
97impl Mask for Mask128 {
98    type BitMask = NeonBits;
99    type Element = u8;
100
101    /// Convert Mask Vector 0x00-ff-ff to Bits 0b0000-1111-1111
102    /// Reference: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
103    #[inline(always)]
104    fn bitmask(self) -> Self::BitMask {
105        unsafe {
106            let v16 = vreinterpretq_u16_u8(self.0);
107            let sr4 = vshrn_n_u16(v16, 4);
108            let v64 = vreinterpret_u64_u8(sr4);
109            NeonBits::new(vget_lane_u64(v64, 0))
110        }
111    }
112
113    #[inline(always)]
114    fn splat(b: bool) -> Self {
115        let v: i8 = if b { -1 } else { 0 };
116        unsafe { Self(vdupq_n_u8(v as u8)) }
117    }
118}
119
120// Bitwise AND for Mask128
121impl std::ops::BitAnd<Mask128> for Mask128 {
122    type Output = Self;
123
124    #[inline(always)]
125    fn bitand(self, rhs: Mask128) -> Self::Output {
126        unsafe { Self(vandq_u8(self.0, rhs.0)) }
127    }
128}
129
130// Bitwise OR for Mask128
131impl std::ops::BitOr<Mask128> for Mask128 {
132    type Output = Self;
133
134    #[inline(always)]
135    fn bitor(self, rhs: Mask128) -> Self::Output {
136        unsafe { Self(vorrq_u8(self.0, rhs.0)) }
137    }
138}
139
140// Bitwise OR assignment for Mask128
141impl std::ops::BitOrAssign<Mask128> for Mask128 {
142    #[inline(always)]
143    fn bitor_assign(&mut self, rhs: Mask128) {
144        unsafe {
145            self.0 = vorrq_u8(self.0, rhs.0);
146        }
147    }
148}
149
150#[inline(always)]
151pub unsafe fn to_bitmask64(v0: uint8x16_t, v1: uint8x16_t, v2: uint8x16_t, v3: uint8x16_t) -> u64 {
152    let bit_mask = std::mem::transmute::<[u8; 16], uint8x16_t>(BIT_MASK_TAB);
153
154    let t0 = vandq_u8(v0, bit_mask);
155    let t1 = vandq_u8(v1, bit_mask);
156    let t2 = vandq_u8(v2, bit_mask);
157    let t3 = vandq_u8(v3, bit_mask);
158
159    let pair0 = vpaddq_u8(t0, t1);
160    let pair1 = vpaddq_u8(t2, t3);
161    let quad = vpaddq_u8(pair0, pair1);
162    let octa = vpaddq_u8(quad, quad);
163
164    vgetq_lane_u64(vreinterpretq_u64_u8(octa), 0)
165}
166
167#[inline(always)]
168pub(crate) unsafe fn to_bitmask32(v0: uint8x16_t, v1: uint8x16_t) -> u32 {
169    let bit_mask = std::mem::transmute::<[u8; 16], uint8x16_t>(BIT_MASK_TAB);
170
171    let t0 = vandq_u8(v0, bit_mask);
172    let t1 = vandq_u8(v1, bit_mask);
173
174    let pair = vpaddq_u8(t0, t1);
175    let quad = vpaddq_u8(pair, pair);
176    let octa = vpaddq_u8(quad, quad);
177
178    vgetq_lane_u32(vreinterpretq_u32_u8(octa), 0)
179}