vers_vecs/bit_vec/
mask.rs

1//! This module defines a struct for lazily masking [`BitVec`]. It offers all immutable operations
2//! of `BitVec` but applies a bit-mask during the operation. The struct is created through
3//! [`BitVec::mask_xor`], [`BitVec::mask_and`], [`BitVec::mask_or`], or [`BitVec::mask_custom`].
4
5use super::WORD_SIZE;
6use crate::BitVec;
7
8/// A bit vector that is masked with another bit vector via a masking function. Offers the same
9/// functions as an unmasked vector. The mask is applied lazily.
10#[derive(Debug, Clone)]
11pub struct MaskedBitVec<'a, 'b, F: Fn(u64, u64) -> u64> {
12    vec: &'a BitVec,
13    mask: &'b BitVec,
14    bin_op: F,
15}
16
17impl<'a, 'b, F> MaskedBitVec<'a, 'b, F>
18where
19    F: Fn(u64, u64) -> u64,
20{
21    #[inline]
22    pub(crate) fn new(vec: &'a BitVec, mask: &'b BitVec, bin_op: F) -> Result<Self, String> {
23        if vec.len != mask.len {
24            return Err(String::from(
25                "mask cannot have different length than vector",
26            ));
27        }
28
29        Ok(MaskedBitVec { vec, mask, bin_op })
30    }
31
32    /// Iterate over the limbs of the masked vector
33    #[inline]
34    fn iter_limbs<'s>(&'s self) -> impl Iterator<Item = u64> + 's
35    where
36        'a: 's,
37        'b: 's,
38    {
39        self.vec
40            .data
41            .iter()
42            .zip(&self.mask.data)
43            .map(|(&a, &b)| (self.bin_op)(a, b))
44    }
45
46    /// Return the bit at the given position.
47    /// The bit takes the least significant bit of the returned u64 word.
48    /// If the position is larger than the length of the vector, None is returned.
49    #[inline]
50    #[must_use]
51    pub fn get(&self, pos: usize) -> Option<u64> {
52        if pos >= self.vec.len {
53            None
54        } else {
55            Some(self.get_unchecked(pos))
56        }
57    }
58
59    /// Return the bit at the given position.
60    /// The bit takes the least significant bit of the returned u64 word.
61    ///
62    /// # Panics
63    /// If the position is larger than the length of the vector,
64    /// the function will either return unpredictable data, or panic.
65    /// Use [`get`] to properly handle this case with an `Option`.
66    ///
67    /// [`get`]: MaskedBitVec::get
68    #[inline]
69    #[must_use]
70    pub fn get_unchecked(&self, pos: usize) -> u64 {
71        ((self.bin_op)(
72            self.vec.data[pos / WORD_SIZE],
73            self.mask.data[pos / WORD_SIZE],
74        ) >> (pos % WORD_SIZE))
75            & 1
76    }
77
78    /// Return whether the bit at the given position is set.
79    /// If the position is larger than the length of the vector, None is returned.
80    #[inline]
81    #[must_use]
82    pub fn is_bit_set(&self, pos: usize) -> Option<bool> {
83        if pos >= self.vec.len {
84            None
85        } else {
86            Some(self.is_bit_set_unchecked(pos))
87        }
88    }
89
90    /// Return whether the bit at the given position is set.
91    ///
92    /// # Panics
93    /// If the position is larger than the length of the vector,
94    /// the function will either return unpredictable data, or panic.
95    /// Use [`is_bit_set`] to properly handle this case with an `Option`.
96    ///
97    /// [`is_bit_set`]: MaskedBitVec::is_bit_set
98    #[inline]
99    #[must_use]
100    pub fn is_bit_set_unchecked(&self, pos: usize) -> bool {
101        self.get_unchecked(pos) != 0
102    }
103
104    /// Return multiple bits at the given position. The number of bits to return is given by `len`.
105    /// At most 64 bits can be returned.
106    /// If the position at the end of the query is larger than the length of the vector,
107    /// None is returned (even if the query partially overlaps with the vector).
108    /// If the length of the query is larger than 64, None is returned.
109    #[inline]
110    #[must_use]
111    pub fn get_bits(&self, pos: usize, len: usize) -> Option<u64> {
112        if len > WORD_SIZE || len == 0 {
113            return None;
114        }
115        if pos + len > self.vec.len {
116            None
117        } else {
118            Some(self.get_bits_unchecked(pos, len))
119        }
120    }
121
122    /// Return multiple bits at the given position. The number of bits to return is given by `len`.
123    /// At most 64 bits can be returned.
124    ///
125    /// This function is always inlined, because it gains a lot from loop optimization and
126    /// can utilize the processor pre-fetcher better if it is.
127    ///
128    /// # Errors
129    /// If the length of the query is larger than 64, unpredictable data will be returned.
130    /// Use [`get_bits`] to avoid this.
131    ///
132    /// # Panics
133    /// If the position or interval is larger than the length of the vector,
134    /// the function will either return any valid results padded with unpredictable
135    /// data or panic.
136    ///
137    /// [`get_bits`]: MaskedBitVec::get_bits
138    #[must_use]
139    #[allow(clippy::inline_always)]
140    #[allow(clippy::comparison_chain)] // rust-clippy #5354
141    #[inline]
142    pub fn get_bits_unchecked(&self, pos: usize, len: usize) -> u64 {
143        debug_assert!(len <= WORD_SIZE);
144        let partial_word = (self.bin_op)(
145            self.vec.data[pos / WORD_SIZE],
146            self.mask.data[pos / WORD_SIZE],
147        ) >> (pos % WORD_SIZE);
148
149        if pos % WORD_SIZE + len == WORD_SIZE {
150            partial_word
151        } else if pos % WORD_SIZE + len < WORD_SIZE {
152            partial_word & ((1 << (len % WORD_SIZE)) - 1)
153        } else {
154            let next_half = (self.bin_op)(
155                self.vec.data[pos / WORD_SIZE + 1],
156                self.mask.data[pos / WORD_SIZE + 1],
157            ) << (WORD_SIZE - pos % WORD_SIZE);
158
159            (partial_word | next_half) & ((1 << (len % WORD_SIZE)) - 1)
160        }
161    }
162
163    /// Return the number of zeros in the masked bit vector.
164    /// This method calls [`count_ones`].
165    ///
166    /// [`count_ones`]: MaskedBitVec::count_ones
167    #[inline]
168    #[must_use]
169    pub fn count_zeros(&self) -> u64 {
170        self.vec.len as u64 - self.count_ones()
171    }
172
173    /// Return the number of ones in the masked bit vector.
174    #[inline]
175    #[must_use]
176    #[allow(clippy::missing_panics_doc)] // can't panic because of bounds check
177    pub fn count_ones(&self) -> u64 {
178        let mut ones = self
179            .iter_limbs()
180            .take(self.vec.len / WORD_SIZE)
181            .map(|limb| u64::from(limb.count_ones()))
182            .sum();
183        if !self.vec.len.is_multiple_of(WORD_SIZE) {
184            ones += u64::from(
185                ((self.bin_op)(
186                    *self.vec.data.last().unwrap(),
187                    *self.mask.data.last().unwrap(),
188                ) & ((1 << (self.vec.len % WORD_SIZE)) - 1))
189                    .count_ones(),
190            );
191        }
192        ones
193    }
194
195    /// Collect the masked [`BitVec`] into a new `BitVec` by applying the mask to all bits.
196    #[inline]
197    #[must_use]
198    pub fn to_bit_vec(&self) -> BitVec {
199        BitVec {
200            data: self.iter_limbs().collect(),
201            len: self.vec.len,
202        }
203    }
204}