Skip to main content

vers_vecs/bit_vec/fast_rs_vec/
select.rs

1// Select code is in here to keep it more organized.
2
3use crate::bit_vec::fast_rs_vec::{BLOCK_SIZE, SELECT_BLOCK_SIZE, SUPER_BLOCK_SIZE};
4use crate::bit_vec::WORD_SIZE;
5use crate::util::pdep::Pdep;
6use crate::util::unroll;
7
8/// A safety constant for assertions to make sure that the block size doesn't change without
9/// adjusting the code.
10const BLOCKS_PER_SUPERBLOCK: usize = 16;
11
12impl super::RsVec {
13    /// Return the position of the 0-bit with the given rank. See `rank0`.
14    /// The following holds for all `pos` with 0-bits:
15    /// ``select0(rank0(pos)) == pos``
16    ///
17    /// If the rank is larger than the number of 0-bits in the vector, the vector length is returned.
18    #[must_use]
19    #[allow(clippy::assertions_on_constants)]
20    pub fn select0(&self, mut rank: usize) -> usize {
21        if rank >= self.rank0 {
22            return self.len;
23        }
24
25        let mut super_block = self.select_blocks[rank / SELECT_BLOCK_SIZE].index_0;
26
27        if self.super_blocks.len() > (super_block + 1)
28            && self.super_blocks[super_block + 1].zeros <= rank
29        {
30            super_block = self.search_super_block0(super_block, rank);
31        }
32
33        rank -= self.super_blocks[super_block].zeros;
34
35        let mut block_index = super_block * (SUPER_BLOCK_SIZE / BLOCK_SIZE);
36        self.search_block0(rank, &mut block_index);
37
38        rank -= self.blocks[block_index].zeros as usize;
39
40        self.search_word_in_block0(rank, block_index)
41    }
42
43    /// Search for the block in a superblock that contains the rank. This function is only used
44    /// internally and is not part of the public API.
45    /// The function uses SIMD instructions if available, otherwise it falls back to a naive
46    /// implementation.
47    ///
48    /// It loads the entire block into a SIMD register and compares the rank to the number of zeros
49    /// in the block. The resulting mask is popcounted to find how many blocks from the block boundary
50    /// the rank is.
51    #[cfg(all(
52        feature = "simd",
53        target_arch = "x86_64",
54        target_feature = "avx",
55        target_feature = "avx512vl",
56        target_feature = "avx512bw",
57    ))]
58    #[inline(always)]
59    pub(super) fn search_block0(&self, rank: usize, block_index: &mut usize) {
60        use std::arch::x86_64::{_mm256_cmpgt_epu16_mask, _mm256_loadu_epi16, _mm256_set1_epi16};
61
62        if self.blocks.len() > *block_index + (SUPER_BLOCK_SIZE / BLOCK_SIZE) {
63            debug_assert!(
64                SUPER_BLOCK_SIZE / BLOCK_SIZE == BLOCKS_PER_SUPERBLOCK,
65                "change unroll constant to {}",
66                64 - (SUPER_BLOCK_SIZE / BLOCK_SIZE).leading_zeros() - 1
67            );
68
69            unsafe {
70                let blocks = _mm256_loadu_epi16(self.blocks[*block_index..].as_ptr() as *const i16);
71                let ranks = _mm256_set1_epi16(rank as i16);
72                let mask = _mm256_cmpgt_epu16_mask(blocks, ranks);
73
74                debug_assert!(
75                    mask.count_zeros() > 0,
76                    "first block should always be zero, but still claims to be greater than rank"
77                );
78                *block_index += mask.count_zeros() as usize - 1;
79            }
80        } else {
81            self.search_block0_naive(rank, block_index)
82        }
83    }
84
85    /// Search for the block in a superblock that contains the rank. This function is only used
86    /// internally and is not part of the public API.
87    /// It compares blocks in a loop-unrolled binary search to find the block that contains the rank.
88    #[cfg(not(all(
89        feature = "simd",
90        target_arch = "x86_64",
91        target_feature = "avx",
92        target_feature = "avx512vl",
93        target_feature = "avx512bw",
94    )))]
95    #[inline(always)]
96    pub(super) fn search_block0(&self, rank: usize, block_index: &mut usize) {
97        self.search_block0_naive(rank, block_index);
98    }
99
100    #[inline(always)]
101    fn search_block0_naive(&self, rank: usize, block_index: &mut usize) {
102        // full binary search for block that contains the rank, manually loop-unrolled, because
103        // LLVM doesn't do it for us, but it gains just under 20% performance
104
105        // this code relies on the fact that BLOCKS_PER_SUPERBLOCK blocks are in one superblock
106        debug_assert!(
107            SUPER_BLOCK_SIZE / BLOCK_SIZE == BLOCKS_PER_SUPERBLOCK,
108            "change unroll constant to {}",
109            64 - (SUPER_BLOCK_SIZE / BLOCK_SIZE).leading_zeros() - 1
110        );
111        unroll!(4,
112            |boundary = { (SUPER_BLOCK_SIZE / BLOCK_SIZE) / 2}|
113                // do not use select_unpredictable here, it degrades performance
114                if self.blocks.len() > *block_index + boundary && rank >= self.blocks[*block_index + boundary].zeros as usize {
115                    *block_index += boundary;
116                },
117            boundary /= 2);
118    }
119
120    /// Search for the word in the block that contains the rank, return the index of the rank-th
121    /// zero bit in the word.
122    /// This function is called by the ``select0``, ``iter::select_next_0`` and ``iter::select_next_0_back`` functions.
123    ///
124    /// # Arguments
125    /// * `rank` - the rank to search for, relative to the block
126    /// * `block_index` - the index of the block to search in, this is the block in the blocks
127    ///   vector that contains the rank
128    #[inline(always)]
129    pub(super) fn search_word_in_block0(&self, mut rank: usize, block_index: usize) -> usize {
130        // linear search for word that contains the rank. Binary search is not possible here,
131        // because we don't have accumulated popcounts for the words. We use pdep to find the
132        // position of the rank-th zero bit in the word, if the word contains enough zeros, otherwise
133        // we subtract the number of ones in the word from the rank and continue with the next word.
134        let mut index_counter = 0;
135        debug_assert!(BLOCK_SIZE / WORD_SIZE == 8, "change unroll constant");
136        unroll!(7, |n = {0}| {
137                    let word = self.data[block_index * BLOCK_SIZE / WORD_SIZE + n];
138                    if (word.count_zeros() as usize) <= rank {
139                        rank -= word.count_zeros() as usize;
140                        index_counter += WORD_SIZE;
141                    } else {
142                        return block_index * BLOCK_SIZE
143                            + index_counter
144                            + (1 << rank).pdep(!word).trailing_zeros() as usize;
145                    }
146                }, n += 1);
147
148        // the last word must contain the rank-th zero bit, otherwise the rank is outside the
149        // block, and thus outside the bitvector
150        block_index * BLOCK_SIZE
151            + index_counter
152            + (1 << rank)
153                .pdep(!self.data[block_index * BLOCK_SIZE / WORD_SIZE + 7])
154                .trailing_zeros() as usize
155    }
156
157    /// Search for the superblock that contains the rank.
158    /// This function is called by the ``select0``, ``iter::select_next_0`` and ``iter::select_next_0_back`` functions.
159    ///
160    /// # Arguments
161    /// * `super_block` - the index of the superblock to start the search from, this is the
162    ///   superblock in the ``select_blocks`` vector that contains the rank
163    /// * `rank` - the rank to search for
164    #[inline(always)]
165    pub(super) fn search_super_block0(&self, mut super_block: usize, rank: usize) -> usize {
166        let mut upper_bound = self.select_blocks[rank / SELECT_BLOCK_SIZE + 1].index_0;
167
168        while upper_bound - super_block > 8 {
169            let middle = super_block + ((upper_bound - super_block) >> 1);
170            // using select_unpredictable does nothing here, likely because the search isn't hot
171            if self.super_blocks[middle].zeros <= rank {
172                super_block = middle;
173            } else {
174                upper_bound = middle;
175            }
176        }
177
178        // linear search for superblock that contains the rank
179        while self.super_blocks.len() > (super_block + 1)
180            && self.super_blocks[super_block + 1].zeros <= rank
181        {
182            super_block += 1;
183        }
184
185        super_block
186    }
187
188    /// Return the position of the 1-bit with the given rank. See `rank1`.
189    /// The following holds for all `pos` with 1-bits:
190    /// ``select1(rank1(pos)) == pos``
191    ///
192    /// If the rank is larger than the number of 1-bits in the bit-vector, the vector length is returned.
193    #[must_use]
194    #[allow(clippy::assertions_on_constants)]
195    pub fn select1(&self, mut rank: usize) -> usize {
196        if rank >= self.rank1 {
197            return self.len;
198        }
199
200        let mut super_block =
201            self.select_blocks[rank / crate::bit_vec::fast_rs_vec::SELECT_BLOCK_SIZE].index_1;
202
203        if self.super_blocks.len() > (super_block + 1)
204            && ((super_block + 1) * SUPER_BLOCK_SIZE - self.super_blocks[super_block + 1].zeros)
205                <= rank
206        {
207            super_block = self.search_super_block1(super_block, rank);
208        }
209
210        rank -= (super_block) * SUPER_BLOCK_SIZE - self.super_blocks[super_block].zeros;
211
212        // full binary search for block that contains the rank, manually loop-unrolled, because
213        // LLVM doesn't do it for us, but it gains just under 20% performance
214        let block_at_super_block = super_block * (SUPER_BLOCK_SIZE / BLOCK_SIZE);
215        let mut block_index = block_at_super_block;
216        self.search_block1(rank, block_at_super_block, &mut block_index);
217
218        rank -= (block_index - block_at_super_block) * BLOCK_SIZE
219            - self.blocks[block_index].zeros as usize;
220
221        self.search_word_in_block1(rank, block_index)
222    }
223
224    /// Search for the block in a superblock that contains the rank. This function is only used
225    /// internally and is not part of the public API.
226    /// The function uses SIMD instructions if available, otherwise it falls back to a naive
227    /// implementation.
228    ///
229    /// It loads the entire block into a SIMD register and compares the rank to the number of ones
230    /// in the block. The resulting mask is popcounted to find how many blocks from the block boundary
231    /// the rank is.
232    #[cfg(all(
233        feature = "simd",
234        target_arch = "x86_64",
235        target_feature = "avx",
236        target_feature = "avx2",
237        target_feature = "avx512vl",
238        target_feature = "avx512bw",
239    ))]
240    #[inline(always)]
241    pub(super) fn search_block1(
242        &self,
243        rank: usize,
244        block_at_super_block: usize,
245        block_index: &mut usize,
246    ) {
247        use std::arch::x86_64::{
248            _mm256_cmpgt_epu16_mask, _mm256_loadu_epi16, _mm256_set1_epi16, _mm256_set_epi16,
249            _mm256_sub_epi16,
250        };
251
252        if self.blocks.len() > *block_index + BLOCKS_PER_SUPERBLOCK {
253            debug_assert!(
254                SUPER_BLOCK_SIZE / BLOCK_SIZE == BLOCKS_PER_SUPERBLOCK,
255                "change unroll constant to {}",
256                64 - (SUPER_BLOCK_SIZE / BLOCK_SIZE).leading_zeros() - 1
257            );
258
259            unsafe {
260                let bit_nums = _mm256_set_epi16(
261                    (15 * BLOCK_SIZE) as i16,
262                    (14 * BLOCK_SIZE) as i16,
263                    (13 * BLOCK_SIZE) as i16,
264                    (12 * BLOCK_SIZE) as i16,
265                    (11 * BLOCK_SIZE) as i16,
266                    (10 * BLOCK_SIZE) as i16,
267                    (9 * BLOCK_SIZE) as i16,
268                    (8 * BLOCK_SIZE) as i16,
269                    (7 * BLOCK_SIZE) as i16,
270                    (6 * BLOCK_SIZE) as i16,
271                    (5 * BLOCK_SIZE) as i16,
272                    (4 * BLOCK_SIZE) as i16,
273                    (3 * BLOCK_SIZE) as i16,
274                    (2 * BLOCK_SIZE) as i16,
275                    (1 * BLOCK_SIZE) as i16,
276                    (0 * BLOCK_SIZE) as i16,
277                );
278
279                let blocks = _mm256_loadu_epi16(self.blocks[*block_index..].as_ptr() as *const i16);
280                let ones = _mm256_sub_epi16(bit_nums, blocks);
281
282                let ranks = _mm256_set1_epi16(rank as i16);
283                let mask = _mm256_cmpgt_epu16_mask(ones, ranks);
284
285                debug_assert!(
286                    mask.count_zeros() > 0,
287                    "first block should always be zero, but still claims to be greater than rank"
288                );
289                *block_index += mask.count_zeros() as usize - 1;
290            }
291        } else {
292            self.search_block1_naive(rank, block_at_super_block, block_index)
293        }
294    }
295
296    /// Search for the block in a superblock that contains the rank. This function is only used
297    /// internally and is not part of the public API.
298    /// It compares blocks in a loop-unrolled binary search to find the block that contains the rank.
299    #[cfg(not(all(
300        feature = "simd",
301        target_arch = "x86_64",
302        target_feature = "avx",
303        target_feature = "avx2",
304        target_feature = "avx512vl",
305        target_feature = "avx512bw",
306    )))]
307    #[inline(always)]
308    pub(super) fn search_block1(
309        &self,
310        rank: usize,
311        block_at_super_block: usize,
312        block_index: &mut usize,
313    ) {
314        self.search_block1_naive(rank, block_at_super_block, block_index);
315    }
316
317    #[inline(always)]
318    fn search_block1_naive(
319        &self,
320        rank: usize,
321        block_at_super_block: usize,
322        block_index: &mut usize,
323    ) {
324        // full binary search for block that contains the rank, manually loop-unrolled, because
325        // LLVM doesn't do it for us, but it gains just under 20% performance
326
327        // this code relies on the fact that BLOCKS_PER_SUPERBLOCK blocks are in one superblock
328        debug_assert!(
329            SUPER_BLOCK_SIZE / BLOCK_SIZE == BLOCKS_PER_SUPERBLOCK,
330            "change unroll constant to {}",
331            64 - (SUPER_BLOCK_SIZE / BLOCK_SIZE).leading_zeros() - 1
332        );
333        unroll!(4,
334            |boundary = { (SUPER_BLOCK_SIZE / BLOCK_SIZE) / 2}|
335                // do not use select_unpredictable here, it degrades performance
336                if self.blocks.len() > *block_index + boundary && rank >= (*block_index + boundary - block_at_super_block) * BLOCK_SIZE - self.blocks[*block_index + boundary].zeros as usize {
337                    *block_index += boundary;
338                },
339            boundary /= 2);
340    }
341
342    /// Search for the word in the block that contains the rank, return the index of the rank-th
343    /// zero bit in the word.
344    /// This function is called by the ``select1``, ``iter::select_next_1`` and ``iter::select_next_1_back`` functions.
345    ///
346    /// # Arguments
347    /// * `rank` - the rank to search for, relative to the block
348    /// * `block_index` - the index of the block to search in, this is the block in the blocks
349    ///   vector that contains the rank
350    #[inline(always)]
351    pub(super) fn search_word_in_block1(&self, mut rank: usize, block_index: usize) -> usize {
352        // linear search for word that contains the rank. Binary search is not possible here,
353        // because we don't have accumulated popcounts for the words. We use pdep to find the
354        // position of the rank-th zero bit in the word, if the word contains enough zeros, otherwise
355        // we subtract the number of ones in the word from the rank and continue with the next word.
356        let mut index_counter = 0;
357        debug_assert!(BLOCK_SIZE / WORD_SIZE == 8, "change unroll constant");
358        unroll!(7, |n = {0}| {
359            let word = self.data[block_index * BLOCK_SIZE / WORD_SIZE + n];
360            if (word.count_ones() as usize) <= rank {
361                rank -= word.count_ones() as usize;
362                index_counter += WORD_SIZE;
363            } else {
364                return block_index * BLOCK_SIZE
365                    + index_counter
366                    + (1 << rank).pdep(word).trailing_zeros() as usize;
367            }
368        }, n += 1);
369
370        // the last word must contain the rank-th zero bit, otherwise the rank is outside of the
371        // block, and thus outside of the bitvector
372        block_index * BLOCK_SIZE
373            + index_counter
374            + (1 << rank)
375                .pdep(self.data[block_index * BLOCK_SIZE / WORD_SIZE + 7])
376                .trailing_zeros() as usize
377    }
378
379    /// Search for the superblock that contains the rank.
380    /// This function is called by the ``select1``, ``iter::select_next_1`` and ``iter::select_next_1_back`` functions.
381    ///
382    /// # Arguments
383    /// * `super_block` - the index of the superblock to start the search from, this is the
384    ///   superblock in the ``select_blocks`` vector that contains the rank
385    /// * `rank` - the rank to search for
386    #[inline(always)]
387    pub(super) fn search_super_block1(&self, mut super_block: usize, rank: usize) -> usize {
388        let mut upper_bound = self.select_blocks[rank / SELECT_BLOCK_SIZE + 1].index_1;
389
390        // binary search for superblock that contains the rank
391        while upper_bound - super_block > 8 {
392            let middle = super_block + ((upper_bound - super_block) >> 1);
393            // using select_unpredictable does nothing here, likely because the search isn't hot
394            if ((middle + 1) * SUPER_BLOCK_SIZE - self.super_blocks[middle].zeros) <= rank {
395                super_block = middle;
396            } else {
397                upper_bound = middle;
398            }
399        }
400        // linear search for superblock that contains the rank
401        while self.super_blocks.len() > (super_block + 1)
402            && ((super_block + 1) * SUPER_BLOCK_SIZE - self.super_blocks[super_block + 1].zeros)
403                <= rank
404        {
405            super_block += 1;
406        }
407
408        super_block
409    }
410}