vers_vecs/bit_vec/mask.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
//! This module defines a struct for lazily masking [`BitVec`]. It offers all immutable operations
//! of `BitVec` but applies a bit-mask during the operation. The struct is created through
//! [`BitVec::mask_xor`], [`BitVec::mask_and`], [`BitVec::mask_or`], or [`BitVec::mask_custom`].
use super::WORD_SIZE;
use crate::BitVec;
/// A bit vector that is masked with another bit vector via a masking function. Offers the same
/// functions as an unmasked vector. The mask is applied lazily.
#[derive(Debug, Clone)]
pub struct MaskedBitVec<'a, 'b, F: Fn(u64, u64) -> u64> {
vec: &'a BitVec,
mask: &'b BitVec,
bin_op: F,
}
impl<'a, 'b, F> MaskedBitVec<'a, 'b, F>
where
F: Fn(u64, u64) -> u64,
{
#[inline]
pub(crate) fn new(vec: &'a BitVec, mask: &'b BitVec, bin_op: F) -> Result<Self, String> {
if vec.len != mask.len {
return Err(String::from(
"mask cannot have different length than vector",
));
}
Ok(MaskedBitVec { vec, mask, bin_op })
}
/// Iterate over the limbs of the masked vector
#[inline]
fn iter_limbs<'s>(&'s self) -> impl Iterator<Item = u64> + 's
where
'a: 's,
'b: 's,
{
self.vec
.data
.iter()
.zip(&self.mask.data)
.map(|(&a, &b)| (self.bin_op)(a, b))
}
/// Return the bit at the given position.
/// The bit takes the least significant bit of the returned u64 word.
/// If the position is larger than the length of the vector, None is returned.
#[inline]
#[must_use]
pub fn get(&self, pos: usize) -> Option<u64> {
if pos >= self.vec.len {
None
} else {
Some(self.get_unchecked(pos))
}
}
/// Return the bit at the given position.
/// The bit takes the least significant bit of the returned u64 word.
///
/// # Panics
/// If the position is larger than the length of the vector,
/// the function will either return unpredictable data, or panic.
/// Use [`get`] to properly handle this case with an `Option`.
///
/// [`get`]: MaskedBitVec::get
#[inline]
#[must_use]
pub fn get_unchecked(&self, pos: usize) -> u64 {
((self.bin_op)(
self.vec.data[pos / WORD_SIZE],
self.mask.data[pos / WORD_SIZE],
) >> (pos % WORD_SIZE))
& 1
}
/// Return whether the bit at the given position is set.
/// If the position is larger than the length of the vector, None is returned.
#[inline]
#[must_use]
pub fn is_bit_set(&self, pos: usize) -> Option<bool> {
if pos >= self.vec.len {
None
} else {
Some(self.is_bit_set_unchecked(pos))
}
}
/// Return whether the bit at the given position is set.
///
/// # Panics
/// If the position is larger than the length of the vector,
/// the function will either return unpredictable data, or panic.
/// Use [`is_bit_set`] to properly handle this case with an `Option`.
///
/// [`is_bit_set`]: MaskedBitVec::is_bit_set
#[inline]
#[must_use]
pub fn is_bit_set_unchecked(&self, pos: usize) -> bool {
self.get_unchecked(pos) != 0
}
/// Return multiple bits at the given position. The number of bits to return is given by `len`.
/// At most 64 bits can be returned.
/// If the position at the end of the query is larger than the length of the vector,
/// None is returned (even if the query partially overlaps with the vector).
/// If the length of the query is larger than 64, None is returned.
#[inline]
#[must_use]
pub fn get_bits(&self, pos: usize, len: usize) -> Option<u64> {
if len > WORD_SIZE || len == 0 {
return None;
}
if pos + len > self.vec.len {
None
} else {
Some(self.get_bits_unchecked(pos, len))
}
}
/// Return multiple bits at the given position. The number of bits to return is given by `len`.
/// At most 64 bits can be returned.
///
/// This function is always inlined, because it gains a lot from loop optimization and
/// can utilize the processor pre-fetcher better if it is.
///
/// # Errors
/// If the length of the query is larger than 64, unpredictable data will be returned.
/// Use [`get_bits`] to avoid this.
///
/// # Panics
/// If the position or interval is larger than the length of the vector,
/// the function will either return any valid results padded with unpredictable
/// data or panic.
///
/// [`get_bits`]: MaskedBitVec::get_bits
#[must_use]
#[allow(clippy::inline_always)]
#[allow(clippy::comparison_chain)] // rust-clippy #5354
#[inline]
pub fn get_bits_unchecked(&self, pos: usize, len: usize) -> u64 {
debug_assert!(len <= WORD_SIZE);
let partial_word = (self.bin_op)(
self.vec.data[pos / WORD_SIZE],
self.mask.data[pos / WORD_SIZE],
) >> (pos % WORD_SIZE);
if pos % WORD_SIZE + len == WORD_SIZE {
partial_word
} else if pos % WORD_SIZE + len < WORD_SIZE {
partial_word & ((1 << (len % WORD_SIZE)) - 1)
} else {
let next_half = (self.bin_op)(
self.vec.data[pos / WORD_SIZE + 1],
self.mask.data[pos / WORD_SIZE + 1],
) << (WORD_SIZE - pos % WORD_SIZE);
(partial_word | next_half) & ((1 << (len % WORD_SIZE)) - 1)
}
}
/// Return the number of zeros in the masked bit vector.
/// This method calls [`count_ones`].
///
/// [`count_ones`]: MaskedBitVec::count_ones
#[inline]
#[must_use]
pub fn count_zeros(&self) -> u64 {
self.vec.len as u64 - self.count_ones()
}
/// Return the number of ones in the masked bit vector.
#[inline]
#[must_use]
#[allow(clippy::missing_panics_doc)] // can't panic because of bounds check
pub fn count_ones(&self) -> u64 {
let mut ones = self
.iter_limbs()
.take(self.vec.len / WORD_SIZE)
.map(|limb| u64::from(limb.count_ones()))
.sum();
if self.vec.len % WORD_SIZE > 0 {
ones += u64::from(
((self.bin_op)(
*self.vec.data.last().unwrap(),
*self.mask.data.last().unwrap(),
) & ((1 << (self.vec.len % WORD_SIZE)) - 1))
.count_ones(),
);
}
ones
}
/// Collect the masked [`BitVec`] into a new `BitVec` by applying the mask to all bits.
#[inline]
#[must_use]
pub fn to_bit_vec(&self) -> BitVec {
BitVec {
data: self.iter_limbs().collect(),
len: self.vec.len,
}
}
}