vers_vecs/bit_vec/mask.rs
1//! This module defines a struct for lazily masking [`BitVec`]. It offers all immutable operations
2//! of `BitVec` but applies a bit-mask during the operation. The struct is created through
3//! [`BitVec::mask_xor`], [`BitVec::mask_and`], [`BitVec::mask_or`], or [`BitVec::mask_custom`].
4
5use super::WORD_SIZE;
6use crate::BitVec;
7
8/// A bit vector that is masked with another bit vector via a masking function. Offers the same
9/// functions as an unmasked vector. The mask is applied lazily.
10#[derive(Debug, Clone)]
11pub struct MaskedBitVec<'a, 'b, F: Fn(u64, u64) -> u64> {
12 vec: &'a BitVec,
13 mask: &'b BitVec,
14 bin_op: F,
15}
16
17impl<'a, 'b, F> MaskedBitVec<'a, 'b, F>
18where
19 F: Fn(u64, u64) -> u64,
20{
21 #[inline]
22 pub(crate) fn new(vec: &'a BitVec, mask: &'b BitVec, bin_op: F) -> Result<Self, String> {
23 if vec.len != mask.len {
24 return Err(String::from(
25 "mask cannot have different length than vector",
26 ));
27 }
28
29 Ok(MaskedBitVec { vec, mask, bin_op })
30 }
31
32 /// Iterate over the limbs of the masked vector
33 #[inline]
34 fn iter_limbs<'s>(&'s self) -> impl Iterator<Item = u64> + 's
35 where
36 'a: 's,
37 'b: 's,
38 {
39 self.vec
40 .data
41 .iter()
42 .zip(&self.mask.data)
43 .map(|(&a, &b)| (self.bin_op)(a, b))
44 }
45
46 /// Return the bit at the given position.
47 /// The bit takes the least significant bit of the returned u64 word.
48 /// If the position is larger than the length of the vector, None is returned.
49 #[inline]
50 #[must_use]
51 pub fn get(&self, pos: usize) -> Option<u64> {
52 if pos >= self.vec.len {
53 None
54 } else {
55 Some(self.get_unchecked(pos))
56 }
57 }
58
59 /// Return the bit at the given position.
60 /// The bit takes the least significant bit of the returned u64 word.
61 ///
62 /// # Panics
63 /// If the position is larger than the length of the vector,
64 /// the function will either return unpredictable data, or panic.
65 /// Use [`get`] to properly handle this case with an `Option`.
66 ///
67 /// [`get`]: MaskedBitVec::get
68 #[inline]
69 #[must_use]
70 pub fn get_unchecked(&self, pos: usize) -> u64 {
71 ((self.bin_op)(
72 self.vec.data[pos / WORD_SIZE],
73 self.mask.data[pos / WORD_SIZE],
74 ) >> (pos % WORD_SIZE))
75 & 1
76 }
77
78 /// Return whether the bit at the given position is set.
79 /// If the position is larger than the length of the vector, None is returned.
80 #[inline]
81 #[must_use]
82 pub fn is_bit_set(&self, pos: usize) -> Option<bool> {
83 if pos >= self.vec.len {
84 None
85 } else {
86 Some(self.is_bit_set_unchecked(pos))
87 }
88 }
89
90 /// Return whether the bit at the given position is set.
91 ///
92 /// # Panics
93 /// If the position is larger than the length of the vector,
94 /// the function will either return unpredictable data, or panic.
95 /// Use [`is_bit_set`] to properly handle this case with an `Option`.
96 ///
97 /// [`is_bit_set`]: MaskedBitVec::is_bit_set
98 #[inline]
99 #[must_use]
100 pub fn is_bit_set_unchecked(&self, pos: usize) -> bool {
101 self.get_unchecked(pos) != 0
102 }
103
104 /// Return multiple bits at the given position. The number of bits to return is given by `len`.
105 /// At most 64 bits can be returned.
106 /// If the position at the end of the query is larger than the length of the vector,
107 /// None is returned (even if the query partially overlaps with the vector).
108 /// If the length of the query is larger than 64, None is returned.
109 #[inline]
110 #[must_use]
111 pub fn get_bits(&self, pos: usize, len: usize) -> Option<u64> {
112 if len > WORD_SIZE || len == 0 {
113 return None;
114 }
115 if pos + len > self.vec.len {
116 None
117 } else {
118 Some(self.get_bits_unchecked(pos, len))
119 }
120 }
121
122 /// Return multiple bits at the given position. The number of bits to return is given by `len`.
123 /// At most 64 bits can be returned.
124 ///
125 /// This function is always inlined, because it gains a lot from loop optimization and
126 /// can utilize the processor pre-fetcher better if it is.
127 ///
128 /// # Errors
129 /// If the length of the query is larger than 64, unpredictable data will be returned.
130 /// Use [`get_bits`] to avoid this.
131 ///
132 /// # Panics
133 /// If the position or interval is larger than the length of the vector,
134 /// the function will either return any valid results padded with unpredictable
135 /// data or panic.
136 ///
137 /// [`get_bits`]: MaskedBitVec::get_bits
138 #[must_use]
139 #[allow(clippy::inline_always)]
140 #[allow(clippy::comparison_chain)] // rust-clippy #5354
141 #[inline]
142 pub fn get_bits_unchecked(&self, pos: usize, len: usize) -> u64 {
143 debug_assert!(len <= WORD_SIZE);
144 let partial_word = (self.bin_op)(
145 self.vec.data[pos / WORD_SIZE],
146 self.mask.data[pos / WORD_SIZE],
147 ) >> (pos % WORD_SIZE);
148
149 if pos % WORD_SIZE + len == WORD_SIZE {
150 partial_word
151 } else if pos % WORD_SIZE + len < WORD_SIZE {
152 partial_word & ((1 << (len % WORD_SIZE)) - 1)
153 } else {
154 let next_half = (self.bin_op)(
155 self.vec.data[pos / WORD_SIZE + 1],
156 self.mask.data[pos / WORD_SIZE + 1],
157 ) << (WORD_SIZE - pos % WORD_SIZE);
158
159 (partial_word | next_half) & ((1 << (len % WORD_SIZE)) - 1)
160 }
161 }
162
163 /// Return the number of zeros in the masked bit vector.
164 /// This method calls [`count_ones`].
165 ///
166 /// [`count_ones`]: MaskedBitVec::count_ones
167 #[inline]
168 #[must_use]
169 pub fn count_zeros(&self) -> u64 {
170 self.vec.len as u64 - self.count_ones()
171 }
172
173 /// Return the number of ones in the masked bit vector.
174 #[inline]
175 #[must_use]
176 #[allow(clippy::missing_panics_doc)] // can't panic because of bounds check
177 pub fn count_ones(&self) -> u64 {
178 let mut ones = self
179 .iter_limbs()
180 .take(self.vec.len / WORD_SIZE)
181 .map(|limb| u64::from(limb.count_ones()))
182 .sum();
183 if !self.vec.len.is_multiple_of(WORD_SIZE) {
184 ones += u64::from(
185 ((self.bin_op)(
186 *self.vec.data.last().unwrap(),
187 *self.mask.data.last().unwrap(),
188 ) & ((1 << (self.vec.len % WORD_SIZE)) - 1))
189 .count_ones(),
190 );
191 }
192 ones
193 }
194
195 /// Collect the masked [`BitVec`] into a new `BitVec` by applying the mask to all bits.
196 #[inline]
197 #[must_use]
198 pub fn to_bit_vec(&self) -> BitVec {
199 BitVec {
200 data: self.iter_limbs().collect(),
201 len: self.vec.len,
202 }
203 }
204}