Skip to main content

v_escape_base/
vector.rs

1// Adapted from https://github.com/BurntSushi/memchr/blob/master/src/vector.rs
2/// A trait for describing vector operations used by vectorized searchers.
3///
4/// The trait is highly constrained to low level vector operations needed.
5/// In general, it was invented mostly to be generic over x86's __m128i and
6/// __m256i types. At time of writing, it also supports wasm and aarch64
7/// 128-bit vector types as well.
8///
9/// # Safety
10///
11/// All methods are not safe since they are intended to be implemented using
12/// vendor intrinsics, which are also not safe. Callers must ensure that the
13/// appropriate target features are enabled in the calling function, and that
14/// the current CPU supports them. All implementations should avoid marking the
15/// routines with #\[target_feature\] and instead mark them as #[\inline(always)\]
16/// to ensure they get appropriately inlined. (inline(always) cannot be used
17/// with target_feature.)
18pub trait Vector: Copy + core::fmt::Debug {
19    /// The number of bytes in the vector. That is, this is the size of the
20    /// vector in memory.
21    const BYTES: usize;
22    /// The bits that must be zero in order for a `*const u8` pointer to be
23    /// correctly aligned to read vector values.
24    const ALIGN: usize;
25
26    /// The type of the value returned by `Vector::movemask`.
27    ///
28    /// This supports abstracting over the specific representation used in
29    /// order to accommodate different representations in different ISAs.
30    type Mask: MoveMask;
31
32    /// Create a vector with 8-bit lanes with the given byte repeated into each
33    /// lane.
34    fn splat(byte: u8) -> Self;
35
36    /// Read a vector-size number of bytes from the given pointer. The pointer
37    /// must be aligned to the size of the vector.
38    ///
39    /// # Safety
40    ///
41    /// Callers must guarantee that at least `BYTES` bytes are readable from
42    /// `data` and that `data` is aligned to a `BYTES` boundary.
43    unsafe fn load_aligned(data: *const u8) -> Self;
44
45    /// Read a vector-size number of bytes from the given pointer. The pointer
46    /// does not need to be aligned.
47    ///
48    /// # Safety
49    ///
50    /// Callers must guarantee that at least `BYTES` bytes are readable from
51    /// `data`.
52    unsafe fn load_unaligned(data: *const u8) -> Self;
53
54    /// Convert the vector to a mask.
55    fn movemask(self) -> Self::Mask;
56
57    /// Compare two vectors for equality.
58    fn cmpeq(self, vector2: Self) -> Self;
59
60    /// Bitwise OR of two vectors.
61    fn or(self, vector2: Self) -> Self;
62
63    /// Add two vectors.
64    fn add(self, vector2: Self) -> Self;
65
66    /// Compare two vectors for greater than.
67    fn gt(self, vector2: Self) -> Self;
68
69    /// Returns true if and only if `Self::movemask` would return a mask that
70    /// contains at least one non-zero bit.
71    #[inline(always)]
72    fn movemask_will_have_non_zero(self) -> bool {
73        self.movemask().has_non_zero()
74    }
75}
76
77/// A trait that abstracts over a vector-to-scalar operation called
78/// "move mask."
79///
80/// On x86-64, this is `_mm_movemask_epi8` for SSE2 and `_mm256_movemask_epi8`
81/// for AVX2. It takes a vector of `u8` lanes and returns a scalar where the
82/// `i`th bit is set if and only if the most significant bit in the `i`th lane
83/// of the vector is set. The simd128 ISA for wasm32 also supports this
84/// exact same operation natively.
85///
86/// ... But aarch64 doesn't. So we have to fake it with more instructions and
87/// a slightly different representation. We could do extra work to unify the
88/// representations, but then would require additional costs in the hot path
89/// for `memchr` and `packedpair`. So instead, we abstraction over the specific
90/// representation with this trait and define the operations we actually need.
91pub trait MoveMask: Copy + core::fmt::Debug {
92    /// Returns true if and only if this mask has a a non-zero bit anywhere.
93    fn has_non_zero(self) -> bool;
94
95    /// Returns shifted the mask to the right by the specified number of positions.
96    fn shr(self, rhs: u32) -> Self;
97
98    /// Returns a mask that is equivalent to `self` but with the least
99    /// significant 1-bit set to 0.
100    fn clear_least_significant_bit(self) -> Self;
101
102    /// Returns the offset of the first non-zero lane this mask represents.
103    fn first_offset(self) -> usize;
104}
105
106/// This is a "sensible" movemask implementation where each bit represents
107/// whether the most significant bit is set in each corresponding lane of a
108/// vector. This is used on x86-64 and wasm, but such a mask is more expensive
109/// to get on aarch64 so we use something a little different.
110///
111/// We call this "sensible" because this is what we get using native sse/avx
112/// movemask instructions. But neon has no such native equivalent.
113#[cfg(any(
114    target_arch = "x86_64",
115    all(target_arch = "wasm32", target_feature = "simd128")
116))]
117#[derive(Clone, Copy)]
118pub struct SensibleMoveMask(u32);
119
120#[cfg(any(
121    target_arch = "x86_64",
122    all(target_arch = "wasm32", target_feature = "simd128")
123))]
124impl core::fmt::Debug for SensibleMoveMask {
125    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
126        write!(f, "{:b}", self.0)
127    }
128}
129
130#[cfg(any(
131    target_arch = "x86_64",
132    all(target_arch = "wasm32", target_feature = "simd128")
133))]
134impl SensibleMoveMask {
135    /// Get the mask in a form suitable for computing offsets.
136    ///
137    /// Basically, this normalizes to little endian. On big endian, this swaps
138    /// the bytes.
139    // TODO: Endianness does NOT affect the result of bitwise operations
140    // (like <<, >>, &) or methods like `.trailing_zeros()` on the integer
141    // returned by a SIMD movemask. The bit order in the movemask result is
142    // defined by the SIMD instruction set (e.g., bit 0 corresponds to lane
143    // 0, bit 1 to lane 1, etc.), regardless of how bytes are stored in
144    // memory. So shifting the mask or counting trailing zeros is safe and
145    // portable.
146    #[inline(always)]
147    fn get_for_offset(self) -> u32 {
148        #[cfg(target_endian = "big")]
149        {
150            self.0.swap_bytes()
151        }
152        #[cfg(target_endian = "little")]
153        {
154            self.0
155        }
156    }
157}
158
159#[cfg(any(
160    target_arch = "x86_64",
161    all(target_arch = "wasm32", target_feature = "simd128")
162))]
163impl MoveMask for SensibleMoveMask {
164    #[inline(always)]
165    fn has_non_zero(self) -> bool {
166        self.0 != 0
167    }
168
169    #[inline(always)]
170    fn clear_least_significant_bit(self) -> SensibleMoveMask {
171        SensibleMoveMask(self.0 & (self.0 - 1))
172    }
173
174    #[inline(always)]
175    fn first_offset(self) -> usize {
176        // We are dealing with little endian here (and if we aren't, we swap
177        // the bytes so we are in practice), where the most significant byte
178        // is at a higher address. That means the least significant bit that
179        // is set corresponds to the position of our first matching byte.
180        // That position corresponds to the number of zeros after the least
181        // significant bit.
182        self.get_for_offset().trailing_zeros() as usize
183    }
184
185    fn shr(self, rhs: u32) -> Self {
186        // Endianness is not relevant here because the mask always uses
187        // first_offset to compute the offset.
188        SensibleMoveMask(self.0.wrapping_shr(rhs))
189    }
190}
191
192/// Noop implementation for types that don't support vectorization.
193impl Vector for () {
194    const BYTES: usize = 0;
195
196    const ALIGN: usize = 0;
197
198    type Mask = ();
199
200    #[inline(always)]
201    fn splat(_byte: u8) -> Self {
202        unreachable!()
203    }
204
205    #[inline(always)]
206    unsafe fn load_aligned(_data: *const u8) -> Self {
207        unreachable!()
208    }
209
210    #[inline(always)]
211    unsafe fn load_unaligned(_data: *const u8) -> Self {
212        unreachable!()
213    }
214
215    #[inline(always)]
216    fn movemask(self) -> Self::Mask {
217        unreachable!()
218    }
219
220    #[inline(always)]
221    fn cmpeq(self, _vector2: Self) -> Self {
222        unreachable!()
223    }
224
225    #[inline(always)]
226    fn or(self, _vector2: Self) -> Self {
227        unreachable!()
228    }
229
230    #[inline(always)]
231    fn add(self, _vector2: Self) -> Self {
232        unreachable!()
233    }
234
235    #[inline(always)]
236    fn gt(self, _vector2: Self) -> Self {
237        unreachable!()
238    }
239}
240
241/// Noop implementation for types that don't support vectorization.
242impl MoveMask for () {
243    #[inline(always)]
244    fn has_non_zero(self) -> bool {
245        unreachable!()
246    }
247
248    #[inline(always)]
249    fn shr(self, _rhs: u32) -> Self {
250        unreachable!()
251    }
252
253    #[inline(always)]
254    fn clear_least_significant_bit(self) -> Self {
255        unreachable!()
256    }
257
258    #[inline(always)]
259    fn first_offset(self) -> usize {
260        unreachable!()
261    }
262}
263
264#[cfg(target_arch = "x86_64")]
265mod x86sse2 {
266    use core::arch::x86_64::*;
267
268    use super::{SensibleMoveMask, Vector};
269
270    impl Vector for __m128i {
271        const BYTES: usize = 16;
272        const ALIGN: usize = Self::BYTES - 1;
273
274        type Mask = SensibleMoveMask;
275
276        #[inline(always)]
277        fn splat(byte: u8) -> Self {
278            unsafe { _mm_set1_epi8(byte as i8) }
279        }
280
281        #[inline(always)]
282        unsafe fn load_aligned(data: *const u8) -> Self {
283            unsafe { _mm_load_si128(data as *const __m128i) }
284        }
285
286        #[inline(always)]
287        unsafe fn load_unaligned(data: *const u8) -> Self {
288            unsafe { _mm_loadu_si128(data as *const __m128i) }
289        }
290
291        #[inline(always)]
292        fn movemask(self) -> Self::Mask {
293            SensibleMoveMask(unsafe { _mm_movemask_epi8(self) } as u32)
294        }
295
296        #[inline(always)]
297        fn cmpeq(self, vector2: Self) -> Self {
298            unsafe { _mm_cmpeq_epi8(self, vector2) }
299        }
300
301        #[inline(always)]
302        fn or(self, vector2: Self) -> Self {
303            unsafe { _mm_or_si128(self, vector2) }
304        }
305
306        #[inline(always)]
307        fn add(self, vector2: Self) -> Self {
308            unsafe { _mm_add_epi8(self, vector2) }
309        }
310
311        #[inline(always)]
312        fn gt(self, vector2: Self) -> Self {
313            unsafe { _mm_cmpgt_epi8(self, vector2) }
314        }
315    }
316}
317
318#[cfg(target_arch = "x86_64")]
319mod x86avx2 {
320    use core::arch::x86_64::*;
321
322    use super::{SensibleMoveMask, Vector};
323
324    impl Vector for __m256i {
325        const BYTES: usize = 32;
326        const ALIGN: usize = Self::BYTES - 1;
327
328        type Mask = SensibleMoveMask;
329
330        #[inline(always)]
331        fn splat(byte: u8) -> Self {
332            unsafe { _mm256_set1_epi8(byte as i8) }
333        }
334
335        #[inline(always)]
336        unsafe fn load_aligned(data: *const u8) -> Self {
337            unsafe { _mm256_load_si256(data as *const __m256i) }
338        }
339
340        #[inline(always)]
341        unsafe fn load_unaligned(data: *const u8) -> Self {
342            unsafe { _mm256_loadu_si256(data as *const __m256i) }
343        }
344
345        #[inline(always)]
346        fn movemask(self) -> Self::Mask {
347            SensibleMoveMask(unsafe { _mm256_movemask_epi8(self) } as u32)
348        }
349
350        #[inline(always)]
351        fn cmpeq(self, vector2: Self) -> Self {
352            unsafe { _mm256_cmpeq_epi8(self, vector2) }
353        }
354
355        #[inline(always)]
356        fn or(self, vector2: Self) -> Self {
357            unsafe { _mm256_or_si256(self, vector2) }
358        }
359
360        fn add(self, vector2: Self) -> Self {
361            unsafe { _mm256_add_epi8(self, vector2) }
362        }
363
364        fn gt(self, vector2: Self) -> Self {
365            unsafe { _mm256_cmpgt_epi8(self, vector2) }
366        }
367    }
368}
369
370#[cfg(target_arch = "aarch64")]
371mod aarch64neon {
372    use core::arch::aarch64::*;
373
374    use super::{MoveMask, Vector};
375
376    impl Vector for int8x16_t {
377        const BYTES: usize = 16;
378        const ALIGN: usize = Self::BYTES - 1;
379
380        type Mask = NeonMoveMask;
381
382        #[inline(always)]
383        fn splat(byte: u8) -> Self {
384            unsafe { vdupq_n_s8(byte as i8) }
385        }
386
387        #[inline(always)]
388        unsafe fn load_aligned(data: *const u8) -> Self {
389            // I've tried `data.cast::<uint8x16_t>().read()` instead, but
390            // couldn't observe any benchmark differences.
391            unsafe { Self::load_unaligned(data) }
392        }
393
394        #[inline(always)]
395        unsafe fn load_unaligned(data: *const u8) -> Self {
396            unsafe { vld1q_s8(data as *const i8) }
397        }
398
399        #[inline(always)]
400        fn movemask(self) -> NeonMoveMask {
401            let asu16s = unsafe { vreinterpretq_u16_s8(self) };
402            let mask = unsafe { vshrn_n_u16(asu16s, 4) };
403            let asu64 = unsafe { vreinterpret_u64_u8(mask) };
404            let scalar64 = unsafe { vget_lane_u64(asu64, 0) };
405            NeonMoveMask(scalar64 & 0x8888888888888888)
406        }
407
408        #[inline(always)]
409        fn cmpeq(self, vector2: Self) -> Self {
410            unsafe { vreinterpretq_s8_u8(vceqq_s8(self, vector2)) }
411        }
412
413        #[inline(always)]
414        fn or(self, vector2: Self) -> Self {
415            unsafe { vorrq_s8(self, vector2) }
416        }
417
418        /// This is the only interesting implementation of this routine.
419        /// Basically, instead of doing the "shift right narrow" dance, we use
420        /// adjacent folding max to determine whether there are any non-zero
421        /// bytes in our mask. If there are, *then* we'll do the "shift right
422        /// narrow" dance. In benchmarks, this does lead to slightly better
423        /// throughput, but the win doesn't appear huge.
424        #[inline(always)]
425        fn movemask_will_have_non_zero(self) -> bool {
426            let self_ = unsafe { vreinterpretq_u8_s8(self) };
427            let low = unsafe { vreinterpretq_u64_u8(vpmaxq_u8(self_, self_)) };
428            unsafe { vgetq_lane_u64(low, 0) != 0 }
429        }
430
431        #[inline(always)]
432        fn add(self, vector2: Self) -> Self {
433            unsafe { vaddq_s8(self, vector2) }
434        }
435
436        #[inline(always)]
437        fn gt(self, vector2: Self) -> Self {
438            unsafe { vreinterpretq_s8_u8(vcgtq_s8(self, vector2)) }
439        }
440    }
441
442    /// Neon doesn't have a `movemask` that works like the one in x86-64, so we
443    /// wind up using a different method[1]. The different method also produces
444    /// a mask, but 4 bits are set in the neon case instead of a single bit set
445    /// in the x86-64 case. We do an extra step to zero out 3 of the 4 bits,
446    /// but we still wind up with at least 3 zeroes between each set bit. This
447    /// generally means that we need to do some division by 4 before extracting
448    /// offsets.
449    ///
450    /// In fact, the existence of this type is the entire reason that we have
451    /// the `MoveMask` trait in the first place. This basically lets us keep
452    /// the different representations of masks without being forced to unify
453    /// them into a single representation, which could result in extra and
454    /// unnecessary work.
455    ///
456    /// [1]: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
457    #[derive(Clone, Copy)]
458    pub struct NeonMoveMask(u64);
459
460    impl core::fmt::Debug for NeonMoveMask {
461        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
462            write!(f, "{:b}", self.0)
463        }
464    }
465
466    impl NeonMoveMask {
467        /// Get the mask in a form suitable for computing offsets.
468        ///
469        /// Basically, this normalizes to little endian. On big endian, this
470        /// swaps the bytes.
471        // TODO: Endianness does NOT affect the result of bitwise operations
472        // (like <<, >>, &) or methods like `.trailing_zeros()` on the integer
473        // returned by a SIMD movemask. The bit order in the movemask result is
474        // defined by the SIMD instruction set (e.g., bit 0 corresponds to lane
475        // 0, bit 1 to lane 1, etc.), regardless of how bytes are stored in
476        // memory. So shifting the mask or counting trailing zeros is safe and
477        // portable.
478        #[inline(always)]
479        fn get_for_offset(self) -> u64 {
480            #[cfg(target_endian = "big")]
481            {
482                self.0.swap_bytes()
483            }
484            #[cfg(target_endian = "little")]
485            {
486                self.0
487            }
488        }
489    }
490
491    impl MoveMask for NeonMoveMask {
492        #[inline(always)]
493        fn has_non_zero(self) -> bool {
494            self.0 != 0
495        }
496
497        #[inline(always)]
498        fn shr(self, rhs: u32) -> Self {
499            // Mask is 64 bits instead of 16 bits (for a 128 bit vector)
500            // so every position has 4 bits. We need to multiply the shift
501            // amount by 4 to shift the bits correctly.
502            // Endianness is not relevant here because the mask always uses
503            // first_offset to compute the offset and shift operations always
504            // respect the value.
505            NeonMoveMask(self.0.wrapping_shr(rhs << 2))
506        }
507
508        #[inline(always)]
509        fn clear_least_significant_bit(self) -> NeonMoveMask {
510            NeonMoveMask(self.0 & (self.0 - 1))
511        }
512
513        #[inline(always)]
514        fn first_offset(self) -> usize {
515            // We are dealing with little endian here (and if we aren't,
516            // we swap the bytes so we are in practice), where the most
517            // significant byte is at a higher address. That means the least
518            // significant bit that is set corresponds to the position of our
519            // first matching byte. That position corresponds to the number of
520            // zeros after the least significant bit.
521            //
522            // Note that unlike `SensibleMoveMask`, this mask has its bits
523            // spread out over 64 bits instead of 16 bits (for a 128 bit
524            // vector). Namely, where as x86-64 will turn
525            //
526            //   0x00 0xFF 0x00 0x00 0xFF
527            //
528            // into 10010, our neon approach will turn it into
529            //
530            //   10000000000010000000
531            //
532            // And this happens because neon doesn't have a native `movemask`
533            // instruction, so we kind of fake it[1]. Thus, we divide the
534            // number of trailing zeros by 4 to get the "real" offset.
535            //
536            // [1]: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
537            (self.get_for_offset().trailing_zeros() >> 2) as usize
538        }
539    }
540}
541
542#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
543mod wasm_simd128 {
544    use core::arch::wasm32::*;
545
546    use super::{SensibleMoveMask, Vector};
547
548    impl Vector for v128 {
549        const BYTES: usize = 16;
550        const ALIGN: usize = Self::BYTES - 1;
551
552        type Mask = SensibleMoveMask;
553
554        #[inline(always)]
555        fn splat(byte: u8) -> Self {
556            u8x16_splat(byte)
557        }
558
559        #[inline(always)]
560        unsafe fn load_aligned(data: *const u8) -> Self {
561            unsafe { *data.cast() }
562        }
563
564        #[inline(always)]
565        unsafe fn load_unaligned(data: *const u8) -> Self {
566            unsafe { v128_load(data.cast()) }
567        }
568
569        #[inline(always)]
570        fn movemask(self) -> SensibleMoveMask {
571            SensibleMoveMask(u8x16_bitmask(self).into())
572        }
573
574        #[inline(always)]
575        fn cmpeq(self, vector2: Self) -> Self {
576            i8x16_eq(self, vector2)
577        }
578
579        #[inline(always)]
580        fn or(self, vector2: Self) -> Self {
581            v128_or(self, vector2)
582        }
583
584        fn add(self, vector2: Self) -> Self {
585            i8x16_add(self, vector2)
586        }
587
588        fn gt(self, vector2: Self) -> Self {
589            i8x16_gt(self, vector2)
590        }
591    }
592}