vortex_array/pipeline/bits/
view.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Debug, Formatter};
5
6use bitvec::prelude::*;
7use vortex_error::{VortexError, VortexResult, vortex_err};
8
9use crate::pipeline::{N, N_WORDS};
10
11/// A borrowed fixed-size bit vector of length `N` bits, represented as an array of usize words.
12///
13/// Internally, it uses a [`BitArray`] to store the bits, but this crate has some
14/// performance foot-guns in cases where we can lean on better assumptions, and therefore we wrap
15/// it up for use within Vortex.
16/// Read-only view into a bit array for selection masking in operator operations.
17#[derive(Clone, Copy)]
18pub struct BitView<'a> {
19    bits: &'a BitArray<[usize; N_WORDS], Lsb0>,
20    // TODO(ngates): we may want to expose this for optimizations.
21    // If set to Selection::Prefix, then all true bits are at the start of the array.
22    // selection: Selection,
23    true_count: usize,
24}
25
26impl Debug for BitView<'_> {
27    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("BitView")
29            .field("true_count", &self.true_count)
30            .field("bits", &self.as_raw())
31            .finish()
32    }
33}
34
35impl BitView<'static> {
36    pub fn all_true() -> Self {
37        static ALL_TRUE: [usize; N_WORDS] = [usize::MAX; N_WORDS];
38        unsafe {
39            BitView::new_unchecked(
40                std::mem::transmute::<&[usize; N_WORDS], &BitArray<[usize; N_WORDS], Lsb0>>(
41                    &ALL_TRUE,
42                ),
43                N,
44            )
45        }
46    }
47
48    pub fn all_false() -> Self {
49        static ALL_FALSE: [usize; N_WORDS] = [0; N_WORDS];
50        unsafe {
51            BitView::new_unchecked(
52                std::mem::transmute::<&[usize; N_WORDS], &BitArray<[usize; N_WORDS], Lsb0>>(
53                    &ALL_FALSE,
54                ),
55                0,
56            )
57        }
58    }
59}
60
61impl<'a> BitView<'a> {
62    pub fn new(bits: &[usize; N_WORDS]) -> Self {
63        let true_count = bits.iter().map(|&word| word.count_ones() as usize).sum();
64        let bits: &BitArray<[usize; N_WORDS], Lsb0> = unsafe {
65            std::mem::transmute::<&[usize; N_WORDS], &BitArray<[usize; N_WORDS], Lsb0>>(bits)
66        };
67        BitView { bits, true_count }
68    }
69
70    pub(crate) unsafe fn new_unchecked(
71        bits: &'a BitArray<[usize; N_WORDS], Lsb0>,
72        true_count: usize,
73    ) -> Self {
74        BitView { bits, true_count }
75    }
76
77    /// Returns the number of `true` bits in the view.
78    pub fn true_count(&self) -> usize {
79        self.true_count
80    }
81
82    /// Runs the provided function `f` for each index of a `true` bit in the view.
83    pub fn iter_ones<F>(&self, mut f: F)
84    where
85        F: FnMut(usize),
86    {
87        match self.true_count {
88            0 => {}
89            N => (0..N).for_each(&mut f),
90            _ => {
91                let mut bit_idx = 0;
92                for mut raw in self.bits.into_inner() {
93                    while raw != 0 {
94                        let bit_pos = raw.trailing_zeros();
95                        f(bit_idx + bit_pos as usize);
96                        raw &= raw - 1; // Clear the bit at `bit_pos`
97                    }
98                    bit_idx += usize::BITS as usize;
99                }
100            }
101        }
102    }
103
104    /// Runs the provided function `f` for each index of a `true` bit in the view.
105    pub fn try_iter_ones<F>(&self, mut f: F) -> VortexResult<()>
106    where
107        F: FnMut(usize) -> VortexResult<()>,
108    {
109        match self.true_count {
110            0 => Ok(()),
111            N => {
112                for i in 0..N {
113                    f(i)?;
114                }
115                Ok(())
116            }
117            _ => {
118                let mut bit_idx = 0;
119                for mut raw in self.bits.into_inner() {
120                    while raw != 0 {
121                        let bit_pos = raw.trailing_zeros();
122                        f(bit_idx + bit_pos as usize)?;
123                        raw &= raw - 1; // Clear the bit at `bit_pos`
124                    }
125                    bit_idx += usize::BITS as usize;
126                }
127                Ok(())
128            }
129        }
130    }
131
132    /// Runs the provided function `f` for each index of a `true` bit in the view.
133    pub fn iter_zeros<F>(&self, mut f: F)
134    where
135        F: FnMut(usize),
136    {
137        match self.true_count {
138            0 => (0..N).for_each(&mut f),
139            N => {}
140            _ => {
141                let mut bit_idx = 0;
142                for mut raw in self.bits.into_inner() {
143                    while raw != usize::MAX {
144                        let bit_pos = raw.trailing_ones();
145                        f(bit_idx + bit_pos as usize);
146                        raw |= 1usize << bit_pos; // Set the zero bit to 1
147                    }
148                    bit_idx += usize::BITS as usize;
149                }
150            }
151        }
152    }
153
154    /// Runs the provided function `f` for each range of `true` bits in the view.
155    ///
156    /// The function `f` receives a tuple `(start, len)` where `start` is the index of the first
157    /// `true` bit and `len` is the number of consecutive `true` bits.
158    pub fn iter_slices<F>(&self, mut f: F)
159    where
160        F: FnMut((usize, usize)),
161    {
162        match self.true_count {
163            0 => {}
164            N => f((0, N)),
165            _ => {
166                let mut bit_idx = 0;
167                for mut raw in self.bits.into_inner() {
168                    let mut offset = 0;
169                    while raw != 0 {
170                        // Skip leading zeros first
171                        let zeros = raw.leading_zeros();
172                        offset += zeros;
173                        raw <<= zeros;
174
175                        if offset >= 64 {
176                            break;
177                        }
178
179                        // Count leading ones
180                        let ones = raw.leading_ones();
181                        if ones > 0 {
182                            f((bit_idx + offset as usize, ones as usize));
183                            offset += ones;
184                            raw <<= ones;
185                        }
186                    }
187                    bit_idx += usize::BITS as usize; // Move to next word
188                }
189            }
190        }
191    }
192
193    pub fn as_raw(&self) -> &[usize; N_WORDS] {
194        // It's actually remarkably hard to get a reference to the underlying array!
195        let raw = self.bits.as_raw_slice();
196        unsafe { &*(raw.as_ptr() as *const [usize; N_WORDS]) }
197    }
198}
199
200impl<'a> From<&'a [usize; N_WORDS]> for BitView<'a> {
201    fn from(value: &'a [usize; N_WORDS]) -> Self {
202        Self::new(value)
203    }
204}
205
206impl<'a> From<&'a BitArray<[usize; N_WORDS], Lsb0>> for BitView<'a> {
207    fn from(bits: &'a BitArray<[usize; N_WORDS], Lsb0>) -> Self {
208        BitView::new(unsafe {
209            std::mem::transmute::<&BitArray<[usize; N_WORDS]>, &[usize; N_WORDS]>(bits)
210        })
211    }
212}
213
214impl<'a> TryFrom<&'a BitSlice<usize, Lsb0>> for BitView<'a> {
215    type Error = VortexError;
216
217    fn try_from(value: &'a BitSlice<usize, Lsb0>) -> Result<Self, Self::Error> {
218        let bits: &BitArray<[usize; N_WORDS], Lsb0> = value
219            .try_into()
220            .map_err(|e| vortex_err!("Failed to convert BitSlice to BitArray: {}", e))?;
221        Ok(BitView::new(unsafe {
222            std::mem::transmute::<&BitArray<[usize; N_WORDS]>, &[usize; N_WORDS]>(bits)
223        }))
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use vortex_mask::Mask;
230
231    use super::*;
232    use crate::pipeline::bits::BitVector;
233
234    #[test]
235    fn test_iter_ones_empty() {
236        let bits = [0usize; N_WORDS];
237        let view = BitView::new(&bits);
238
239        let mut ones = Vec::new();
240        view.iter_ones(|idx| ones.push(idx));
241
242        assert_eq!(ones, Vec::<usize>::new());
243        assert_eq!(view.true_count(), 0);
244    }
245
246    #[test]
247    fn test_iter_ones_all_set() {
248        let view = BitView::all_true();
249
250        let mut ones = Vec::new();
251        view.iter_ones(|idx| ones.push(idx));
252
253        assert_eq!(ones.len(), N);
254        assert_eq!(ones, (0..N).collect::<Vec<_>>());
255        assert_eq!(view.true_count(), N);
256    }
257
258    #[test]
259    fn test_iter_zeros_empty() {
260        let bits = [0usize; N_WORDS];
261        let view = BitView::new(&bits);
262
263        let mut zeros = Vec::new();
264        view.iter_zeros(|idx| zeros.push(idx));
265
266        assert_eq!(zeros.len(), N);
267        assert_eq!(zeros, (0..N).collect::<Vec<_>>());
268    }
269
270    #[test]
271    fn test_iter_zeros_all_set() {
272        let view = BitView::all_true();
273
274        let mut zeros = Vec::new();
275        view.iter_zeros(|idx| zeros.push(idx));
276
277        assert_eq!(zeros, Vec::<usize>::new());
278    }
279
280    #[test]
281    fn test_iter_ones_single_bit() {
282        let mut bits = [0usize; N_WORDS];
283        bits[0] = 1; // Set bit 0 (LSB)
284        let view = BitView::new(&bits);
285
286        let mut ones = Vec::new();
287        view.iter_ones(|idx| ones.push(idx));
288
289        assert_eq!(ones, vec![0]);
290        assert_eq!(view.true_count(), 1);
291    }
292
293    #[test]
294    fn test_iter_zeros_single_bit_unset() {
295        let mut bits = [usize::MAX; N_WORDS];
296        bits[0] = usize::MAX ^ 1; // Clear bit 0 (LSB)
297        let view = BitView::new(&bits);
298
299        let mut zeros = Vec::new();
300        view.iter_zeros(|idx| zeros.push(idx));
301
302        assert_eq!(zeros, vec![0]);
303    }
304
305    #[test]
306    fn test_iter_ones_multiple_bits_first_word() {
307        let mut bits = [0usize; N_WORDS];
308        bits[0] = 0b1010101; // Set bits 0, 2, 4, 6
309        let view = BitView::new(&bits);
310
311        let mut ones = Vec::new();
312        view.iter_ones(|idx| ones.push(idx));
313
314        assert_eq!(ones, vec![0, 2, 4, 6]);
315        assert_eq!(view.true_count(), 4);
316    }
317
318    #[test]
319    fn test_iter_zeros_multiple_bits_first_word() {
320        let mut bits = [usize::MAX; N_WORDS];
321        bits[0] = !0b1010101; // Clear bits 0, 2, 4, 6
322        let view = BitView::new(&bits);
323
324        let mut zeros = Vec::new();
325        view.iter_zeros(|idx| zeros.push(idx));
326
327        assert_eq!(zeros, vec![0, 2, 4, 6]);
328    }
329
330    #[test]
331    fn test_iter_ones_across_words() {
332        let mut bits = [0usize; N_WORDS];
333        bits[0] = 1 << 63; // Set bit 63 of first word
334        bits[1] = 1; // Set bit 0 of second word (bit 64 overall)
335        bits[2] = 1 << 31; // Set bit 31 of third word (bit 159 overall)
336        let view = BitView::new(&bits);
337
338        let mut ones = Vec::new();
339        view.iter_ones(|idx| ones.push(idx));
340
341        assert_eq!(ones, vec![63, 64, 159]);
342        assert_eq!(view.true_count(), 3);
343    }
344
345    #[test]
346    fn test_iter_zeros_across_words() {
347        let mut bits = [usize::MAX; N_WORDS];
348        bits[0] = !(1 << 63); // Clear bit 63 of first word
349        bits[1] = !1; // Clear bit 0 of second word (bit 64 overall)
350        bits[2] = !(1 << 31); // Clear bit 31 of third word (bit 159 overall)
351        let view = BitView::new(&bits);
352
353        let mut zeros = Vec::new();
354        view.iter_zeros(|idx| zeros.push(idx));
355
356        assert_eq!(zeros, vec![63, 64, 159]);
357    }
358
359    #[test]
360    fn test_lsb_bit_ordering() {
361        let mut bits = [0usize; N_WORDS];
362        bits[0] = 0b11111111; // Set bits 0-7 (LSB ordering)
363        let view = BitView::new(&bits);
364
365        let mut ones = Vec::new();
366        view.iter_ones(|idx| ones.push(idx));
367
368        assert_eq!(ones, vec![0, 1, 2, 3, 4, 5, 6, 7]);
369        assert_eq!(view.true_count(), 8);
370    }
371
372    #[test]
373    fn test_iter_ones_and_zeros_complement() {
374        let mut bits = [0usize; N_WORDS];
375        bits[0] = 0xAAAAAAAAAAAAAAAA; // Alternating pattern
376        let view = BitView::new(&bits);
377
378        let mut ones = Vec::new();
379        let mut zeros = Vec::new();
380        view.iter_ones(|idx| ones.push(idx));
381        view.iter_zeros(|idx| zeros.push(idx));
382
383        // Check that ones and zeros together cover all indices
384        let mut all_indices = ones.clone();
385        all_indices.extend(&zeros);
386        all_indices.sort_unstable();
387
388        assert_eq!(all_indices, (0..N).collect::<Vec<_>>());
389
390        // Check they don't overlap
391        for one_idx in &ones {
392            assert!(!zeros.contains(one_idx));
393        }
394    }
395
396    #[test]
397    fn test_all_false_static() {
398        let view = BitView::all_false();
399
400        let mut ones = Vec::new();
401        let mut zeros = Vec::new();
402        view.iter_ones(|idx| ones.push(idx));
403        view.iter_zeros(|idx| zeros.push(idx));
404
405        assert_eq!(ones, Vec::<usize>::new());
406        assert_eq!(zeros, (0..N).collect::<Vec<_>>());
407        assert_eq!(view.true_count(), 0);
408    }
409
410    #[test]
411    fn test_compatibility_with_mask_all_true() {
412        // Create corresponding BitView
413        let view = BitView::all_true();
414
415        // Collect ones from BitView
416        let mut bitview_ones = Vec::new();
417        view.iter_ones(|idx| bitview_ones.push(idx));
418
419        // Get indices from Mask (all indices for all_true mask)
420        let expected_indices: Vec<usize> = (0..N).collect();
421
422        assert_eq!(bitview_ones, expected_indices);
423        assert_eq!(view.true_count(), N);
424    }
425
426    #[test]
427    fn test_compatibility_with_mask_all_false() {
428        // Create corresponding BitView
429        let view = BitView::all_false();
430
431        // Collect ones from BitView
432        let mut bitview_ones = Vec::new();
433        view.iter_ones(|idx| bitview_ones.push(idx));
434
435        // Collect zeros from BitView
436        let mut bitview_zeros = Vec::new();
437        view.iter_zeros(|idx| bitview_zeros.push(idx));
438
439        assert_eq!(bitview_ones, Vec::<usize>::new());
440        assert_eq!(bitview_zeros, (0..N).collect::<Vec<_>>());
441        assert_eq!(view.true_count(), 0);
442    }
443
444    #[test]
445    fn test_compatibility_with_mask_from_indices() {
446        // Create a Mask from specific indices
447        let indices = vec![0, 10, 20, 63, 64, 100, 500, 1023];
448
449        // Create corresponding BitView
450        let mut bits = [0usize; N_WORDS];
451        for idx in &indices {
452            let word_idx = idx / 64;
453            let bit_idx = idx % 64;
454            bits[word_idx] |= 1usize << bit_idx;
455        }
456        let view = BitView::new(&bits);
457
458        // Collect ones from BitView
459        let mut bitview_ones = Vec::new();
460        view.iter_ones(|idx| bitview_ones.push(idx));
461
462        assert_eq!(bitview_ones, indices);
463        assert_eq!(view.true_count(), indices.len());
464    }
465
466    #[test]
467    fn test_compatibility_with_mask_slices() {
468        // Create a Mask from slices (ranges)
469        let slices = vec![(0, 10), (100, 110), (500, 510)];
470
471        // Create corresponding BitView
472        let mut bits = [0usize; N_WORDS];
473        for (start, end) in &slices {
474            for idx in *start..*end {
475                let word_idx = idx / 64;
476                let bit_idx = idx % 64;
477                bits[word_idx] |= 1usize << bit_idx;
478            }
479        }
480        let view = BitView::new(&bits);
481
482        // Collect ones from BitView
483        let mut bitview_ones = Vec::new();
484        view.iter_ones(|idx| bitview_ones.push(idx));
485
486        // Expected indices from slices
487        let mut expected_indices = Vec::new();
488        for (start, end) in &slices {
489            expected_indices.extend(*start..*end);
490        }
491
492        assert_eq!(bitview_ones, expected_indices);
493        assert_eq!(view.true_count(), expected_indices.len());
494    }
495
496    #[test]
497    fn test_mask_and_bitview_iter_match() {
498        // Create a pattern with alternating bits in first word
499        let mut bits = [0usize; N_WORDS];
500        bits[0] = 0xAAAAAAAAAAAAAAAA; // Alternating 1s and 0s
501        bits[1] = 0xFF00FF00FF00FF00; // Alternating bytes
502
503        let view = BitView::new(&bits);
504
505        // Collect indices from BitView
506        let mut bitview_ones = Vec::new();
507        view.iter_ones(|idx| bitview_ones.push(idx));
508
509        // Create Mask from the same indices
510        let mask = Mask::from_indices(N, bitview_ones.clone());
511
512        // Verify the mask returns the same indices
513        mask.iter_bools(|iter| {
514            let mask_bools: Vec<bool> = iter.collect();
515
516            // Check each bit matches
517            for i in 0..N {
518                let expected = bitview_ones.contains(&i);
519                assert_eq!(mask_bools[i], expected, "Mismatch at index {}", i);
520            }
521        });
522    }
523
524    #[test]
525    fn test_mask_and_bitview_all_true() {
526        let mask = Mask::AllTrue(5);
527
528        let vector = BitVector::true_until(5);
529
530        let view = vector.as_view();
531
532        // Collect indices from BitView
533        let mut bitview_ones = Vec::new();
534        view.iter_ones(|idx| bitview_ones.push(idx));
535
536        // Collect indices from BitView
537        let mask_ones = mask.iter_bools(|iter| {
538            iter.enumerate()
539                .filter(|(_, b)| *b)
540                .map(|(i, _)| i)
541                .collect::<Vec<_>>()
542        });
543
544        assert_eq!(bitview_ones, mask_ones);
545    }
546
547    #[test]
548    fn test_bitview_zeros_complement_mask() {
549        // Create a pattern
550        let mut bits = [0usize; N_WORDS];
551        bits[0] = 0b11110000111100001111000011110000;
552
553        let view = BitView::new(&bits);
554
555        // Collect ones and zeros from BitView
556        let mut bitview_ones = Vec::new();
557        let mut bitview_zeros = Vec::new();
558        view.iter_ones(|idx| bitview_ones.push(idx));
559        view.iter_zeros(|idx| bitview_zeros.push(idx));
560
561        // Create masks for ones and zeros
562        let ones_mask = Mask::from_indices(N, bitview_ones);
563        let zeros_mask = Mask::from_indices(N, bitview_zeros);
564
565        // Verify they are complements
566        ones_mask.iter_bools(|ones_iter| {
567            zeros_mask.iter_bools(|zeros_iter| {
568                let ones_bools: Vec<bool> = ones_iter.collect();
569                let zeros_bools: Vec<bool> = zeros_iter.collect();
570
571                for i in 0..N {
572                    // Each index should be either in ones or zeros, but not both
573                    assert_ne!(
574                        ones_bools[i], zeros_bools[i],
575                        "Index {} should be in exactly one set",
576                        i
577                    );
578                }
579            });
580        });
581    }
582}