vortex_array/pipeline/bits/
view.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Debug, Formatter};
5
6use bitvec::prelude::*;
7use vortex_error::{VortexError, vortex_err};
8
9use crate::pipeline::{N, N_WORDS};
10
11/// A borrowed fixed-size bit vector of length `N` bits, represented as an array of usize words.
12///
13/// Internally, it uses a [`BitArray`] to store the bits, but this crate has some
14/// performance foot-guns in cases where we can lean on better assumptions, and therefore we wrap
15/// it up for use within Vortex.
16/// Read-only view into a bit array for selection masking in operator operations.
17#[derive(Clone, Copy)]
18pub struct BitView<'a> {
19    bits: &'a BitArray<[usize; N_WORDS], Lsb0>,
20    true_count: usize,
21}
22
23impl Debug for BitView<'_> {
24    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25        f.debug_struct("BitView")
26            .field("true_count", &self.true_count)
27            .field("bits", &self.as_raw())
28            .finish()
29    }
30}
31
32impl BitView<'static> {
33    pub fn all_true() -> Self {
34        static ALL_TRUE: [usize; N_WORDS] = [usize::MAX; N_WORDS];
35        unsafe {
36            BitView::new_unchecked(
37                std::mem::transmute::<&[usize; N_WORDS], &BitArray<[usize; N_WORDS], Lsb0>>(
38                    &ALL_TRUE,
39                ),
40                N,
41            )
42        }
43    }
44
45    pub fn all_false() -> Self {
46        static ALL_FALSE: [usize; N_WORDS] = [0; N_WORDS];
47        unsafe {
48            BitView::new_unchecked(
49                std::mem::transmute::<&[usize; N_WORDS], &BitArray<[usize; N_WORDS], Lsb0>>(
50                    &ALL_FALSE,
51                ),
52                0,
53            )
54        }
55    }
56}
57
58impl<'a> BitView<'a> {
59    pub fn new(bits: &[usize; N_WORDS]) -> Self {
60        let true_count = bits.iter().map(|&word| word.count_ones() as usize).sum();
61        let bits: &BitArray<[usize; N_WORDS], Lsb0> = unsafe {
62            std::mem::transmute::<&[usize; N_WORDS], &BitArray<[usize; N_WORDS], Lsb0>>(bits)
63        };
64        BitView { bits, true_count }
65    }
66
67    pub(crate) unsafe fn new_unchecked(
68        bits: &'a BitArray<[usize; N_WORDS], Lsb0>,
69        true_count: usize,
70    ) -> Self {
71        BitView { bits, true_count }
72    }
73
74    /// Returns the number of `true` bits in the view.
75    pub fn true_count(&self) -> usize {
76        self.true_count
77    }
78
79    /// Runs the provided function `f` for each index of a `true` bit in the view.
80    pub fn iter_ones<F>(&self, mut f: F)
81    where
82        F: FnMut(usize),
83    {
84        match self.true_count {
85            0 => {}
86            N => (0..N).for_each(&mut f),
87            _ => {
88                let mut bit_idx = 0;
89                for mut raw in self.bits.into_inner() {
90                    while raw != 0 {
91                        let bit_pos = raw.trailing_zeros();
92                        f(bit_idx + bit_pos as usize);
93                        raw &= raw - 1; // Clear the bit at `bit_pos`
94                    }
95                    bit_idx += usize::BITS as usize;
96                }
97            }
98        }
99    }
100
101    /// Runs the provided function `f` for each index of a `true` bit in the view.
102    pub fn iter_zeros<F>(&self, mut f: F)
103    where
104        F: FnMut(usize),
105    {
106        match self.true_count {
107            0 => (0..N).for_each(&mut f),
108            N => {}
109            _ => {
110                let mut bit_idx = 0;
111                for mut raw in self.bits.into_inner() {
112                    while raw != usize::MAX {
113                        let bit_pos = raw.trailing_ones();
114                        f(bit_idx + bit_pos as usize);
115                        raw |= 1usize << bit_pos; // Set the zero bit to 1
116                    }
117                    bit_idx += usize::BITS as usize;
118                }
119            }
120        }
121    }
122
123    /// Runs the provided function `f` for each range of `true` bits in the view.
124    ///
125    /// The function `f` receives a tuple `(start, len)` where `start` is the index of the first
126    /// `true` bit and `len` is the number of consecutive `true` bits.
127    pub fn iter_slices<F>(&self, mut f: F)
128    where
129        F: FnMut((usize, usize)),
130    {
131        match self.true_count {
132            0 => {}
133            N => f((0, N)),
134            _ => {
135                let mut bit_idx = 0;
136                for mut raw in self.bits.into_inner() {
137                    let mut offset = 0;
138                    while raw != 0 {
139                        // Skip leading zeros first
140                        let zeros = raw.leading_zeros();
141                        offset += zeros;
142                        raw <<= zeros;
143
144                        if offset >= 64 {
145                            break;
146                        }
147
148                        // Count leading ones
149                        let ones = raw.leading_ones();
150                        if ones > 0 {
151                            f((bit_idx + offset as usize, ones as usize));
152                            offset += ones;
153                            raw <<= ones;
154                        }
155                    }
156                    bit_idx += usize::BITS as usize; // Move to next word
157                }
158            }
159        }
160    }
161
162    pub fn as_raw(&self) -> &[usize; N_WORDS] {
163        // It's actually remarkably hard to get a reference to the underlying array!
164        let raw = self.bits.as_raw_slice();
165        unsafe { &*(raw.as_ptr() as *const [usize; N_WORDS]) }
166    }
167}
168
169impl<'a> From<&'a [usize; N_WORDS]> for BitView<'a> {
170    fn from(value: &'a [usize; N_WORDS]) -> Self {
171        Self::new(value)
172    }
173}
174
175impl<'a> From<&'a BitArray<[usize; N_WORDS], Lsb0>> for BitView<'a> {
176    fn from(bits: &'a BitArray<[usize; N_WORDS], Lsb0>) -> Self {
177        BitView::new(unsafe {
178            std::mem::transmute::<&BitArray<[usize; N_WORDS]>, &[usize; N_WORDS]>(bits)
179        })
180    }
181}
182
183impl<'a> TryFrom<&'a BitSlice<usize, Lsb0>> for BitView<'a> {
184    type Error = VortexError;
185
186    fn try_from(value: &'a BitSlice<usize, Lsb0>) -> Result<Self, Self::Error> {
187        let bits: &BitArray<[usize; N_WORDS], Lsb0> = value
188            .try_into()
189            .map_err(|e| vortex_err!("Failed to convert BitSlice to BitArray: {}", e))?;
190        Ok(BitView::new(unsafe {
191            std::mem::transmute::<&BitArray<[usize; N_WORDS]>, &[usize; N_WORDS]>(bits)
192        }))
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use vortex_mask::Mask;
199
200    use super::*;
201    use crate::pipeline::bits::BitVector;
202
203    #[test]
204    fn test_iter_ones_empty() {
205        let bits = [0usize; N_WORDS];
206        let view = BitView::new(&bits);
207
208        let mut ones = Vec::new();
209        view.iter_ones(|idx| ones.push(idx));
210
211        assert_eq!(ones, Vec::<usize>::new());
212        assert_eq!(view.true_count(), 0);
213    }
214
215    #[test]
216    fn test_iter_ones_all_set() {
217        let view = BitView::all_true();
218
219        let mut ones = Vec::new();
220        view.iter_ones(|idx| ones.push(idx));
221
222        assert_eq!(ones.len(), N);
223        assert_eq!(ones, (0..N).collect::<Vec<_>>());
224        assert_eq!(view.true_count(), N);
225    }
226
227    #[test]
228    fn test_iter_zeros_empty() {
229        let bits = [0usize; N_WORDS];
230        let view = BitView::new(&bits);
231
232        let mut zeros = Vec::new();
233        view.iter_zeros(|idx| zeros.push(idx));
234
235        assert_eq!(zeros.len(), N);
236        assert_eq!(zeros, (0..N).collect::<Vec<_>>());
237    }
238
239    #[test]
240    fn test_iter_zeros_all_set() {
241        let view = BitView::all_true();
242
243        let mut zeros = Vec::new();
244        view.iter_zeros(|idx| zeros.push(idx));
245
246        assert_eq!(zeros, Vec::<usize>::new());
247    }
248
249    #[test]
250    fn test_iter_ones_single_bit() {
251        let mut bits = [0usize; N_WORDS];
252        bits[0] = 1; // Set bit 0 (LSB)
253        let view = BitView::new(&bits);
254
255        let mut ones = Vec::new();
256        view.iter_ones(|idx| ones.push(idx));
257
258        assert_eq!(ones, vec![0]);
259        assert_eq!(view.true_count(), 1);
260    }
261
262    #[test]
263    fn test_iter_zeros_single_bit_unset() {
264        let mut bits = [usize::MAX; N_WORDS];
265        bits[0] = usize::MAX ^ 1; // Clear bit 0 (LSB)
266        let view = BitView::new(&bits);
267
268        let mut zeros = Vec::new();
269        view.iter_zeros(|idx| zeros.push(idx));
270
271        assert_eq!(zeros, vec![0]);
272    }
273
274    #[test]
275    fn test_iter_ones_multiple_bits_first_word() {
276        let mut bits = [0usize; N_WORDS];
277        bits[0] = 0b1010101; // Set bits 0, 2, 4, 6
278        let view = BitView::new(&bits);
279
280        let mut ones = Vec::new();
281        view.iter_ones(|idx| ones.push(idx));
282
283        assert_eq!(ones, vec![0, 2, 4, 6]);
284        assert_eq!(view.true_count(), 4);
285    }
286
287    #[test]
288    fn test_iter_zeros_multiple_bits_first_word() {
289        let mut bits = [usize::MAX; N_WORDS];
290        bits[0] = !0b1010101; // Clear bits 0, 2, 4, 6
291        let view = BitView::new(&bits);
292
293        let mut zeros = Vec::new();
294        view.iter_zeros(|idx| zeros.push(idx));
295
296        assert_eq!(zeros, vec![0, 2, 4, 6]);
297    }
298
299    #[test]
300    fn test_iter_ones_across_words() {
301        let mut bits = [0usize; N_WORDS];
302        bits[0] = 1 << 63; // Set bit 63 of first word
303        bits[1] = 1; // Set bit 0 of second word (bit 64 overall)
304        bits[2] = 1 << 31; // Set bit 31 of third word (bit 159 overall)
305        let view = BitView::new(&bits);
306
307        let mut ones = Vec::new();
308        view.iter_ones(|idx| ones.push(idx));
309
310        assert_eq!(ones, vec![63, 64, 159]);
311        assert_eq!(view.true_count(), 3);
312    }
313
314    #[test]
315    fn test_iter_zeros_across_words() {
316        let mut bits = [usize::MAX; N_WORDS];
317        bits[0] = !(1 << 63); // Clear bit 63 of first word
318        bits[1] = !1; // Clear bit 0 of second word (bit 64 overall)
319        bits[2] = !(1 << 31); // Clear bit 31 of third word (bit 159 overall)
320        let view = BitView::new(&bits);
321
322        let mut zeros = Vec::new();
323        view.iter_zeros(|idx| zeros.push(idx));
324
325        assert_eq!(zeros, vec![63, 64, 159]);
326    }
327
328    #[test]
329    fn test_lsb_bit_ordering() {
330        let mut bits = [0usize; N_WORDS];
331        bits[0] = 0b11111111; // Set bits 0-7 (LSB ordering)
332        let view = BitView::new(&bits);
333
334        let mut ones = Vec::new();
335        view.iter_ones(|idx| ones.push(idx));
336
337        assert_eq!(ones, vec![0, 1, 2, 3, 4, 5, 6, 7]);
338        assert_eq!(view.true_count(), 8);
339    }
340
341    #[test]
342    fn test_iter_ones_and_zeros_complement() {
343        let mut bits = [0usize; N_WORDS];
344        bits[0] = 0xAAAAAAAAAAAAAAAA; // Alternating pattern
345        let view = BitView::new(&bits);
346
347        let mut ones = Vec::new();
348        let mut zeros = Vec::new();
349        view.iter_ones(|idx| ones.push(idx));
350        view.iter_zeros(|idx| zeros.push(idx));
351
352        // Check that ones and zeros together cover all indices
353        let mut all_indices = ones.clone();
354        all_indices.extend(&zeros);
355        all_indices.sort_unstable();
356
357        assert_eq!(all_indices, (0..N).collect::<Vec<_>>());
358
359        // Check they don't overlap
360        for one_idx in &ones {
361            assert!(!zeros.contains(one_idx));
362        }
363    }
364
365    #[test]
366    fn test_all_false_static() {
367        let view = BitView::all_false();
368
369        let mut ones = Vec::new();
370        let mut zeros = Vec::new();
371        view.iter_ones(|idx| ones.push(idx));
372        view.iter_zeros(|idx| zeros.push(idx));
373
374        assert_eq!(ones, Vec::<usize>::new());
375        assert_eq!(zeros, (0..N).collect::<Vec<_>>());
376        assert_eq!(view.true_count(), 0);
377    }
378
379    #[test]
380    fn test_compatibility_with_mask_all_true() {
381        // Create corresponding BitView
382        let view = BitView::all_true();
383
384        // Collect ones from BitView
385        let mut bitview_ones = Vec::new();
386        view.iter_ones(|idx| bitview_ones.push(idx));
387
388        // Get indices from Mask (all indices for all_true mask)
389        let expected_indices: Vec<usize> = (0..N).collect();
390
391        assert_eq!(bitview_ones, expected_indices);
392        assert_eq!(view.true_count(), N);
393    }
394
395    #[test]
396    fn test_compatibility_with_mask_all_false() {
397        // Create corresponding BitView
398        let view = BitView::all_false();
399
400        // Collect ones from BitView
401        let mut bitview_ones = Vec::new();
402        view.iter_ones(|idx| bitview_ones.push(idx));
403
404        // Collect zeros from BitView
405        let mut bitview_zeros = Vec::new();
406        view.iter_zeros(|idx| bitview_zeros.push(idx));
407
408        assert_eq!(bitview_ones, Vec::<usize>::new());
409        assert_eq!(bitview_zeros, (0..N).collect::<Vec<_>>());
410        assert_eq!(view.true_count(), 0);
411    }
412
413    #[test]
414    fn test_compatibility_with_mask_from_indices() {
415        // Create a Mask from specific indices
416        let indices = vec![0, 10, 20, 63, 64, 100, 500, 1023];
417
418        // Create corresponding BitView
419        let mut bits = [0usize; N_WORDS];
420        for idx in &indices {
421            let word_idx = idx / 64;
422            let bit_idx = idx % 64;
423            bits[word_idx] |= 1usize << bit_idx;
424        }
425        let view = BitView::new(&bits);
426
427        // Collect ones from BitView
428        let mut bitview_ones = Vec::new();
429        view.iter_ones(|idx| bitview_ones.push(idx));
430
431        assert_eq!(bitview_ones, indices);
432        assert_eq!(view.true_count(), indices.len());
433    }
434
435    #[test]
436    fn test_compatibility_with_mask_slices() {
437        // Create a Mask from slices (ranges)
438        let slices = vec![(0, 10), (100, 110), (500, 510)];
439
440        // Create corresponding BitView
441        let mut bits = [0usize; N_WORDS];
442        for (start, end) in &slices {
443            for idx in *start..*end {
444                let word_idx = idx / 64;
445                let bit_idx = idx % 64;
446                bits[word_idx] |= 1usize << bit_idx;
447            }
448        }
449        let view = BitView::new(&bits);
450
451        // Collect ones from BitView
452        let mut bitview_ones = Vec::new();
453        view.iter_ones(|idx| bitview_ones.push(idx));
454
455        // Expected indices from slices
456        let mut expected_indices = Vec::new();
457        for (start, end) in &slices {
458            expected_indices.extend(*start..*end);
459        }
460
461        assert_eq!(bitview_ones, expected_indices);
462        assert_eq!(view.true_count(), expected_indices.len());
463    }
464
465    #[test]
466    fn test_mask_and_bitview_iter_match() {
467        // Create a pattern with alternating bits in first word
468        let mut bits = [0usize; N_WORDS];
469        bits[0] = 0xAAAAAAAAAAAAAAAA; // Alternating 1s and 0s
470        bits[1] = 0xFF00FF00FF00FF00; // Alternating bytes
471
472        let view = BitView::new(&bits);
473
474        // Collect indices from BitView
475        let mut bitview_ones = Vec::new();
476        view.iter_ones(|idx| bitview_ones.push(idx));
477
478        // Create Mask from the same indices
479        let mask = Mask::from_indices(N, bitview_ones.clone());
480
481        // Verify the mask returns the same indices
482        mask.iter_bools(|iter| {
483            let mask_bools: Vec<bool> = iter.collect();
484
485            // Check each bit matches
486            for i in 0..N {
487                let expected = bitview_ones.contains(&i);
488                assert_eq!(mask_bools[i], expected, "Mismatch at index {}", i);
489            }
490        });
491    }
492
493    #[test]
494    fn test_mask_and_bitview_all_true() {
495        let mask = Mask::AllTrue(5);
496
497        let vector = BitVector::true_until(5);
498
499        let view = vector.as_view();
500
501        // Collect indices from BitView
502        let mut bitview_ones = Vec::new();
503        view.iter_ones(|idx| bitview_ones.push(idx));
504
505        // Collect indices from BitView
506        let mask_ones = mask.iter_bools(|iter| {
507            iter.enumerate()
508                .filter(|(_, b)| *b)
509                .map(|(i, _)| i)
510                .collect::<Vec<_>>()
511        });
512
513        assert_eq!(bitview_ones, mask_ones);
514    }
515
516    #[test]
517    fn test_bitview_zeros_complement_mask() {
518        // Create a pattern
519        let mut bits = [0usize; N_WORDS];
520        bits[0] = 0b11110000111100001111000011110000;
521
522        let view = BitView::new(&bits);
523
524        // Collect ones and zeros from BitView
525        let mut bitview_ones = Vec::new();
526        let mut bitview_zeros = Vec::new();
527        view.iter_ones(|idx| bitview_ones.push(idx));
528        view.iter_zeros(|idx| bitview_zeros.push(idx));
529
530        // Create masks for ones and zeros
531        let ones_mask = Mask::from_indices(N, bitview_ones);
532        let zeros_mask = Mask::from_indices(N, bitview_zeros);
533
534        // Verify they are complements
535        ones_mask.iter_bools(|ones_iter| {
536            zeros_mask.iter_bools(|zeros_iter| {
537                let ones_bools: Vec<bool> = ones_iter.collect();
538                let zeros_bools: Vec<bool> = zeros_iter.collect();
539
540                for i in 0..N {
541                    // Each index should be either in ones or zeros, but not both
542                    assert_ne!(
543                        ones_bools[i], zeros_bools[i],
544                        "Index {} should be in exactly one set",
545                        i
546                    );
547                }
548            });
549        });
550    }
551}