Skip to main content

simdprune/
lib.rs

1// Copyright 2021 Daniel Philip Watson
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Pruning elements in SIMD vectors
16//!
17//! This crate is a port of Daniel Lemire's C library [simdprune](https://github.com/lemire/simdprune/).
18//!
19//! The mask "marks" values in the input for deletion. So if
20//! the mask is odd, then the first value is marked for deletion.
21//! This function produces a new vector that start with all
22//! values that have not been deleted.
23//!
24//! Passing a mask of 0 would simply copy the provided vector.
25//!
26//! Note that this is the opposite of the mask behavior of AVX512 VCOMPRESS/VPCOMRESS instructions.
27//! If you have AVX512 much of this crate can be performed with those instructions.
28//!
29//! # Examples
30//!
31//! See [`prune_epi32`].
32//!
33//! # Features
34//!
35//! All features below are enabled by default.
36//!
37//! * **std** - Enables the standard library. Disabling this enables the `no_std` crate attribute.
38//! * **large_tables** - Enables functions like [`prune_epi8`] which require large tables (>1MB).
39//! Disabling this may speed up compilation.
40
41#![cfg_attr(docsrs, feature(doc_cfg))]
42#![cfg_attr(not(feature = "std"), no_std)]
43
44#[cfg(feature = "large_tables")]
45mod large_tables;
46mod tables;
47
48#[cfg(target_arch = "x86")]
49use core::arch::x86::*;
50#[cfg(target_arch = "x86_64")]
51use core::arch::x86_64::*;
52
53#[cfg(feature = "large_tables")]
54use large_tables::mask128_epi8;
55use tables::*;
56
57/// Prune 8-bit values.
58///
59/// Values corresponding to a 1-bit in the mask are removed from output
60///
61/// The table used for this operation occupies 1 MB.
62///
63/// The last value not deleted is used to pad the result.
64///
65/// Requires the `large_tables` feature (enabled by default).
66///
67/// Trick: by leaving the highest bit (`1 << 15`) to zero whether
68/// you want to delete the last value or not, you can end up using
69/// only the first half of the table (which limits cache usage).
70///
71/// # Panics
72/// Panics if `mask` is not in `[0, 1 << 16)`.
73/// # Examples
74/// See [`prune_epi32`].
75#[target_feature(enable = "ssse3")]
76#[cfg(feature = "large_tables")]
77#[cfg_attr(docsrs, doc(cfg(feature = "large_tables")))]
78#[inline]
79pub unsafe fn prune_epi8(x: __m128i, mask: i32) -> __m128i {
80    let ptr = mask128_epi8[16 * mask as usize..].as_ptr().cast();
81    _mm_shuffle_epi8(x, _mm_loadu_si128(ptr))
82}
83
84#[inline]
85unsafe fn left_shift_bytes(x: __m128i, count: i32) -> __m128i {
86    // we'd like to shift by count bytes, but it can't be done directly without immediates
87    let p1 = _mm_sll_epi64(x, _mm_cvtsi64_si128(count as i64 * 8));
88    let p2 = _mm_srl_epi64(
89        _mm_unpacklo_epi64(_mm_setzero_si128(), x),
90        _mm_cvtsi64_si128(64 - count as i64 * 8),
91    );
92    _mm_or_si128(p1, p2)
93}
94
95/// Prune 8-bit values. Like [`prune_epi8`] but uses a 2kB table.
96///
97/// Values corresponding to a 1-bit in the mask are removed from output
98///
99/// Note that this will be faster if you enable the `popcnt` instruction set feature flag,
100/// available on SSE4.2 and later.
101///
102/// # Panics
103/// Panics if `mask` is not in `[0, 1 << 16)`.
104/// # Examples
105/// See [`prune_epi32`].
106#[target_feature(enable = "ssse3")]
107#[inline]
108pub unsafe fn thinprune_epi8(x: __m128i, mask: i32) -> __m128i {
109    let mask1 = mask & 0xFF;
110    let pop = 8 - mask1.count_ones();
111    let mask2 = mask as u32 >> 8; // we want a logical shift here
112    let m1 = _mm_loadl_epi64(thintable_epi8[mask1 as usize..].as_ptr().cast());
113    let m2 = _mm_loadl_epi64(thintable_epi8[mask2 as usize..].as_ptr().cast());
114    let m2add = _mm_add_epi8(m2, _mm_set1_epi8(8));
115    let m2shifted = left_shift_bytes(m2add, pop as i32);
116    let shufmask = _mm_or_si128(m2shifted, m1);
117    _mm_shuffle_epi8(x, shufmask)
118}
119
120/// Prune 8-bit values. Like [`prune_epi8`] but uses a <1kB table.
121///
122/// Values corresponding to a 1-bit in the mask are removed from output
123///
124/// credit: @animetosho
125///
126/// # Panics
127/// Panics if `mask` is not in `[0, 1 << 16)`.
128/// # Examples
129/// See [`prune_epi32`].
130#[target_feature(enable = "ssse3")]
131#[inline]
132pub unsafe fn skinnyprune_epi8(x: __m128i, mask: i32) -> __m128i {
133    let mask1 = mask & 0xFF;
134    // we want a logical shift here
135    let mask2 = mask as u32 >> 8;
136    // reference impl uses _mm_loadh_pi but since Rust removed __m64 support,
137    // we use _mm_loadh_pd here.
138    let ptr1 = thintable_epi8[mask1 as usize..].as_ptr().cast();
139    let ptr2 = thintable_epi8[mask2 as usize..].as_ptr().cast();
140    let mut shufmask =
141        _mm_castpd_si128(_mm_loadh_pd(_mm_castsi128_pd(_mm_loadl_epi64(ptr1)), ptr2));
142    shufmask = _mm_add_epi8(shufmask, _mm_set_epi32(0x0808_0808, 0x0808_0808, 0, 0));
143    let pruned = _mm_shuffle_epi8(x, shufmask);
144    let popx2 = BitsSetTable256mul2[mask1 as usize];
145    let compactmask = _mm_loadu_si128(pshufb_combine_table[popx2 as usize * 8..].as_ptr().cast());
146    _mm_shuffle_epi8(pruned, compactmask)
147}
148
149/// Prune 8-bit values.
150///
151/// Values corresponding to a 1-bit in the mask are removed from output
152///
153/// The table used for this operation occupies 4 kB.
154///
155/// The last value not deleted is used to pad the result.
156///
157/// Trick: by leaving the highest bit (`1 << 7`) to zero whether
158/// you want to delete the last value or not, you can end up using
159/// only the first half of the table (which limits cache usage).
160///
161/// # Panics
162/// Panics if `mask` is not in `[0, 1 << 8)`.
163/// # Examples
164/// See [`prune_epi32`].
165#[target_feature(enable = "ssse3")]
166#[inline]
167pub unsafe fn prune_epi16(x: __m128i, mask: i32) -> __m128i {
168    let ptr = mask128_epi16[16 * mask as usize..].as_ptr().cast();
169    _mm_shuffle_epi8(x, _mm_loadu_si128(ptr))
170}
171
172/// Prune 32-bit integer values.
173///
174/// Values corresponding to a 1-bit in the mask are removed from output
175///
176/// # Panics
177/// Panics if `mask` is not in `[0, 1 << 4)`.
178///
179/// # Examples
180///
181/// ```
182/// # #[cfg(target_arch = "x86")] use core::arch::x86::*;
183/// # #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*;
184/// use simdprune::prune_epi32;
185///
186/// unsafe {
187///     let input = _mm_set_epi32(3, 2, 1, 0);
188///     let mask = 0b1010;
189///     let pruned = prune_epi32(input, mask);
190///     let mut buf = [0_u32; 4];
191///     _mm_storeu_si128(buf.as_mut_ptr().cast(), pruned);
192///     assert_eq!(&buf[..4 - mask.count_ones() as usize], [0, 2]);
193/// }
194#[target_feature(enable = "ssse3")]
195#[inline]
196pub unsafe fn prune_epi32(x: __m128i, mask: i32) -> __m128i {
197    let ptr = mask128_epi32[16 * mask as usize..].as_ptr().cast();
198    _mm_shuffle_epi8(x, _mm_loadu_si128(ptr))
199}
200
201/// Prune 32-bit floating-point values.
202///
203/// Values corresponding to a 1-bit in the mask are removed from output
204///
205/// # Panics
206/// Panics if `mask` is not in `[0, 1 << 4)`.
207/// # Examples
208/// See [`prune_epi32`].
209#[inline]
210#[target_feature(enable = "ssse3")]
211pub unsafe fn prune_ps(x: __m128, mask: i32) -> __m128 {
212    _mm_castsi128_ps(prune_epi32(_mm_castps_si128(x), mask))
213}
214
215/// Prune 32-bit integer values.
216///
217/// Values corresponding to a 1-bit in the mask are removed from output
218///
219/// # Panics
220/// Panics if `mask` is not in `[0, 1 << 8)`.
221/// # Examples
222/// See [`prune_epi32`].
223#[target_feature(enable = "avx2")]
224#[inline]
225pub unsafe fn prune256_epi32(x: __m256i, mask: i32) -> __m256i {
226    let ptr = mask256_epi32[8 * mask as usize..].as_ptr().cast();
227    _mm256_permutevar8x32_epi32(x, _mm256_loadu_si256(ptr))
228}
229
230/// Prune 32-bit floating-point values.
231///
232/// Values corresponding to a 1-bit in the mask are removed from output
233///
234/// # Panics
235/// Panics if `mask` is not in `[0, 1 << 8)`.
236/// # Examples
237/// See [`prune_epi32`].
238#[inline]
239#[target_feature(enable = "avx2")]
240pub unsafe fn prune256_ps(x: __m256, mask: i32) -> __m256 {
241    let ptr = mask256_epi32[8 * mask as usize..].as_ptr().cast();
242    _mm256_permutevar8x32_ps(x, _mm256_loadu_si256(ptr))
243}
244
245/// Prune 32-bit floating-point values. Uses 64bit `pdep/pext` to save a step in unpacking.
246///
247/// source:
248/// <http://stackoverflow.com/questions/36932240/avx2-what-is-the-most-efficient-way-to-pack-left-based-on-a-mask>
249///
250/// ***Note that `_pdep_u64` is very slow on AMD Ryzen.***
251///
252/// # Panics
253/// Panics if `mask` is not in `[0, 1 << 8)`.
254/// # Examples
255/// See [`prune_epi32`].
256#[target_feature(enable = "avx2,bmi2")]
257#[inline]
258pub unsafe fn pext_prune256_epi32(src: __m256i, mask: u64) -> __m256i {
259    assert!(mask < 1 << 8);
260    let mut expanded_mask = _pdep_u64(mask, 0x0101_0101_0101_0101); // unpack each bit to a byte
261    expanded_mask *= 0xFF;
262    let identity_indices = 0x0706_0504_0302_0100;
263    let wanted_indices = _pext_u64(identity_indices, expanded_mask);
264    let bytevec = _mm_cvtsi64_si128(wanted_indices as i64);
265    let shufmask = _mm256_cvtepu8_epi32(bytevec);
266    _mm256_permutevar8x32_epi32(src, shufmask)
267}