vortex_buffer/bit/
view.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::borrow::Cow;
5use std::fmt::Debug;
6use std::fmt::Formatter;
7use std::marker::PhantomData;
8
9use vortex_error::VortexResult;
10
11use crate::BitBuffer;
12use crate::BitBufferMut;
13
14/// A borrowed fixed-size mask of length `N` bits.
15///
16/// Since const generic expressions are not yet stable, we instead define the type over the
17/// number of bytes `NB`, and compute `N` as `NB * 8`.
18///
19/// This struct is designed to provide a view over a Vortex [`BitBuffer`], therefore the
20/// bit-ordering is LSB0 (least-significant-bit first).
21///
22/// Note that [`BitView`] does not support an offset. Therefore, bits are assumed to start at
23/// index and end at index `N - 1`.
24pub struct BitView<'a, const NB: usize> {
25    bits: Cow<'a, [u8; NB]>,
26    // TODO(ngates): we may want to expose this for optimizations.
27    // If set to Selection::Prefix, then all true bits are at the start of the array.
28    // selection: Selection,
29    true_count: usize,
30}
31
32impl<const NB: usize> Debug for BitView<'_, NB> {
33    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct(&format!("BitView[{}]", NB * 8))
35            .field("true_count", &self.true_count)
36            .field("bits", &self.as_raw())
37            .finish()
38    }
39}
40
41impl<const NB: usize> BitView<'static, NB> {
42    const ALL_TRUE: [u8; NB] = [u8::MAX; NB];
43    const ALL_FALSE: [u8; NB] = [0; NB];
44
45    /// Creates a [`BitView`] with all bits set to `true`.
46    pub const fn all_true() -> Self {
47        unsafe { BitView::new_unchecked(&Self::ALL_TRUE, NB * 8) }
48    }
49
50    /// Creates a [`BitView`] with all bits set to `false`.
51    pub const fn all_false() -> Self {
52        unsafe { BitView::new_unchecked(&Self::ALL_FALSE, 0) }
53    }
54}
55
56impl<'a, const NB: usize> BitView<'a, NB> {
57    /// The number of bits in the view.
58    pub const N: usize = NB * 8;
59    /// The number of machine words in the view.
60    pub const N_WORDS: usize = NB * 8 / (usize::BITS as usize);
61
62    const _ASSERT_MULTIPLE_OF_8: () = assert!(
63        NB % 8 == 0,
64        "NB must be a multiple of 8 for N to be a multiple of 64"
65    );
66
67    /// Creates a [`BitView`] from raw bits, computing the true count.
68    pub fn new(bits: &'a [u8; NB]) -> Self {
69        let ptr = bits.as_ptr().cast::<usize>();
70        let true_count = (0..Self::N_WORDS)
71            .map(|idx| unsafe { ptr.add(idx).read_unaligned().count_ones() as usize })
72            .sum();
73        BitView {
74            bits: Cow::Borrowed(bits),
75            true_count,
76        }
77    }
78
79    /// Creates a [`BitView`] from owned raw bits.
80    pub fn new_owned(bits: [u8; NB]) -> Self {
81        let ptr = bits.as_ptr().cast::<usize>();
82        let true_count = (0..Self::N_WORDS)
83            .map(|idx| unsafe { ptr.add(idx).read_unaligned().count_ones() as usize })
84            .sum();
85        BitView {
86            bits: Cow::Owned(bits),
87            true_count,
88        }
89    }
90
91    /// Creates a [`BitView`] from raw bits and a known true count.
92    ///
93    /// # Safety
94    ///
95    /// The caller must ensure that `true_count` is correct for the provided `bits`.
96    pub(crate) const unsafe fn new_unchecked(bits: &'a [u8; NB], true_count: usize) -> Self {
97        BitView {
98            bits: Cow::Borrowed(bits),
99            true_count,
100        }
101    }
102
103    /// Creates a [`BitView`] from a byte slice.
104    ///
105    /// # Panics
106    ///
107    /// If the length of the slice is not equal to `NB`.
108    pub fn from_slice(bits: &'a [u8]) -> Self {
109        assert_eq!(bits.len(), NB);
110        let bits_array = unsafe { &*(bits.as_ptr() as *const [u8; NB]) };
111        BitView::new(bits_array)
112    }
113
114    /// Creates a [`BitView`] from a mutable byte array, populating it with the requested prefix
115    /// of `true` bits.
116    pub fn with_prefix(n_true: usize) -> Self {
117        assert!(n_true <= Self::N);
118
119        // We're going to own our own array of bits
120        let mut bits = [0u8; NB];
121
122        // All-true words first
123        let n_full_words = n_true / (usize::BITS as usize);
124        let remaining_bits = n_true % (usize::BITS as usize);
125
126        let ptr = bits.as_mut_ptr().cast::<usize>();
127
128        // Fill the all-true words
129        for word_idx in 0..n_full_words {
130            unsafe { ptr.add(word_idx).write_unaligned(usize::MAX) };
131        }
132
133        // Fill the remaining bits in the next word
134        if remaining_bits > 0 {
135            let mask = (1usize << remaining_bits) - 1;
136            unsafe { ptr.add(n_full_words).write_unaligned(mask) };
137        }
138
139        Self {
140            bits: Cow::Owned(bits),
141            true_count: n_true,
142        }
143    }
144
145    /// Returns the number of `true` bits in the view.
146    pub fn true_count(&self) -> usize {
147        self.true_count
148    }
149
150    /// Iterate the [`BitView`] in fixed-size words.
151    ///
152    /// The words are loaded using unaligned loads to ensure correct bit ordering.
153    /// For example, bit 0 is located in `word & 1 << 0`, bit 63 is located in `word & 1 << 63`,
154    /// assuming the word size is 64 bits.
155    pub fn iter_words(&self) -> impl Iterator<Item = usize> + '_ {
156        let ptr = self.bits.as_ptr().cast::<usize>();
157        // We use constant N_WORDS to trigger loop unrolling.
158        (0..Self::N_WORDS).map(move |idx| unsafe { ptr.add(idx).read_unaligned() })
159    }
160
161    /// Runs the provided function `f` for each index of a `true` bit in the view.
162    pub fn iter_ones<F>(&self, mut f: F)
163    where
164        F: FnMut(usize),
165    {
166        match self.true_count {
167            0 => {}
168            n if n == Self::N => (0..Self::N).for_each(&mut f),
169            _ => {
170                let mut bit_idx = 0;
171                for mut raw in self.iter_words() {
172                    while raw != 0 {
173                        let bit_pos = raw.trailing_zeros();
174                        f(bit_idx + bit_pos as usize);
175                        raw &= raw - 1; // Clear the bit at `bit_pos`
176                    }
177                    bit_idx += usize::BITS as usize;
178                }
179            }
180        }
181    }
182
183    /// Runs the provided function `f` for each index of a `true` bit in the view.
184    pub fn try_iter_ones<F>(&self, mut f: F) -> VortexResult<()>
185    where
186        F: FnMut(usize) -> VortexResult<()>,
187    {
188        match self.true_count {
189            0 => Ok(()),
190            n if n == Self::N => {
191                for i in 0..Self::N {
192                    f(i)?;
193                }
194                Ok(())
195            }
196            _ => {
197                let mut bit_idx = 0;
198                for mut raw in self.iter_words() {
199                    while raw != 0 {
200                        let bit_pos = raw.trailing_zeros();
201                        f(bit_idx + bit_pos as usize)?;
202                        raw &= raw - 1; // Clear the bit at `bit_pos`
203                    }
204                    bit_idx += usize::BITS as usize;
205                }
206                Ok(())
207            }
208        }
209    }
210
211    /// Runs the provided function `f` for each index of a `true` bit in the view.
212    pub fn iter_zeros<F>(&self, mut f: F)
213    where
214        F: FnMut(usize),
215    {
216        match self.true_count {
217            0 => (0..Self::N).for_each(&mut f),
218            n if n == Self::N => {}
219            _ => {
220                let mut bit_idx = 0;
221                for mut raw in self.iter_words() {
222                    while raw != usize::MAX {
223                        let bit_pos = raw.trailing_ones();
224                        f(bit_idx + bit_pos as usize);
225                        raw |= 1usize << bit_pos; // Set the zero bit to 1
226                    }
227                    bit_idx += usize::BITS as usize;
228                }
229            }
230        }
231    }
232
233    /// Runs the provided function `f` for each range of `true` bits in the view.
234    ///
235    /// The function `f` receives a [`BitSlice`] containing the inclusive `start` bit as well as
236    /// the length.
237    ///
238    /// FIXME(ngates): this is still broken.
239    pub fn iter_slices<F>(&self, mut f: F)
240    where
241        F: FnMut(BitSlice),
242    {
243        if self.true_count == 0 {
244            return;
245        }
246
247        let mut abs_bit_offset: usize = 0; // Absolute bit index of the *current* word being processed
248        let mut slice_start_bit: usize = 0; // Absolute start index of the run of 1s being tracked
249        let mut slice_length: usize = 0; // Accumulated length of the run of 1s
250
251        for mut word in self.iter_words() {
252            match word {
253                0 => {
254                    // If a slice was being tracked, the run ends at the start of this word.
255                    if slice_length > 0 {
256                        f(BitSlice {
257                            start: slice_start_bit,
258                            len: slice_length,
259                        });
260                        slice_length = 0;
261                    }
262                }
263                usize::MAX => {
264                    // If a slice was not already open, it starts at the beginning of this word.
265                    if slice_length == 0 {
266                        slice_start_bit = abs_bit_offset;
267                    }
268                    // Extend the length by a full word (64 bits).
269                    slice_length += usize::BITS as usize;
270                }
271                _ => {
272                    while word != 0 {
273                        // Find the first set bit (start of a run of 1s)
274                        let zeros = word.trailing_zeros() as usize;
275
276                        // If a run was open, and we hit a zero gap, report the finished slice
277                        if slice_length > 0 && zeros > 0 {
278                            f(BitSlice {
279                                start: slice_start_bit,
280                                len: slice_length,
281                            });
282                            slice_length = 0; // Reset state for a new slice
283                        }
284
285                        // Advance past the zeros
286                        word >>= zeros;
287
288                        if word == 0 {
289                            break;
290                        }
291
292                        // Find the contiguous ones (the length of the current run segment)
293                        let ones = word.trailing_ones() as usize;
294
295                        // If slice_length is 0, we found the *absolute* start of a new slice.
296                        if slice_length == 0 {
297                            // Calculate the bit index within the *entire* mask where this run starts
298                            let current_word_idx = abs_bit_offset + zeros;
299                            slice_start_bit = current_word_idx;
300                        }
301
302                        // Accumulate the length of the slice
303                        slice_length += ones;
304
305                        // Advance past the ones
306                        word >>= ones;
307                    }
308                }
309            }
310
311            abs_bit_offset += usize::BITS as usize;
312        }
313
314        if slice_length > 0 {
315            f(BitSlice {
316                start: slice_start_bit,
317                len: slice_length,
318            });
319        }
320    }
321
322    /// Returns the raw bits of the view.
323    pub fn as_raw(&self) -> &[u8; NB] {
324        self.bits.as_ref()
325    }
326}
327
328/// A slice of bits within a [`BitBuffer`].
329///
330/// We use this struct to avoid a common mistake of assuming the slices represent (start, end) ranges,
331pub struct BitSlice {
332    /// The starting bit index of the slice.
333    pub start: usize,
334    /// The length of the slice in bits.
335    pub len: usize,
336}
337
338impl BitBuffer {
339    /// Iterate the buffer as [`BitView`]s of size `NB` where the number of bits in each view
340    /// is `NB * 8`.
341    ///
342    /// The final chunk will be filled with unset padding bits if the bit buffer's length is not
343    /// a multiple of `N`.
344    ///
345    /// The number of bits `N` must be a multiple of 64.
346    ///
347    /// # Panics
348    ///
349    /// If the bit offset is not zero
350    pub fn iter_bit_views<const NB: usize>(&self) -> impl Iterator<Item = BitView<'_, NB>> + '_ {
351        assert_eq!(
352            self.offset(),
353            0,
354            "BitView iteration requires zero bit offset"
355        );
356        BitViewIterator::new(self.inner().as_ref(), self.len())
357    }
358}
359
360impl BitBufferMut {
361    /// Iterate the buffer as [`BitView`]s of size `NB` where the number of bits in each view
362    /// is `NB * 8`.
363    ///
364    /// The final chunk will be filled with unset padding bits if the bit buffer's length is not
365    /// a multiple of `N`.
366    ///
367    /// The number of bits `N` must be a multiple of 64.
368    ///
369    /// # Panics
370    ///
371    /// If the bit offset is not zero
372    pub fn iter_bit_views<const NB: usize>(&self) -> impl Iterator<Item = BitView<'_, NB>> + '_ {
373        assert_eq!(
374            self.offset(),
375            0,
376            "BitView iteration requires zero bit offset"
377        );
378        BitViewIterator::new(self.inner().as_ref(), self.len())
379    }
380}
381
382/// Iterator over fixed-size [`BitView`]s within a byte slice.
383pub(super) struct BitViewIterator<'a, const NB: usize> {
384    bits: &'a [u8],
385    // The index of the view to be returned next
386    view_idx: usize,
387    // The total number of views
388    n_views: usize,
389    // Phantom to capture `NB`
390    _phantom: PhantomData<[u8; NB]>,
391}
392
393impl<'a, const NB: usize> BitViewIterator<'a, NB> {
394    /// Create a new [`BitViewIterator`].
395    fn new(bits: &'a [u8], len: usize) -> Self {
396        debug_assert_eq!(len.div_ceil(8), bits.len());
397        let n_views = bits.len().div_ceil(NB);
398        BitViewIterator {
399            bits,
400            view_idx: 0,
401            n_views,
402            _phantom: PhantomData,
403        }
404    }
405}
406
407impl<'a, const NB: usize> Iterator for BitViewIterator<'a, NB> {
408    type Item = BitView<'a, NB>;
409
410    fn next(&mut self) -> Option<Self::Item> {
411        if self.view_idx == self.n_views {
412            return None;
413        }
414
415        let start_byte = self.view_idx * NB;
416        let end_byte = start_byte + NB;
417
418        let bits = if end_byte <= self.bits.len() {
419            // Full view from the original bits
420            BitView::from_slice(&self.bits[start_byte..end_byte])
421        } else {
422            // Partial view, copy to scratch
423            let remaining_bytes = self.bits.len() - start_byte;
424            let mut remaining = [0u8; NB];
425            remaining[..remaining_bytes].copy_from_slice(&self.bits[start_byte..]);
426            BitView::new_owned(remaining)
427        };
428
429        self.view_idx += 1;
430        Some(bits)
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    const NB: usize = 128; // Number of bytes
439    const N: usize = NB * 8; // Number of bits
440
441    #[test]
442    fn test_iter_ones_empty() {
443        let bits = [0; NB];
444        let view = BitView::<NB>::new(&bits);
445
446        let mut ones = Vec::new();
447        view.iter_ones(|idx| ones.push(idx));
448
449        assert_eq!(ones, Vec::<usize>::new());
450        assert_eq!(view.true_count(), 0);
451    }
452
453    #[test]
454    fn test_iter_ones_all_set() {
455        let view = BitView::<NB>::all_true();
456
457        let mut ones = Vec::new();
458        view.iter_ones(|idx| ones.push(idx));
459
460        assert_eq!(ones.len(), N);
461        assert_eq!(ones, (0..N).collect::<Vec<_>>());
462        assert_eq!(view.true_count(), N);
463    }
464
465    #[test]
466    fn test_iter_zeros_empty() {
467        let bits = [0; NB];
468        let view = BitView::<NB>::new(&bits);
469
470        let mut zeros = Vec::new();
471        view.iter_zeros(|idx| zeros.push(idx));
472
473        assert_eq!(zeros.len(), N);
474        assert_eq!(zeros, (0..N).collect::<Vec<_>>());
475    }
476
477    #[test]
478    fn test_iter_zeros_all_set() {
479        let view = BitView::<NB>::all_true();
480
481        let mut zeros = Vec::new();
482        view.iter_zeros(|idx| zeros.push(idx));
483
484        assert_eq!(zeros, Vec::<usize>::new());
485    }
486
487    #[test]
488    fn test_iter_ones_single_bit() {
489        let mut bits = [0; NB];
490        bits[0] = 1; // Set bit 0 (LSB)
491        let view = BitView::new(&bits);
492
493        let mut ones = Vec::new();
494        view.iter_ones(|idx| ones.push(idx));
495
496        assert_eq!(ones, vec![0]);
497        assert_eq!(view.true_count(), 1);
498    }
499
500    #[test]
501    fn test_iter_zeros_single_bit_unset() {
502        let mut bits = [u8::MAX; NB];
503        bits[0] = u8::MAX ^ 1; // Clear bit 0 (LSB)
504        let view = BitView::new(&bits);
505
506        let mut zeros = Vec::new();
507        view.iter_zeros(|idx| zeros.push(idx));
508
509        assert_eq!(zeros, vec![0]);
510    }
511
512    #[test]
513    fn test_iter_ones_multiple_bits_first_word() {
514        let mut bits = [0; NB];
515        bits[0] = 0b1010101; // Set bits 0, 2, 4, 6
516        let view = BitView::new(&bits);
517
518        let mut ones = Vec::new();
519        view.iter_ones(|idx| ones.push(idx));
520
521        assert_eq!(ones, vec![0, 2, 4, 6]);
522        assert_eq!(view.true_count(), 4);
523    }
524
525    #[test]
526    fn test_iter_zeros_multiple_bits_first_word() {
527        let mut bits = [u8::MAX; NB];
528        bits[0] = !0b1010101; // Clear bits 0, 2, 4, 6
529        let view = BitView::new(&bits);
530
531        let mut zeros = Vec::new();
532        view.iter_zeros(|idx| zeros.push(idx));
533
534        assert_eq!(zeros, vec![0, 2, 4, 6]);
535    }
536
537    #[test]
538    fn test_lsb_bit_ordering() {
539        let mut bits = [0; NB];
540        bits[0] = 0b11111111; // Set bits 0-7 (LSB ordering)
541        let view = BitView::new(&bits);
542
543        let mut ones = Vec::new();
544        view.iter_ones(|idx| ones.push(idx));
545
546        assert_eq!(ones, vec![0, 1, 2, 3, 4, 5, 6, 7]);
547        assert_eq!(view.true_count(), 8);
548    }
549
550    #[test]
551    fn test_all_false_static() {
552        let view = BitView::<NB>::all_false();
553
554        let mut ones = Vec::new();
555        let mut zeros = Vec::new();
556        view.iter_ones(|idx| ones.push(idx));
557        view.iter_zeros(|idx| zeros.push(idx));
558
559        assert_eq!(ones, Vec::<usize>::new());
560        assert_eq!(zeros, (0..N).collect::<Vec<_>>());
561        assert_eq!(view.true_count(), 0);
562    }
563
564    #[test]
565    fn test_compatibility_with_mask_all_true() {
566        // Create corresponding BitView
567        let view = BitView::<NB>::all_true();
568
569        // Collect ones from BitView
570        let mut bitview_ones = Vec::new();
571        view.iter_ones(|idx| bitview_ones.push(idx));
572
573        // Get indices from Mask (all indices for all_true mask)
574        let expected_indices: Vec<usize> = (0..N).collect();
575
576        assert_eq!(bitview_ones, expected_indices);
577        assert_eq!(view.true_count(), N);
578    }
579
580    #[test]
581    fn test_compatibility_with_mask_all_false() {
582        // Create corresponding BitView
583        let view = BitView::<NB>::all_false();
584
585        // Collect ones from BitView
586        let mut bitview_ones = Vec::new();
587        view.iter_ones(|idx| bitview_ones.push(idx));
588
589        // Collect zeros from BitView
590        let mut bitview_zeros = Vec::new();
591        view.iter_zeros(|idx| bitview_zeros.push(idx));
592
593        assert_eq!(bitview_ones, Vec::<usize>::new());
594        assert_eq!(bitview_zeros, (0..N).collect::<Vec<_>>());
595        assert_eq!(view.true_count(), 0);
596    }
597
598    #[test]
599    fn test_compatibility_with_mask_from_indices() {
600        // Create a Mask from specific indices
601        let indices = vec![0, 10, 20, 63, 64, 100, 500, 1023];
602
603        // Create corresponding BitView
604        let mut bits = [0; NB];
605        for idx in &indices {
606            let word_idx = idx / 8;
607            let bit_idx = idx % 8;
608            bits[word_idx] |= 1u8 << bit_idx;
609        }
610        let view = BitView::new(&bits);
611
612        // Collect ones from BitView
613        let mut bitview_ones = Vec::new();
614        view.iter_ones(|idx| bitview_ones.push(idx));
615
616        assert_eq!(bitview_ones, indices);
617        assert_eq!(view.true_count(), indices.len());
618    }
619
620    #[test]
621    fn test_compatibility_with_mask_slices() {
622        // Create a Mask from slices (ranges)
623        let slices = vec![(0, 10), (100, 110), (500, 510)];
624
625        // Create corresponding BitView
626        let mut bits = [0; NB];
627        for (start, end) in &slices {
628            for idx in *start..*end {
629                let word_idx = idx / 8;
630                let bit_idx = idx % 8;
631                bits[word_idx] |= 1u8 << bit_idx;
632            }
633        }
634        let view = BitView::new(&bits);
635
636        // Collect ones from BitView
637        let mut bitview_ones = Vec::new();
638        view.iter_ones(|idx| bitview_ones.push(idx));
639
640        // Expected indices from slices
641        let mut expected_indices = Vec::new();
642        for (start, end) in &slices {
643            expected_indices.extend(*start..*end);
644        }
645
646        assert_eq!(bitview_ones, expected_indices);
647        assert_eq!(view.true_count(), expected_indices.len());
648    }
649
650    #[test]
651    fn test_with_prefix() {
652        assert_eq!(BitView::<NB>::with_prefix(0).true_count(), 0);
653
654        // May as well test all the possible prefix lengths!
655        for i in 1..N {
656            let view = BitView::<NB>::with_prefix(i);
657
658            // Collect slices (there should be one slice from 0 to n_true)
659            let mut slices = vec![];
660            view.iter_slices(|slice| slices.push(slice));
661
662            assert_eq!(slices.len(), 1);
663        }
664    }
665}