vortex_array/pipeline/bits/
view.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Debug, Formatter};
5
6use bitvec::prelude::*;
7use vortex_error::{VortexError, vortex_err};
8
9use crate::pipeline::{N, N_WORDS};
10
11/// A borrowed fixed-size bit vector of length `N` bits, represented as an array of usize words.
12///
13/// Internally, it uses a [`BitArray`] to store the bits, but this crate has some
14/// performance foot-guns in cases where we can lean on better assumptions, and therefore we wrap
15/// it up for use within Vortex.
16/// Read-only view into a bit array for selection masking in pipeline operations.
17#[derive(Clone, Copy)]
18pub struct BitView<'a> {
19    bits: &'a BitArray<[usize; N_WORDS], Lsb0>,
20    true_count: usize,
21}
22
23impl Debug for BitView<'_> {
24    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25        f.debug_struct("BitView")
26            .field("true_count", &self.true_count)
27            .field("bits", &self.as_raw())
28            .finish()
29    }
30}
31
32impl BitView<'static> {
33    pub fn all_true() -> Self {
34        static ALL_TRUE: [usize; N_WORDS] = [usize::MAX; N_WORDS];
35        unsafe {
36            BitView::new_unchecked(
37                std::mem::transmute::<&[usize; N_WORDS], &BitArray<[usize; N_WORDS], Lsb0>>(
38                    &ALL_TRUE,
39                ),
40                N,
41            )
42        }
43    }
44
45    pub fn all_false() -> Self {
46        static ALL_FALSE: [usize; N_WORDS] = [0; N_WORDS];
47        unsafe {
48            BitView::new_unchecked(
49                std::mem::transmute::<&[usize; N_WORDS], &BitArray<[usize; N_WORDS], Lsb0>>(
50                    &ALL_FALSE,
51                ),
52                0,
53            )
54        }
55    }
56}
57
58impl<'a> BitView<'a> {
59    pub fn new(bits: &[usize; N_WORDS]) -> Self {
60        let true_count = bits.iter().map(|&word| word.count_ones() as usize).sum();
61        let bits: &BitArray<[usize; N_WORDS], Lsb0> = unsafe {
62            std::mem::transmute::<&[usize; N_WORDS], &BitArray<[usize; N_WORDS], Lsb0>>(bits)
63        };
64        BitView { bits, true_count }
65    }
66
67    pub(crate) unsafe fn new_unchecked(
68        bits: &'a BitArray<[usize; N_WORDS], Lsb0>,
69        true_count: usize,
70    ) -> Self {
71        BitView { bits, true_count }
72    }
73
74    /// Returns the number of `true` bits in the view.
75    pub fn true_count(&self) -> usize {
76        self.true_count
77    }
78
79    /// Runs the provided function `f` for each index of a `true` bit in the view.
80    pub fn iter_ones<F>(&self, mut f: F)
81    where
82        F: FnMut(usize),
83    {
84        match self.true_count {
85            0 => {}
86            N => (0..N).for_each(&mut f),
87            _ => {
88                let mut bit_idx = 0;
89                for mut raw in self.bits.into_inner() {
90                    while raw != 0 {
91                        let bit_pos = raw.trailing_zeros();
92                        f(bit_idx + bit_pos as usize);
93                        raw &= raw - 1; // Clear the bit at `bit_pos`
94                    }
95                    bit_idx += usize::BITS as usize;
96                }
97            }
98        }
99    }
100
101    /// Runs the provided function `f` for each index of a `true` bit in the view.
102    pub fn iter_zeros<F>(&self, mut f: F)
103    where
104        F: FnMut(usize),
105    {
106        match self.true_count {
107            0 => (0..N).for_each(&mut f),
108            N => {}
109            _ => {
110                let mut bit_idx = 0;
111                for mut raw in self.bits.into_inner() {
112                    while raw != usize::MAX {
113                        let bit_pos = raw.trailing_ones();
114                        f(bit_idx + bit_pos as usize);
115                        raw |= 1usize << bit_pos; // Set the zero bit to 1
116                    }
117                    bit_idx += usize::BITS as usize;
118                }
119            }
120        }
121    }
122
123    /// Runs the provided function `f` for each range of `true` bits in the view.
124    ///
125    /// The function `f` receives a tuple `(start, len)` where `start` is the index of the first
126    /// `true` bit and `len` is the number of consecutive `true` bits.
127    pub fn iter_slices<F>(&self, mut f: F)
128    where
129        F: FnMut((usize, usize)),
130    {
131        match self.true_count {
132            0 => {}
133            N => f((0, N)),
134            _ => {
135                let mut bit_idx = 0;
136                for mut raw in self.bits.into_inner() {
137                    let mut offset = 0;
138                    while raw != 0 {
139                        // Skip leading zeros first
140                        let zeros = raw.leading_zeros();
141                        offset += zeros;
142                        raw <<= zeros;
143
144                        if offset >= 64 {
145                            break;
146                        }
147
148                        // Count leading ones
149                        let ones = raw.leading_ones();
150                        if ones > 0 {
151                            f((bit_idx + offset as usize, ones as usize));
152                            offset += ones;
153                            raw <<= ones;
154                        }
155                    }
156                    bit_idx += usize::BITS as usize; // Move to next word
157                }
158            }
159        }
160    }
161
162    pub fn as_raw(&self) -> &[usize; N_WORDS] {
163        // It's actually remarkably hard to get a reference to the underlying array!
164        let raw = self.bits.as_raw_slice();
165        unsafe { &*(raw.as_ptr() as *const [usize; N_WORDS]) }
166    }
167}
168
169impl<'a> From<&'a [usize; N_WORDS]> for BitView<'a> {
170    fn from(value: &'a [usize; N_WORDS]) -> Self {
171        Self::new(value)
172    }
173}
174
175impl<'a> From<&'a BitArray<[usize; N_WORDS], Lsb0>> for BitView<'a> {
176    fn from(bits: &'a BitArray<[usize; N_WORDS], Lsb0>) -> Self {
177        BitView::new(unsafe {
178            std::mem::transmute::<&BitArray<[usize; N_WORDS]>, &[usize; N_WORDS]>(bits)
179        })
180    }
181}
182
183impl<'a> TryFrom<&'a BitSlice<usize, Lsb0>> for BitView<'a> {
184    type Error = VortexError;
185
186    fn try_from(value: &'a BitSlice<usize, Lsb0>) -> Result<Self, Self::Error> {
187        let bits: &BitArray<[usize; N_WORDS], Lsb0> = value
188            .try_into()
189            .map_err(|e| vortex_err!("Failed to convert BitSlice to BitArray: {}", e))?;
190        Ok(BitView::new(unsafe {
191            std::mem::transmute::<&BitArray<[usize; N_WORDS]>, &[usize; N_WORDS]>(bits)
192        }))
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use vortex_mask::Mask;
199
200    use super::*;
201    use crate::pipeline::bits::BitVector;
202
203    #[test]
204    fn test_iter_ones_empty() {
205        let bits = [0usize; N_WORDS];
206        let view = BitView::new(&bits);
207
208        let mut ones = Vec::new();
209        view.iter_ones(|idx| ones.push(idx));
210
211        assert_eq!(ones, Vec::<usize>::new());
212        assert_eq!(view.true_count(), 0);
213    }
214
215    #[test]
216    fn test_iter_ones_all_set() {
217        let view = BitView::all_true();
218
219        let mut ones = Vec::new();
220        view.iter_ones(|idx| ones.push(idx));
221
222        assert_eq!(ones.len(), N);
223        assert_eq!(ones, (0..N).collect::<Vec<_>>());
224        assert_eq!(view.true_count(), N);
225    }
226
227    #[test]
228    fn test_iter_zeros_empty() {
229        let bits = [0usize; N_WORDS];
230        let view = BitView::new(&bits);
231
232        let mut zeros = Vec::new();
233        view.iter_zeros(|idx| zeros.push(idx));
234
235        assert_eq!(zeros.len(), N);
236        assert_eq!(zeros, (0..N).collect::<Vec<_>>());
237    }
238
239    #[test]
240    fn test_iter_zeros_all_set() {
241        let view = BitView::all_true();
242
243        let mut zeros = Vec::new();
244        view.iter_zeros(|idx| zeros.push(idx));
245
246        assert_eq!(zeros, Vec::<usize>::new());
247    }
248
249    #[test]
250    fn test_iter_ones_single_bit() {
251        let mut bits = [0usize; N_WORDS];
252        bits[0] = 1; // Set bit 0 (LSB)
253        let view = BitView::new(&bits);
254
255        let mut ones = Vec::new();
256        view.iter_ones(|idx| ones.push(idx));
257
258        assert_eq!(ones, vec![0]);
259        assert_eq!(view.true_count(), 1);
260    }
261
262    #[test]
263    fn test_iter_zeros_single_bit_unset() {
264        let mut bits = [usize::MAX; N_WORDS];
265        bits[0] = usize::MAX ^ 1; // Clear bit 0 (LSB)
266        let view = BitView::new(&bits);
267
268        let mut zeros = Vec::new();
269        view.iter_zeros(|idx| zeros.push(idx));
270
271        assert_eq!(zeros, vec![0]);
272    }
273
274    #[test]
275    fn test_iter_ones_multiple_bits_first_word() {
276        let mut bits = [0usize; N_WORDS];
277        bits[0] = 0b1010101; // Set bits 0, 2, 4, 6
278        let view = BitView::new(&bits);
279
280        let mut ones = Vec::new();
281        view.iter_ones(|idx| ones.push(idx));
282
283        assert_eq!(ones, vec![0, 2, 4, 6]);
284        assert_eq!(view.true_count(), 4);
285    }
286
287    #[test]
288    fn test_iter_zeros_multiple_bits_first_word() {
289        let mut bits = [usize::MAX; N_WORDS];
290        bits[0] = !0b1010101; // Clear bits 0, 2, 4, 6
291        let view = BitView::new(&bits);
292
293        let mut zeros = Vec::new();
294        view.iter_zeros(|idx| zeros.push(idx));
295
296        assert_eq!(zeros, vec![0, 2, 4, 6]);
297    }
298
299    #[test]
300    fn test_iter_ones_across_words() {
301        let mut bits = [0usize; N_WORDS];
302        bits[0] = 1 << 63; // Set bit 63 of first word
303        bits[1] = 1; // Set bit 0 of second word (bit 64 overall)
304        bits[2] = 1 << 31; // Set bit 31 of third word (bit 159 overall)
305        let view = BitView::new(&bits);
306
307        let mut ones = Vec::new();
308        view.iter_ones(|idx| ones.push(idx));
309
310        assert_eq!(ones, vec![63, 64, 159]);
311        assert_eq!(view.true_count(), 3);
312    }
313
314    #[test]
315    fn test_iter_zeros_across_words() {
316        let mut bits = [usize::MAX; N_WORDS];
317        bits[0] = !(1 << 63); // Clear bit 63 of first word
318        bits[1] = !1; // Clear bit 0 of second word (bit 64 overall)
319        bits[2] = !(1 << 31); // Clear bit 31 of third word (bit 159 overall)
320        let view = BitView::new(&bits);
321
322        let mut zeros = Vec::new();
323        view.iter_zeros(|idx| zeros.push(idx));
324
325        assert_eq!(zeros, vec![63, 64, 159]);
326    }
327
328    #[test]
329    fn test_lsb_bit_ordering() {
330        let mut bits = [0usize; N_WORDS];
331        bits[0] = 0b11111111; // Set bits 0-7 (LSB ordering)
332        let view = BitView::new(&bits);
333
334        let mut ones = Vec::new();
335        view.iter_ones(|idx| ones.push(idx));
336
337        assert_eq!(ones, vec![0, 1, 2, 3, 4, 5, 6, 7]);
338        assert_eq!(view.true_count(), 8);
339    }
340
341    #[test]
342    fn test_iter_ones_and_zeros_complement() {
343        let mut bits = [0usize; N_WORDS];
344        bits[0] = 0xAAAAAAAAAAAAAAAA; // Alternating pattern
345        let view = BitView::new(&bits);
346
347        let mut ones = Vec::new();
348        let mut zeros = Vec::new();
349        view.iter_ones(|idx| ones.push(idx));
350        view.iter_zeros(|idx| zeros.push(idx));
351
352        // Check that ones and zeros together cover all indices
353        let mut all_indices = ones.clone();
354        all_indices.extend(&zeros);
355        all_indices.sort_unstable();
356
357        assert_eq!(all_indices, (0..N).collect::<Vec<_>>());
358
359        // Check they don't overlap
360        for one_idx in &ones {
361            assert!(!zeros.contains(one_idx));
362        }
363    }
364
365    #[test]
366    fn test_all_false_static() {
367        let view = BitView::all_false();
368
369        let mut ones = Vec::new();
370        let mut zeros = Vec::new();
371        view.iter_ones(|idx| ones.push(idx));
372        view.iter_zeros(|idx| zeros.push(idx));
373
374        assert_eq!(ones, Vec::<usize>::new());
375        assert_eq!(zeros, (0..N).collect::<Vec<_>>());
376        assert_eq!(view.true_count(), 0);
377    }
378
379    #[test]
380    fn test_compatibility_with_mask_all_true() {
381        // Create a Mask with all bits set
382        let mask = Mask::new_true(N);
383
384        // Create corresponding BitView
385        let view = BitView::all_true();
386
387        // Collect ones from BitView
388        let mut bitview_ones = Vec::new();
389        view.iter_ones(|idx| bitview_ones.push(idx));
390
391        // Get indices from Mask (all indices for all_true mask)
392        let expected_indices: Vec<usize> = (0..N).collect();
393
394        assert_eq!(bitview_ones, expected_indices);
395        assert_eq!(view.true_count(), N);
396    }
397
398    #[test]
399    fn test_compatibility_with_mask_all_false() {
400        // Create a Mask with no bits set
401        let mask = Mask::new_false(N);
402
403        // Create corresponding BitView
404        let view = BitView::all_false();
405
406        // Collect ones from BitView
407        let mut bitview_ones = Vec::new();
408        view.iter_ones(|idx| bitview_ones.push(idx));
409
410        // Collect zeros from BitView
411        let mut bitview_zeros = Vec::new();
412        view.iter_zeros(|idx| bitview_zeros.push(idx));
413
414        assert_eq!(bitview_ones, Vec::<usize>::new());
415        assert_eq!(bitview_zeros, (0..N).collect::<Vec<_>>());
416        assert_eq!(view.true_count(), 0);
417    }
418
419    #[test]
420    fn test_compatibility_with_mask_from_indices() {
421        // Create a Mask from specific indices
422        let indices = vec![0, 10, 20, 63, 64, 100, 500, 1023];
423        let mask = Mask::from_indices(N, indices.clone());
424
425        // Create corresponding BitView
426        let mut bits = [0usize; N_WORDS];
427        for idx in &indices {
428            let word_idx = idx / 64;
429            let bit_idx = idx % 64;
430            bits[word_idx] |= 1usize << bit_idx;
431        }
432        let view = BitView::new(&bits);
433
434        // Collect ones from BitView
435        let mut bitview_ones = Vec::new();
436        view.iter_ones(|idx| bitview_ones.push(idx));
437
438        assert_eq!(bitview_ones, indices);
439        assert_eq!(view.true_count(), indices.len());
440    }
441
442    #[test]
443    fn test_compatibility_with_mask_slices() {
444        // Create a Mask from slices (ranges)
445        let slices = vec![(0, 10), (100, 110), (500, 510)];
446        let mask = Mask::from_slices(N, slices.clone());
447
448        // Create corresponding BitView
449        let mut bits = [0usize; N_WORDS];
450        for (start, end) in &slices {
451            for idx in *start..*end {
452                let word_idx = idx / 64;
453                let bit_idx = idx % 64;
454                bits[word_idx] |= 1usize << bit_idx;
455            }
456        }
457        let view = BitView::new(&bits);
458
459        // Collect ones from BitView
460        let mut bitview_ones = Vec::new();
461        view.iter_ones(|idx| bitview_ones.push(idx));
462
463        // Expected indices from slices
464        let mut expected_indices = Vec::new();
465        for (start, end) in &slices {
466            expected_indices.extend(*start..*end);
467        }
468
469        assert_eq!(bitview_ones, expected_indices);
470        assert_eq!(view.true_count(), expected_indices.len());
471    }
472
473    #[test]
474    fn test_mask_and_bitview_iter_match() {
475        // Create a pattern with alternating bits in first word
476        let mut bits = [0usize; N_WORDS];
477        bits[0] = 0xAAAAAAAAAAAAAAAA; // Alternating 1s and 0s
478        bits[1] = 0xFF00FF00FF00FF00; // Alternating bytes
479
480        let view = BitView::new(&bits);
481
482        // Collect indices from BitView
483        let mut bitview_ones = Vec::new();
484        view.iter_ones(|idx| bitview_ones.push(idx));
485
486        // Create Mask from the same indices
487        let mask = Mask::from_indices(N, bitview_ones.clone());
488
489        // Verify the mask returns the same indices
490        mask.iter_bools(|iter| {
491            let mask_bools: Vec<bool> = iter.collect();
492
493            // Check each bit matches
494            for i in 0..N {
495                let expected = bitview_ones.contains(&i);
496                assert_eq!(mask_bools[i], expected, "Mismatch at index {}", i);
497            }
498        });
499    }
500
501    #[test]
502    fn test_mask_and_bitview_all_true() {
503        let mask = Mask::AllTrue(5);
504
505        let vector = BitVector::true_until(5);
506
507        let view = vector.as_view();
508
509        // Collect indices from BitView
510        let mut bitview_ones = Vec::new();
511        view.iter_ones(|idx| bitview_ones.push(idx));
512
513        // Collect indices from BitView
514        let mask_ones = mask.iter_bools(|iter| {
515            iter.enumerate()
516                .filter(|(_, b)| *b)
517                .map(|(i, _)| i)
518                .collect::<Vec<_>>()
519        });
520
521        assert_eq!(bitview_ones, mask_ones);
522    }
523
524    #[test]
525    fn test_bitview_zeros_complement_mask() {
526        // Create a pattern
527        let mut bits = [0usize; N_WORDS];
528        bits[0] = 0b11110000111100001111000011110000;
529
530        let view = BitView::new(&bits);
531
532        // Collect ones and zeros from BitView
533        let mut bitview_ones = Vec::new();
534        let mut bitview_zeros = Vec::new();
535        view.iter_ones(|idx| bitview_ones.push(idx));
536        view.iter_zeros(|idx| bitview_zeros.push(idx));
537
538        // Create masks for ones and zeros
539        let ones_mask = Mask::from_indices(N, bitview_ones);
540        let zeros_mask = Mask::from_indices(N, bitview_zeros);
541
542        // Verify they are complements
543        ones_mask.iter_bools(|ones_iter| {
544            zeros_mask.iter_bools(|zeros_iter| {
545                let ones_bools: Vec<bool> = ones_iter.collect();
546                let zeros_bools: Vec<bool> = zeros_iter.collect();
547
548                for i in 0..N {
549                    // Each index should be either in ones or zeros, but not both
550                    assert_ne!(
551                        ones_bools[i], zeros_bools[i],
552                        "Index {} should be in exactly one set",
553                        i
554                    );
555                }
556            });
557        });
558    }
559}