vortex_array/pipeline/bits/
vector.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Debug, Formatter};
5use std::ops::Not;
6use std::sync::{Arc, LazyLock};
7
8use bitvec::array::BitArray;
9use bitvec::order::Lsb0;
10
11use super::{BitView, BitViewMut};
12use crate::pipeline::{N, N_WORDS};
13
14static EMPTY: LazyLock<BitVector> = LazyLock::new(|| BitVector {
15    bits: Arc::new(BitArray::ZERO),
16    true_count: 0,
17});
18
19static FULL: LazyLock<BitVector> = LazyLock::new(|| BitVector {
20    bits: Arc::new(BitArray::ZERO.not()),
21    true_count: N,
22});
23
24/// An owned fixed-size bit vector of length `N` bits, represented as an array of usize words.
25///
26/// Internally, it uses a [`BitArray`] to store the bits, but this crate has some
27/// performance foot-guns in cases where we can lean on better assumptions, and therefore we wrap
28/// it up for use within Vortex.
29/// Owned bit vector for storing boolean selection masks.
30#[derive(Clone)]
31pub struct BitVector {
32    pub(super) bits: Arc<BitArray<[usize; N_WORDS], Lsb0>>,
33    pub(super) true_count: usize,
34}
35
36impl Debug for BitVector {
37    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("BitVector")
39            .field("true_count", &self.true_count)
40            //.field("bits", &self.bits.as_raw_slice())
41            .finish()
42    }
43}
44
45impl PartialEq for BitVector {
46    fn eq(&self, other: &Self) -> bool {
47        Arc::ptr_eq(&self.bits, &other.bits)
48            || (self.true_count == other.true_count && self.bits == other.bits)
49    }
50}
51
52impl Eq for BitVector {}
53
54impl BitVector {
55    pub fn empty() -> &'static BitVector {
56        &EMPTY
57    }
58
59    pub fn full() -> &'static BitVector {
60        &FULL
61    }
62
63    pub fn true_until(n: usize) -> Self {
64        assert!(n <= N, "Cannot create a BitVector with more than N bits");
65
66        let mut bits = Arc::new(BitArray::<[usize; N_WORDS], Lsb0>::ZERO);
67        let bits_mut = Arc::make_mut(&mut bits);
68
69        let mut word = 0;
70        let mut remaining = n;
71        while remaining >= usize::BITS as usize {
72            bits_mut.as_raw_mut_slice()[word] = usize::MAX;
73            remaining -= usize::BITS as usize;
74            word += 1;
75        }
76
77        if remaining > 0 {
78            // For LSB ordering, set the lower bits (0 to remaining-1)
79            bits_mut.as_raw_mut_slice()[word] = (1usize << remaining) - 1;
80        }
81
82        BitVector {
83            bits,
84            true_count: n,
85        }
86    }
87
88    pub fn true_count(&self) -> usize {
89        self.true_count
90    }
91
92    pub fn as_raw(&self) -> &[usize; N_WORDS] {
93        // It's actually remarkably hard to get a reference to the underlying array!
94        let raw = self.bits.as_raw_slice();
95        unsafe { &*(raw.as_ptr() as *const [usize; N_WORDS]) }
96    }
97
98    pub fn as_raw_mut(&mut self) -> &mut [usize; N_WORDS] {
99        // SAFETY: We assume that the bits are mutable and that the view is valid.
100        let raw = Arc::make_mut(&mut self.bits).as_raw_mut_slice();
101        unsafe { &mut *(raw.as_mut_ptr() as *mut [usize; N_WORDS]) }
102    }
103
104    pub fn fill_from<I>(&mut self, iter: I)
105    where
106        I: IntoIterator<Item = usize>,
107    {
108        let mut true_count = 0;
109        for (dst, word) in self.as_raw_mut().iter_mut().zip(iter) {
110            true_count += word.count_ones() as usize;
111            *dst = word;
112        }
113        self.true_count = true_count;
114    }
115
116    pub fn as_view(&self) -> BitView<'_> {
117        unsafe { BitView::new_unchecked(&self.bits, self.true_count) }
118    }
119
120    pub fn as_view_mut(&mut self) -> BitViewMut<'_> {
121        unsafe { BitViewMut::new_unchecked(Arc::make_mut(&mut self.bits), self.true_count) }
122    }
123}
124
125impl From<BitView<'_>> for BitVector {
126    fn from(value: BitView<'_>) -> Self {
127        let true_count = value.true_count();
128        let bits = Arc::new(BitArray::<[usize; N_WORDS], Lsb0>::from(*value.as_raw()));
129        BitVector { bits, true_count }
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn test_fill_from() {
139        let mut vec = BitVector::empty().clone();
140
141        // Fill with a pattern
142        let pattern = [
143            0b1010101010101010usize,
144            0b1111000011110000usize,
145            usize::MAX,
146            0,
147        ];
148
149        vec.fill_from(pattern.iter().copied());
150
151        let raw = vec.as_raw();
152        assert_eq!(raw[0], 0b1010101010101010usize);
153        assert_eq!(raw[1], 0b1111000011110000usize);
154        assert_eq!(raw[2], usize::MAX);
155        assert_eq!(raw[3], 0);
156
157        // Check true_count is updated correctly
158        let expected_count = 0b1010101010101010usize.count_ones() as usize
159            + 0b1111000011110000usize.count_ones() as usize
160            + usize::MAX.count_ones() as usize;
161        assert_eq!(vec.true_count(), expected_count);
162    }
163
164    #[test]
165    fn test_as_view() {
166        let vec = BitVector::true_until(100);
167        let view = vec.as_view();
168
169        assert_eq!(view.true_count(), 100);
170
171        // Verify the view sees the same bits
172        let mut ones = Vec::new();
173        view.iter_ones(|idx| ones.push(idx));
174        assert_eq!(ones, (0..100).collect::<Vec<_>>());
175    }
176
177    #[test]
178    fn test_as_view_mut() {
179        let mut vec = BitVector::true_until(64);
180        {
181            let view_mut = vec.as_view_mut();
182            // BitViewMut would allow modifications
183            // This test just verifies we can create a mutable view
184            assert_eq!(view_mut.true_count(), 64);
185        }
186        assert_eq!(vec.true_count(), 64);
187    }
188
189    #[test]
190    fn test_from_bitview() {
191        // Create a BitView from raw data
192        let mut raw = [0usize; N_WORDS];
193        raw[0] = 0b11111111;
194        raw[1] = 0b11110000;
195
196        let view = BitView::new(&raw);
197        let vec = BitVector::from(view);
198
199        assert_eq!(vec.true_count(), view.true_count());
200        assert_eq!(vec.as_raw()[0], 0b11111111);
201        assert_eq!(vec.as_raw()[1], 0b11110000);
202    }
203
204    #[test]
205    fn test_lsb_ordering_verification() {
206        // Verify LSB ordering by setting specific bits
207        let vec = BitVector::true_until(5);
208        let view = vec.as_view();
209
210        // Collect which bits are set
211        let mut ones = Vec::new();
212        view.iter_ones(|idx| ones.push(idx));
213
214        // With LSB ordering, bits 0-4 should be set
215        assert_eq!(ones, vec![0, 1, 2, 3, 4]);
216    }
217
218    #[test]
219    fn test_as_raw_mut() {
220        let mut vec = BitVector::empty().clone();
221
222        // Modify through as_raw_mut
223        let raw_mut = vec.as_raw_mut();
224        raw_mut[0] = 0b1111;
225        raw_mut[2] = usize::MAX;
226
227        // Note: true_count is NOT automatically updated when using as_raw_mut
228        // This is a low-level API, so the user must manage true_count
229        vec.true_count = 4 + 64; // Update manually
230
231        assert_eq!(vec.as_raw()[0], 0b1111);
232        assert_eq!(vec.as_raw()[2], usize::MAX);
233        assert_eq!(vec.true_count(), 68);
234    }
235
236    #[test]
237    fn test_boundary_conditions() {
238        // Test various boundary values
239        let boundaries = [1, 31, 32, 33, 63, 64, 65, 127, 128, 129, N - 1, N];
240
241        for &n in &boundaries {
242            let vec = BitVector::true_until(n);
243            assert_eq!(vec.true_count(), n);
244
245            // Verify correct bits are set via view
246            let view = vec.as_view();
247            let mut ones = Vec::new();
248            view.iter_ones(|idx| ones.push(idx));
249            assert_eq!(ones.len(), n);
250            if n > 0 {
251                assert_eq!(ones[0], 0); // First bit should be 0 (LSB)
252                assert_eq!(ones[n - 1], n - 1); // Last bit should be n-1
253            }
254        }
255    }
256}