vortex_array/pipeline/bits/
view_mut.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use bitvec::array::BitArray;
5use bitvec::order::Lsb0;
6
7use crate::pipeline::bits::BitView;
8use crate::pipeline::{N, N_WORDS};
9
10/// A mutable borrowed fixed-size bit vector of length `N` bits, represented as an array of
11/// usize words.
12/// Mutable view into a bit array for constructing selection masks.
13#[derive(Debug)]
14pub struct BitViewMut<'a> {
15    bits: &'a mut BitArray<[usize; N_WORDS], Lsb0>,
16    true_count: usize,
17}
18
19impl<'a> BitViewMut<'a> {
20    pub fn new(bits: &'a mut [usize; N_WORDS]) -> Self {
21        let true_count = bits.iter().map(|&word| word.count_ones() as usize).sum();
22        let bits: &mut BitArray<[usize; N_WORDS], Lsb0> = unsafe { std::mem::transmute(bits) };
23        BitViewMut { bits, true_count }
24    }
25
26    pub(crate) unsafe fn new_unchecked(
27        bits: &'a mut BitArray<[usize; N_WORDS], Lsb0>,
28        true_count: usize,
29    ) -> Self {
30        BitViewMut { bits, true_count }
31    }
32
33    pub fn true_count(&self) -> usize {
34        self.true_count
35    }
36
37    /// Mask the values in the mask up to the given length.
38    pub fn intersect_prefix(&mut self, mut len: usize) {
39        assert!(len <= N, "BitViewMut::truncate: length exceeds N");
40
41        let bit_slice = self.bits.as_raw_mut_slice();
42
43        let mut word = 0;
44        let mut true_count = 0;
45        while len >= usize::BITS as usize {
46            true_count += bit_slice[word].count_ones() as usize;
47            len -= usize::BITS as usize;
48            word += 1;
49        }
50
51        if len > 0 {
52            bit_slice[word] &= !(usize::MAX << len);
53            true_count += bit_slice[word].count_ones() as usize;
54            word += 1;
55        }
56
57        while word < N_WORDS {
58            bit_slice[word] = 0;
59            word += 1;
60        }
61
62        self.set_true_count(true_count);
63    }
64
65    pub fn clear(&mut self) {
66        self.bits.as_raw_mut_slice().fill(0);
67        self.set_true_count(0);
68    }
69
70    pub fn fill_with_words(&mut self, mut iter: impl Iterator<Item = u64>) {
71        let mut true_count = 0;
72
73        let dst_bytes = unsafe {
74            std::slice::from_raw_parts_mut(
75                self.bits.as_raw_mut_slice().as_mut_ptr() as *mut u64,
76                N_WORDS,
77            )
78        };
79
80        for word in 0..N / 64 {
81            if let Some(value) = iter.next() {
82                dst_bytes[word] = value;
83                true_count += value.count_ones() as usize;
84            }
85        }
86        self.set_true_count(true_count);
87    }
88
89    pub fn as_view(&self) -> BitView<'_> {
90        unsafe { BitView::new_unchecked(self.bits, self.true_count) }
91    }
92
93    pub fn as_raw_mut(&mut self) -> &mut [usize; N_WORDS] {
94        unsafe { std::mem::transmute(&mut self.bits) }
95    }
96
97    #[inline(always)]
98    fn set_true_count(&mut self, true_count: usize) {
99        self.true_count = true_count;
100        debug_assert_eq!(
101            self.true_count,
102            self.bits
103                .as_raw_slice()
104                .iter()
105                .map(|&word| word.count_ones() as usize)
106                .sum::<usize>()
107        );
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::pipeline::bits::BitVector;
115
116    #[test]
117    fn test_intersect_prefix() {
118        let mut bit_vec = BitVector::full().clone();
119
120        let mut view_mut = bit_vec.as_view_mut();
121        assert_eq!(view_mut.true_count(), N);
122
123        view_mut.intersect_prefix(N - 1);
124        assert_eq!(view_mut.true_count(), N - 1);
125
126        view_mut.intersect_prefix(64);
127        assert_eq!(view_mut.true_count(), 64);
128
129        view_mut.intersect_prefix(10);
130        assert_eq!(view_mut.true_count(), 10);
131
132        view_mut.intersect_prefix(0);
133        assert_eq!(view_mut.true_count(), 0);
134    }
135}