vortex_array/pipeline/bits/
view_mut.rs1use bitvec::array::BitArray;
5use bitvec::order::Lsb0;
6
7use crate::pipeline::bits::BitView;
8use crate::pipeline::{N, N_WORDS};
9
10#[derive(Debug)]
14pub struct BitViewMut<'a> {
15 bits: &'a mut BitArray<[usize; N_WORDS], Lsb0>,
16 true_count: usize,
17}
18
19impl<'a> BitViewMut<'a> {
20 pub fn new(bits: &'a mut [usize; N_WORDS]) -> Self {
21 let true_count = bits.iter().map(|&word| word.count_ones() as usize).sum();
22 let bits: &mut BitArray<[usize; N_WORDS], Lsb0> = unsafe { std::mem::transmute(bits) };
23 BitViewMut { bits, true_count }
24 }
25
26 pub(crate) unsafe fn new_unchecked(
27 bits: &'a mut BitArray<[usize; N_WORDS], Lsb0>,
28 true_count: usize,
29 ) -> Self {
30 BitViewMut { bits, true_count }
31 }
32
33 pub fn true_count(&self) -> usize {
34 self.true_count
35 }
36
37 pub fn intersect_prefix(&mut self, mut len: usize) {
39 assert!(len <= N, "BitViewMut::truncate: length exceeds N");
40
41 let bit_slice = self.bits.as_raw_mut_slice();
42
43 let mut word = 0;
44 let mut true_count = 0;
45 while len >= usize::BITS as usize {
46 true_count += bit_slice[word].count_ones() as usize;
47 len -= usize::BITS as usize;
48 word += 1;
49 }
50
51 if len > 0 {
52 bit_slice[word] &= !(usize::MAX << len);
53 true_count += bit_slice[word].count_ones() as usize;
54 word += 1;
55 }
56
57 while word < N_WORDS {
58 bit_slice[word] = 0;
59 word += 1;
60 }
61
62 self.set_true_count(true_count);
63 }
64
65 pub fn clear(&mut self) {
66 self.bits.as_raw_mut_slice().fill(0);
67 self.set_true_count(0);
68 }
69
70 pub fn fill_with_words(&mut self, mut iter: impl Iterator<Item = u64>) {
71 let mut true_count = 0;
72
73 let dst_bytes = unsafe {
74 std::slice::from_raw_parts_mut(
75 self.bits.as_raw_mut_slice().as_mut_ptr() as *mut u64,
76 N_WORDS,
77 )
78 };
79
80 for word in 0..N / 64 {
81 if let Some(value) = iter.next() {
82 dst_bytes[word] = value;
83 true_count += value.count_ones() as usize;
84 }
85 }
86 self.set_true_count(true_count);
87 }
88
89 pub fn as_view(&self) -> BitView<'_> {
90 unsafe { BitView::new_unchecked(self.bits, self.true_count) }
91 }
92
93 pub fn as_raw_mut(&mut self) -> &mut [usize; N_WORDS] {
94 unsafe { std::mem::transmute(&mut self.bits) }
95 }
96
97 #[inline(always)]
98 fn set_true_count(&mut self, true_count: usize) {
99 self.true_count = true_count;
100 debug_assert_eq!(
101 self.true_count,
102 self.bits
103 .as_raw_slice()
104 .iter()
105 .map(|&word| word.count_ones() as usize)
106 .sum::<usize>()
107 );
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::pipeline::bits::BitVector;
115
116 #[test]
117 fn test_intersect_prefix() {
118 let mut bit_vec = BitVector::full().clone();
119
120 let mut view_mut = bit_vec.as_view_mut();
121 assert_eq!(view_mut.true_count(), N);
122
123 view_mut.intersect_prefix(N - 1);
124 assert_eq!(view_mut.true_count(), N - 1);
125
126 view_mut.intersect_prefix(64);
127 assert_eq!(view_mut.true_count(), 64);
128
129 view_mut.intersect_prefix(10);
130 assert_eq!(view_mut.true_count(), 10);
131
132 view_mut.intersect_prefix(0);
133 assert_eq!(view_mut.true_count(), 0);
134 }
135}