Skip to main content

vers_vecs/bit_vec/fast_rs_vec/
mod.rs

1//! A fast succinct bit vector implementation with rank and select queries. Rank computes in
2//! constant-time, select on average in constant-time, with a logarithmic worst case.
3
4use std::mem::size_of;
5
6#[cfg(all(
7    feature = "simd",
8    target_arch = "x86_64",
9    target_feature = "avx",
10    target_feature = "avx2",
11    target_feature = "avx512f",
12    target_feature = "avx512bw",
13))]
14pub use bitset::*;
15pub use iter::*;
16
17use crate::util::impl_vector_iterator;
18use crate::BitVec;
19
20use super::WORD_SIZE;
21
22/// Size of a block in the bitvector.
23const BLOCK_SIZE: usize = 512;
24
25/// Size of a super block in the bitvector. Super-blocks exist to decrease the memory overhead
26/// of block descriptors.
27/// Increasing or decreasing the super block size has negligible effect on performance of rank
28/// instruction. This means we want to make the super block size as large as possible, as long as
29/// the zero-counter in normal blocks still fits in a reasonable amount of bits. However, this has
30/// impact on the performance of select queries. The larger the super block size, the deeper will
31/// a binary search be. We found 2^13 to be a good compromise between memory overhead and
32/// performance.
33const SUPER_BLOCK_SIZE: usize = 1 << 13;
34
35/// Size of a select block. The select block is used to speed up select queries. The select block
36/// contains the indices of every `SELECT_BLOCK_SIZE`'th 1-bit and 0-bit in the bitvector.
37/// The smaller this block-size, the faster are select queries, but the more memory is used.
38const SELECT_BLOCK_SIZE: usize = 1 << 13;
39
40/// Meta-data for a block. The `zeros` field stores the number of zeros up to the block,
41/// beginning from the last super-block boundary. This means the first block in a super-block
42/// always stores the number zero, which serves as a sentinel value to avoid special-casing the
43/// first block in a super-block (which would be a performance hit due branch prediction failures).
44#[derive(Clone, Copy, Debug)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46#[cfg_attr(feature = "mem_dbg", derive(mem_dbg::MemSize, mem_dbg::MemDbg))]
47#[cfg_attr(feature = "mem_dbg", mem_size_flat)]
48struct BlockDescriptor {
49    zeros: u16,
50}
51
52/// Meta-data for a super-block. The `zeros` field stores the number of zeros up to this super-block.
53/// This allows the `BlockDescriptor` to store the number of zeros in a much smaller
54/// space. The `zeros` field is the number of zeros up to the super-block.
55#[derive(Clone, Copy, Debug)]
56#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
57#[cfg_attr(feature = "mem_dbg", derive(mem_dbg::MemSize, mem_dbg::MemDbg))]
58#[cfg_attr(feature = "mem_dbg", mem_size_flat)]
59struct SuperBlockDescriptor {
60    zeros: usize,
61}
62
63/// Meta-data for the select query. Each entry i in the select vector contains the indices to find
64/// the i * `SELECT_BLOCK_SIZE`'th 0- and 1-bit in the bitvector. Those indices may be very far apart.
65/// The indices do not point into the bit-vector, but into the super-block vector.
66#[derive(Clone, Debug)]
67#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68#[cfg_attr(feature = "mem_dbg", derive(mem_dbg::MemSize, mem_dbg::MemDbg))]
69#[cfg_attr(feature = "mem_dbg", mem_size_flat)]
70struct SelectSuperBlockDescriptor {
71    index_0: usize,
72    index_1: usize,
73}
74
75/// A bitvector that supports constant-time rank and select queries and is optimized for fast queries.
76/// The bitvector is stored as a vector of `u64`s. The bit-vector stores meta-data for constant-time
77/// rank and select queries, which takes sub-linear additional space. The space overhead is
78/// 28 bits per 512 bits of user data (~5.47%).
79///
80/// # Example
81/// ```rust
82/// use vers_vecs::{BitVec, RsVec};
83///
84/// let mut bit_vec = BitVec::new();
85/// bit_vec.append_word(u64::MAX);
86///
87/// let rs_vec = RsVec::from_bit_vec(bit_vec);
88/// assert_eq!(rs_vec.rank1(64), 64);
89/// assert_eq!(rs_vec.select1(64), 64);
90///```
91#[derive(Clone, Debug)]
92#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
93#[cfg_attr(feature = "mem_dbg", derive(mem_dbg::MemSize, mem_dbg::MemDbg))]
94pub struct RsVec {
95    data: Vec<u64>,
96    len: usize,
97    blocks: Vec<BlockDescriptor>,
98    super_blocks: Vec<SuperBlockDescriptor>,
99    select_blocks: Vec<SelectSuperBlockDescriptor>,
100    pub(crate) rank0: usize,
101    pub(crate) rank1: usize,
102}
103
104impl RsVec {
105    /// Build an `RsVec` from a [`BitVec`]. This will consume the `BitVec`. Since `RsVec`s are
106    /// immutable, this is the only way to construct an `RsVec`.
107    ///
108    /// # Example
109    /// See the example for `RsVec`.
110    ///
111    /// [`BitVec`]: BitVec
112    #[must_use]
113    pub fn from_bit_vec(vec: BitVec) -> RsVec {
114        // Construct the block descriptor meta data. Each block descriptor contains the number of
115        // zeros in the super-block, up to but excluding the block.
116        let mut blocks = Vec::with_capacity(vec.len() / BLOCK_SIZE + 1);
117        let mut super_blocks = Vec::with_capacity(vec.len() / SUPER_BLOCK_SIZE + 1);
118        let mut select_blocks = Vec::new();
119
120        // sentinel value
121        select_blocks.push(SelectSuperBlockDescriptor {
122            index_0: 0,
123            index_1: 0,
124        });
125
126        let mut total_zeros: usize = 0;
127        let mut current_zeros: usize = 0;
128        let mut last_zero_select_block: usize = 0;
129        let mut last_one_select_block: usize = 0;
130
131        for (idx, &word) in vec.data.iter().enumerate() {
132            // if we moved past a block boundary, append the block information for the previous
133            // block and reset the counter if we moved past a super-block boundary.
134            if idx % (BLOCK_SIZE / WORD_SIZE) == 0 {
135                if idx % (SUPER_BLOCK_SIZE / WORD_SIZE) == 0 {
136                    total_zeros += current_zeros;
137                    current_zeros = 0;
138                    super_blocks.push(SuperBlockDescriptor { zeros: total_zeros });
139                }
140
141                // this cannot overflow because a super block isn't 2^16 bits long
142                #[allow(clippy::cast_possible_truncation)]
143                blocks.push(BlockDescriptor {
144                    zeros: current_zeros as u16,
145                });
146            }
147
148            // count the zeros in the current word and add them to the counter
149            // the last word may contain padding zeros, which should not be counted,
150            // but since we do not append the last block descriptor, this is not a problem
151            let mut new_zeros = word.count_zeros() as usize;
152
153            // in the last block, remove remaining zeros of limb that aren't part of the vector
154            if idx == vec.data.len() - 1 && !vec.len.is_multiple_of(WORD_SIZE) {
155                let mask = (1 << (vec.len % WORD_SIZE)) - 1;
156                new_zeros -= (word | mask).count_zeros() as usize;
157            }
158
159            let all_zeros = total_zeros + current_zeros + new_zeros;
160            if all_zeros / SELECT_BLOCK_SIZE > (total_zeros + current_zeros) / SELECT_BLOCK_SIZE {
161                if all_zeros / SELECT_BLOCK_SIZE == select_blocks.len() {
162                    select_blocks.push(SelectSuperBlockDescriptor {
163                        index_0: super_blocks.len() - 1,
164                        index_1: 0,
165                    });
166                } else {
167                    select_blocks[all_zeros / SELECT_BLOCK_SIZE].index_0 = super_blocks.len() - 1;
168                }
169
170                last_zero_select_block += 1;
171            }
172
173            let total_bits = (idx + 1) * WORD_SIZE;
174            let all_ones = total_bits - all_zeros;
175            if all_ones / SELECT_BLOCK_SIZE
176                > (idx * WORD_SIZE - total_zeros - current_zeros) / SELECT_BLOCK_SIZE
177            {
178                if all_ones / SELECT_BLOCK_SIZE == select_blocks.len() {
179                    select_blocks.push(SelectSuperBlockDescriptor {
180                        index_0: 0,
181                        index_1: super_blocks.len() - 1,
182                    });
183                } else {
184                    select_blocks[all_ones / SELECT_BLOCK_SIZE].index_1 = super_blocks.len() - 1;
185                }
186
187                last_one_select_block += 1;
188            }
189
190            current_zeros += new_zeros;
191        }
192
193        // insert dummy select blocks at the end that just report the same index like the last real
194        // block, so the bound check for binary search doesn't overflow
195        // this is technically the incorrect value, but since all valid queries will be smaller,
196        // this will only tell select to stay in the current super block, which is correct.
197        // we cannot use a real value here, because this would change the size of the super-block
198        if last_zero_select_block == select_blocks.len() - 1 {
199            select_blocks.push(SelectSuperBlockDescriptor {
200                index_0: select_blocks[last_zero_select_block].index_0,
201                index_1: 0,
202            });
203        } else {
204            debug_assert!(select_blocks[last_zero_select_block + 1].index_0 == 0);
205            select_blocks[last_zero_select_block + 1].index_0 =
206                select_blocks[last_zero_select_block].index_0;
207        }
208        if last_one_select_block == select_blocks.len() - 1 {
209            select_blocks.push(SelectSuperBlockDescriptor {
210                index_0: 0,
211                index_1: select_blocks[last_one_select_block].index_1,
212            });
213        } else {
214            debug_assert!(select_blocks[last_one_select_block + 1].index_1 == 0);
215            select_blocks[last_one_select_block + 1].index_1 =
216                select_blocks[last_one_select_block].index_1;
217        }
218
219        total_zeros += current_zeros;
220
221        RsVec {
222            data: vec.data,
223            len: vec.len,
224            blocks,
225            super_blocks,
226            select_blocks,
227            rank0: total_zeros,
228            rank1: vec.len - total_zeros,
229        }
230    }
231
232    /// Return the 0-rank of the bit at the given position. The 0-rank is the number of
233    /// 0-bits in the vector up to but excluding the bit at the given position. Calling this
234    /// function with an index larger than the length of the bit-vector will report the total
235    /// number of 0-bits in the bit-vector.
236    ///
237    /// # Parameters
238    /// - `pos`: The position of the bit to return the rank of.
239    #[must_use]
240    pub fn rank0(&self, pos: usize) -> usize {
241        self.rank(true, pos)
242    }
243
244    /// Return the 1-rank of the bit at the given position. The 1-rank is the number of
245    /// 1-bits in the vector up to but excluding the bit at the given position. Calling this
246    /// function with an index larger than the length of the bit-vector will report the total
247    /// number of 1-bits in the bit-vector.
248    ///
249    /// # Parameters
250    /// - `pos`: The position of the bit to return the rank of.
251    #[must_use]
252    pub fn rank1(&self, pos: usize) -> usize {
253        self.rank(false, pos)
254    }
255
256    // I measured 5-10% improvement with this. I don't know why it's not inlined by default, the
257    // branch elimination profits alone should make it worth it.
258    #[allow(clippy::inline_always)]
259    #[inline(always)]
260    fn rank(&self, zero: bool, pos: usize) -> usize {
261        #[allow(clippy::collapsible_else_if)]
262        // readability and more obvious where dead branch elimination happens
263        if zero {
264            if pos >= self.len() {
265                return self.rank0;
266            }
267        } else {
268            if pos >= self.len() {
269                return self.rank1;
270            }
271        }
272
273        let index = pos / WORD_SIZE;
274        let block_index = pos / BLOCK_SIZE;
275        let super_block_index = pos / SUPER_BLOCK_SIZE;
276        let mut rank = 0;
277
278        // at first add the number of zeros/ones before the current super block
279        rank += if zero {
280            self.super_blocks[super_block_index].zeros
281        } else {
282            (super_block_index * SUPER_BLOCK_SIZE) - self.super_blocks[super_block_index].zeros
283        };
284
285        // then add the number of zeros/ones before the current block
286        rank += if zero {
287            self.blocks[block_index].zeros as usize
288        } else {
289            ((block_index % (SUPER_BLOCK_SIZE / BLOCK_SIZE)) * BLOCK_SIZE)
290                - self.blocks[block_index].zeros as usize
291        };
292
293        // naive popcount of blocks
294        for &i in &self.data[(block_index * BLOCK_SIZE) / WORD_SIZE..index] {
295            rank += if zero {
296                i.count_zeros() as usize
297            } else {
298                i.count_ones() as usize
299            };
300        }
301
302        rank += if zero {
303            (!self.data[index] & ((1 << (pos % WORD_SIZE)) - 1)).count_ones() as usize
304        } else {
305            (self.data[index] & ((1 << (pos % WORD_SIZE)) - 1)).count_ones() as usize
306        };
307
308        rank
309    }
310
311    /// Return the length of the vector, i.e. the number of bits it contains.
312    #[must_use]
313    pub fn len(&self) -> usize {
314        self.len
315    }
316
317    /// Return whether the vector is empty.
318    #[must_use]
319    pub fn is_empty(&self) -> bool {
320        self.len() == 0
321    }
322
323    /// Return the bit at the given position. The bit takes the least significant
324    /// bit of the returned u64 word.
325    /// If the position is larger than the length of the vector, `None` is returned.
326    #[must_use]
327    pub fn get(&self, pos: usize) -> Option<u64> {
328        if pos >= self.len() {
329            None
330        } else {
331            Some(self.get_unchecked(pos))
332        }
333    }
334
335    /// Return the bit at the given position. The bit takes the least significant
336    /// bit of the returned u64 word.
337    ///
338    /// # Panics
339    /// This function may panic if `pos >= self.len()` (alternatively, it may return garbage).
340    #[must_use]
341    pub fn get_unchecked(&self, pos: usize) -> u64 {
342        (self.data[pos / WORD_SIZE] >> (pos % WORD_SIZE)) & 1
343    }
344
345    /// Return multiple bits at the given position. The number of bits to return is given by `len`.
346    /// At most 64 bits can be returned.
347    /// If the position at the end of the query is larger than the length of the vector,
348    /// None is returned (even if the query partially overlaps with the vector).
349    /// If the length of the query is larger than 64, None is returned.
350    #[must_use]
351    pub fn get_bits(&self, pos: usize, len: usize) -> Option<u64> {
352        if len > WORD_SIZE {
353            return None;
354        }
355        if pos + len > self.len {
356            None
357        } else {
358            Some(self.get_bits_unchecked(pos, len))
359        }
360    }
361
362    /// Return multiple bits at the given position. The number of bits to return is given by `len`.
363    /// At most 64 bits can be returned.
364    ///
365    /// This function is always inlined, because it gains a lot from loop optimization and
366    /// can utilize the processor pre-fetcher better if it is.
367    ///
368    /// # Errors
369    /// If the length of the query is larger than 64, unpredictable data will be returned.
370    /// Use [`get_bits`] to properly handle this case with an `Option`.
371    ///
372    /// # Panics
373    /// If the position or interval is larger than the length of the vector,
374    /// the function will either return unpredictable data, or panic.
375    ///
376    /// [`get_bits`]: #method.get_bits
377    #[must_use]
378    #[allow(clippy::comparison_chain)] // readability
379    #[allow(clippy::cast_possible_truncation)] // parameter must be out of scope for this to happen
380    pub fn get_bits_unchecked(&self, pos: usize, len: usize) -> u64 {
381        debug_assert!(len <= WORD_SIZE);
382        let partial_word = self.data[pos / WORD_SIZE] >> (pos % WORD_SIZE);
383        if pos % WORD_SIZE + len <= WORD_SIZE {
384            partial_word & 1u64.checked_shl(len as u32).unwrap_or(0).wrapping_sub(1)
385        } else {
386            (partial_word | (self.data[pos / WORD_SIZE + 1] << (WORD_SIZE - pos % WORD_SIZE)))
387                & 1u64.checked_shl(len as u32).unwrap_or(0).wrapping_sub(1)
388        }
389    }
390
391    /// Convert the `RsVec` into a [`BitVec`].
392    /// This consumes the `RsVec`, and discards all meta-data.
393    /// Since [`RsVec`]s are innately immutable, this conversion is the only way to modify the
394    /// underlying data.
395    ///
396    /// # Example
397    /// ```rust
398    /// use vers_vecs::{BitVec, RsVec};
399    ///
400    /// let mut bit_vec = BitVec::new();
401    /// bit_vec.append_word(u64::MAX);
402    ///
403    /// let rs_vec = RsVec::from_bit_vec(bit_vec);
404    /// assert_eq!(rs_vec.rank1(64), 64);
405    ///
406    /// let mut bit_vec = rs_vec.into_bit_vec();
407    /// bit_vec.flip_bit(32);
408    /// let rs_vec = RsVec::from_bit_vec(bit_vec);
409    /// assert_eq!(rs_vec.rank1(64), 63);
410    /// assert_eq!(rs_vec.select0(0), 32);
411    /// ```
412    #[must_use]
413    pub fn into_bit_vec(self) -> BitVec {
414        BitVec {
415            data: self.data,
416            len: self.len,
417        }
418    }
419
420    /// Check if two `RsVec`s are equal. For sparse vectors (either sparsely filled with 1-bits or
421    /// 0-bits), this is faster than comparing the vectors bit by bit.
422    /// Choose the value of `ZERO` depending on which bits are more sparse.
423    ///
424    /// This method is faster than [`full_equals`] for sparse vectors beginning at roughly 1
425    /// million bits. Above 4 million bits, this method becomes faster than full equality in general.
426    ///
427    /// # Parameters
428    /// - `other`: The other `RsVec` to compare to.
429    /// - `ZERO`: Whether to compare the sparse 0-bits (true) or the sparse 1-bits (false).
430    ///
431    /// # Returns
432    /// `true` if the vectors' contents are equal, `false` otherwise.
433    ///
434    /// [`full_equals`]: RsVec::full_equals
435    #[must_use]
436    pub fn sparse_equals<const ZERO: bool>(&self, other: &Self) -> bool {
437        if self.len() != other.len() {
438            return false;
439        }
440
441        if self.rank0 != other.rank0 || self.rank1 != other.rank1 {
442            return false;
443        }
444
445        let iter: SelectIter<ZERO> = self.select_iter();
446
447        for (rank, bit_index) in iter.enumerate() {
448            // since rank is inlined, we get dead code elimination depending on ZERO
449            if (other.get_unchecked(bit_index) == 0) != ZERO || other.rank(ZERO, bit_index) != rank
450            {
451                return false;
452            }
453        }
454
455        true
456    }
457
458    /// Check if two `RsVec`s are equal. This compares limb by limb. This is usually faster than a
459    /// [`sparse_equals`] call for small vectors.
460    ///
461    /// # Parameters
462    /// - `other`: The other `RsVec` to compare to.
463    ///
464    /// # Returns
465    /// `true` if the vectors' contents are equal, `false` otherwise.
466    ///
467    /// [`sparse_equals`]: RsVec::sparse_equals
468    #[must_use]
469    pub fn full_equals(&self, other: &Self) -> bool {
470        if self.len() != other.len() {
471            return false;
472        }
473
474        if self.rank0 != other.rank0 || self.rank1 != other.rank1 {
475            return false;
476        }
477
478        if self.data[..self.len / 64]
479            .iter()
480            .zip(other.data[..other.len / 64].iter())
481            .any(|(a, b)| a != b)
482        {
483            return false;
484        }
485
486        // if last incomplete block exists, test it without junk data
487        if !self.len.is_multiple_of(WORD_SIZE)
488            && self.data[self.len / WORD_SIZE] & ((1 << (self.len % WORD_SIZE)) - 1)
489                != other.data[self.len / WORD_SIZE] & ((1 << (other.len % WORD_SIZE)) - 1)
490        {
491            return false;
492        }
493
494        true
495    }
496
497    /// Returns the number of bytes used on the heap for this vector. This does not include
498    /// allocated space that is not used (e.g. by the allocation behavior of `Vec`).
499    #[must_use]
500    pub fn heap_size(&self) -> usize {
501        self.data.len() * size_of::<u64>()
502            + self.blocks.len() * size_of::<BlockDescriptor>()
503            + self.super_blocks.len() * size_of::<SuperBlockDescriptor>()
504            + self.select_blocks.len() * size_of::<SelectSuperBlockDescriptor>()
505    }
506}
507
508impl_vector_iterator! { RsVec, RsVecIter, RsVecRefIter }
509
510impl PartialEq for RsVec {
511    /// Check if two `RsVec`s are equal. This method calls [`sparse_equals`] if the vector has more
512    /// than 4'000'000 bits, and [`full_equals`] otherwise.
513    ///
514    /// This was determined with benchmarks on an `x86_64` machine,
515    /// on which [`sparse_equals`] outperforms [`full_equals`] consistently above this threshold.
516    ///
517    /// # Parameters
518    /// - `other`: The other `RsVec` to compare to.
519    ///
520    /// # Returns
521    /// `true` if the vectors' contents are equal, `false` otherwise.
522    ///
523    /// [`sparse_equals`]: RsVec::sparse_equals
524    /// [`full_equals`]: RsVec::full_equals
525    fn eq(&self, other: &Self) -> bool {
526        if self.len > 4_000_000 {
527            if self.rank1 > self.rank0 {
528                self.sparse_equals::<true>(other)
529            } else {
530                self.sparse_equals::<false>(other)
531            }
532        } else {
533            self.full_equals(other)
534        }
535    }
536}
537
538impl From<BitVec> for RsVec {
539    /// Build an [`RsVec`] from a [`BitVec`]. This will consume the [`BitVec`]. Since [`RsVec`]s are
540    /// immutable, this is the only way to construct an [`RsVec`].
541    ///
542    /// # Example
543    /// See the example for [`RsVec`].
544    ///
545    /// [`BitVec`]: BitVec
546    /// [`RsVec`]: RsVec
547    fn from(vec: BitVec) -> Self {
548        RsVec::from_bit_vec(vec)
549    }
550}
551
552impl From<RsVec> for BitVec {
553    fn from(value: RsVec) -> Self {
554        value.into_bit_vec()
555    }
556}
557
558// iter code in here to keep it more organized
559mod iter;
560// select code in here to keep it more organized
561mod select;
562
563#[cfg(all(
564    feature = "simd",
565    target_arch = "x86_64",
566    target_feature = "avx",
567    target_feature = "avx2",
568    target_feature = "avx512f",
569    target_feature = "avx512bw",
570))]
571mod bitset;
572
573#[cfg(test)]
574mod tests;