vortex_buffer/bit/
view.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::borrow::Cow;
5use std::fmt::{Debug, Formatter};
6use std::marker::PhantomData;
7
8use vortex_error::VortexResult;
9
10use crate::{BitBuffer, BitBufferMut};
11
12/// A borrowed fixed-size mask of length `N` bits.
13///
14/// Since const generic expressions are not yet stable, we instead define the type over the
15/// number of bytes `NB`, and compute `N` as `NB * 8`.
16///
17/// This struct is designed to provide a view over a Vortex [`BitBuffer`], therefore the
18/// bit-ordering is LSB0 (least-significant-bit first).
19///
20/// Note that [`BitView`] does not support an offset. Therefore, bits are assumed to start at
21/// index and end at index `N - 1`.
22pub struct BitView<'a, const NB: usize> {
23    bits: Cow<'a, [u8; NB]>,
24    // TODO(ngates): we may want to expose this for optimizations.
25    // If set to Selection::Prefix, then all true bits are at the start of the array.
26    // selection: Selection,
27    true_count: usize,
28}
29
30impl<const NB: usize> Debug for BitView<'_, NB> {
31    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct(&format!("BitView[{}]", NB * 8))
33            .field("true_count", &self.true_count)
34            .field("bits", &self.as_raw())
35            .finish()
36    }
37}
38
39impl<const NB: usize> BitView<'static, NB> {
40    const ALL_TRUE: [u8; NB] = [u8::MAX; NB];
41    const ALL_FALSE: [u8; NB] = [0; NB];
42
43    /// Creates a [`BitView`] with all bits set to `true`.
44    pub const fn all_true() -> Self {
45        unsafe { BitView::new_unchecked(&Self::ALL_TRUE, NB * 8) }
46    }
47
48    /// Creates a [`BitView`] with all bits set to `false`.
49    pub const fn all_false() -> Self {
50        unsafe { BitView::new_unchecked(&Self::ALL_FALSE, 0) }
51    }
52}
53
54impl<'a, const NB: usize> BitView<'a, NB> {
55    const N: usize = NB * 8;
56    const N_WORDS: usize = NB * 8 / (usize::BITS as usize);
57
58    const _ASSERT_MULTIPLE_OF_8: () = assert!(
59        NB % 8 == 0,
60        "NB must be a multiple of 8 for N to be a multiple of 64"
61    );
62
63    /// Creates a [`BitView`] from raw bits, computing the true count.
64    pub fn new(bits: &'a [u8; NB]) -> Self {
65        let ptr = bits.as_ptr().cast::<usize>();
66        let true_count = (0..Self::N_WORDS)
67            .map(|idx| unsafe { ptr.add(idx).read_unaligned().count_ones() as usize })
68            .sum();
69        BitView {
70            bits: Cow::Borrowed(bits),
71            true_count,
72        }
73    }
74
75    /// Creates a [`BitView`] from owned raw bits.
76    pub fn new_owned(bits: [u8; NB]) -> Self {
77        let ptr = bits.as_ptr().cast::<usize>();
78        let true_count = (0..Self::N_WORDS)
79            .map(|idx| unsafe { ptr.add(idx).read_unaligned().count_ones() as usize })
80            .sum();
81        BitView {
82            bits: Cow::Owned(bits),
83            true_count,
84        }
85    }
86
87    /// Creates a [`BitView`] from raw bits and a known true count.
88    ///
89    /// # Safety
90    ///
91    /// The caller must ensure that `true_count` is correct for the provided `bits`.
92    pub(crate) const unsafe fn new_unchecked(bits: &'a [u8; NB], true_count: usize) -> Self {
93        BitView {
94            bits: Cow::Borrowed(bits),
95            true_count,
96        }
97    }
98
99    /// Creates a [`BitView`] from a byte slice.
100    ///
101    /// # Panics
102    ///
103    /// If the length of the slice is not equal to `NB`.
104    pub fn from_slice(bits: &'a [u8]) -> Self {
105        assert_eq!(bits.len(), NB);
106        let bits_array = unsafe { &*(bits.as_ptr() as *const [u8; NB]) };
107        BitView::new(bits_array)
108    }
109
110    /// Creates a [`BitView`] from a mutable byte array, populating it with the requested prefix
111    /// of `true` bits.
112    pub fn with_prefix(n_true: usize) -> Self {
113        assert!(n_true <= Self::N);
114
115        // We're going to own our own array of bits
116        let mut bits = [0u8; NB];
117
118        // All-true words first
119        let n_full_words = n_true / (usize::BITS as usize);
120        let remaining_bits = n_true % (usize::BITS as usize);
121
122        let ptr = bits.as_mut_ptr().cast::<usize>();
123
124        // Fill the all-true words
125        for word_idx in 0..n_full_words {
126            unsafe { ptr.add(word_idx).write_unaligned(usize::MAX) };
127        }
128
129        // Fill the remaining bits in the next word
130        if remaining_bits > 0 {
131            let mask = (1usize << remaining_bits) - 1;
132            unsafe { ptr.add(n_full_words).write_unaligned(mask) };
133        }
134
135        Self {
136            bits: Cow::Owned(bits),
137            true_count: n_true,
138        }
139    }
140
141    /// Returns the number of `true` bits in the view.
142    pub fn true_count(&self) -> usize {
143        self.true_count
144    }
145
146    /// Iterate the [`BitView`] in fixed-size words.
147    ///
148    /// The words are loaded using unaligned loads to ensure correct bit ordering.
149    /// For example, bit 0 is located in `word & 1 << 0`, bit 63 is located in `word & 1 << 63`,
150    /// assuming the word size is 64 bits.
151    fn iter_words(&self) -> impl Iterator<Item = usize> + '_ {
152        let ptr = self.bits.as_ptr().cast::<usize>();
153        // We use constant N_WORDS to trigger loop unrolling.
154        (0..Self::N_WORDS).map(move |idx| unsafe { ptr.add(idx).read_unaligned() })
155    }
156
157    /// Runs the provided function `f` for each index of a `true` bit in the view.
158    pub fn iter_ones<F>(&self, mut f: F)
159    where
160        F: FnMut(usize),
161    {
162        match self.true_count {
163            0 => {}
164            n if n == Self::N => (0..Self::N).for_each(&mut f),
165            _ => {
166                let mut bit_idx = 0;
167                for mut raw in self.iter_words() {
168                    while raw != 0 {
169                        let bit_pos = raw.trailing_zeros();
170                        f(bit_idx + bit_pos as usize);
171                        raw &= raw - 1; // Clear the bit at `bit_pos`
172                    }
173                    bit_idx += usize::BITS as usize;
174                }
175            }
176        }
177    }
178
179    /// Runs the provided function `f` for each index of a `true` bit in the view.
180    pub fn try_iter_ones<F>(&self, mut f: F) -> VortexResult<()>
181    where
182        F: FnMut(usize) -> VortexResult<()>,
183    {
184        match self.true_count {
185            0 => Ok(()),
186            n if n == Self::N => {
187                for i in 0..Self::N {
188                    f(i)?;
189                }
190                Ok(())
191            }
192            _ => {
193                let mut bit_idx = 0;
194                for mut raw in self.iter_words() {
195                    while raw != 0 {
196                        let bit_pos = raw.trailing_zeros();
197                        f(bit_idx + bit_pos as usize)?;
198                        raw &= raw - 1; // Clear the bit at `bit_pos`
199                    }
200                    bit_idx += usize::BITS as usize;
201                }
202                Ok(())
203            }
204        }
205    }
206
207    /// Runs the provided function `f` for each index of a `true` bit in the view.
208    pub fn iter_zeros<F>(&self, mut f: F)
209    where
210        F: FnMut(usize),
211    {
212        match self.true_count {
213            0 => (0..Self::N).for_each(&mut f),
214            n if n == Self::N => {}
215            _ => {
216                let mut bit_idx = 0;
217                for mut raw in self.iter_words() {
218                    while raw != usize::MAX {
219                        let bit_pos = raw.trailing_ones();
220                        f(bit_idx + bit_pos as usize);
221                        raw |= 1usize << bit_pos; // Set the zero bit to 1
222                    }
223                    bit_idx += usize::BITS as usize;
224                }
225            }
226        }
227    }
228
229    /// Runs the provided function `f` for each range of `true` bits in the view.
230    ///
231    /// The function `f` receives a [`BitSlice`] containing the inclusive `start` bit as well as
232    /// the length.
233    ///
234    /// FIXME(ngates): this is still broken.
235    pub fn iter_slices<F>(&self, mut f: F)
236    where
237        F: FnMut(BitSlice),
238    {
239        if self.true_count == 0 {
240            return;
241        }
242
243        let mut abs_bit_offset: usize = 0; // Absolute bit index of the *current* word being processed
244        let mut slice_start_bit: usize = 0; // Absolute start index of the run of 1s being tracked
245        let mut slice_length: usize = 0; // Accumulated length of the run of 1s
246
247        for mut word in self.iter_words() {
248            match word {
249                0 => {
250                    // If a slice was being tracked, the run ends at the start of this word.
251                    if slice_length > 0 {
252                        f(BitSlice {
253                            start: slice_start_bit,
254                            len: slice_length,
255                        });
256                        slice_length = 0;
257                    }
258                }
259                usize::MAX => {
260                    // If a slice was not already open, it starts at the beginning of this word.
261                    if slice_length == 0 {
262                        slice_start_bit = abs_bit_offset;
263                    }
264                    // Extend the length by a full word (64 bits).
265                    slice_length += usize::BITS as usize;
266                }
267                _ => {
268                    while word != 0 {
269                        // Find the first set bit (start of a run of 1s)
270                        let zeros = word.trailing_zeros() as usize;
271
272                        // If a run was open, and we hit a zero gap, report the finished slice
273                        if slice_length > 0 && zeros > 0 {
274                            f(BitSlice {
275                                start: slice_start_bit,
276                                len: slice_length,
277                            });
278                            slice_length = 0; // Reset state for a new slice
279                        }
280
281                        // Advance past the zeros
282                        word >>= zeros;
283
284                        if word == 0 {
285                            break;
286                        }
287
288                        // Find the contiguous ones (the length of the current run segment)
289                        let ones = word.trailing_ones() as usize;
290
291                        // If slice_length is 0, we found the *absolute* start of a new slice.
292                        if slice_length == 0 {
293                            // Calculate the bit index within the *entire* mask where this run starts
294                            let current_word_idx = abs_bit_offset + zeros;
295                            slice_start_bit = current_word_idx;
296                        }
297
298                        // Accumulate the length of the slice
299                        slice_length += ones;
300
301                        // Advance past the ones
302                        word >>= ones;
303                    }
304                }
305            }
306
307            abs_bit_offset += usize::BITS as usize;
308        }
309
310        if slice_length > 0 {
311            f(BitSlice {
312                start: slice_start_bit,
313                len: slice_length,
314            });
315        }
316    }
317
318    /// Returns the raw bits of the view.
319    pub fn as_raw(&self) -> &[u8; NB] {
320        self.bits.as_ref()
321    }
322}
323
324/// A slice of bits within a [`BitBuffer`].
325///
326/// We use this struct to avoid a common mistake of assuming the slices represent (start, end) ranges,
327pub struct BitSlice {
328    /// The starting bit index of the slice.
329    pub start: usize,
330    /// The length of the slice in bits.
331    pub len: usize,
332}
333
334impl BitBuffer {
335    /// Iterate the buffer as [`BitView`]s of size `NB` where the number of bits in each view
336    /// is `NB * 8`.
337    ///
338    /// The final chunk will be filled with unset padding bits if the bit buffer's length is not
339    /// a multiple of `N`.
340    ///
341    /// The number of bits `N` must be a multiple of 64.
342    ///
343    /// # Panics
344    ///
345    /// If the bit offset is not zero
346    pub fn iter_bit_views<const NB: usize>(&self) -> impl Iterator<Item = BitView<'_, NB>> + '_ {
347        assert_eq!(
348            self.offset(),
349            0,
350            "BitView iteration requires zero bit offset"
351        );
352        BitViewIterator::new(self.inner().as_ref())
353    }
354}
355
356impl BitBufferMut {
357    /// Iterate the buffer as [`BitView`]s of size `NB` where the number of bits in each view
358    /// is `NB * 8`.
359    ///
360    /// The final chunk will be filled with unset padding bits if the bit buffer's length is not
361    /// a multiple of `N`.
362    ///
363    /// The number of bits `N` must be a multiple of 64.
364    ///
365    /// # Panics
366    ///
367    /// If the bit offset is not zero
368    pub fn iter_bit_views<const NB: usize>(&self) -> impl Iterator<Item = BitView<'_, NB>> + '_ {
369        assert_eq!(
370            self.offset(),
371            0,
372            "BitView iteration requires zero bit offset"
373        );
374        BitViewIterator::new(self.inner().as_ref())
375    }
376}
377
378/// Iterator over fixed-size [`BitView`]s within a byte slice.
379pub(super) struct BitViewIterator<'a, const NB: usize> {
380    bits: &'a [u8],
381    // The index of the view to be returned next
382    view_idx: usize,
383    // The total number of views
384    n_views: usize,
385    /// Phantom to capture `NB`
386    _phantom: PhantomData<[u8; NB]>,
387}
388
389impl<'a, const NB: usize> BitViewIterator<'a, NB> {
390    /// Create a new [`BitViewIterator`].
391    pub fn new(bits: &'a [u8]) -> Self {
392        let n_views = bits.len().div_ceil(NB);
393        BitViewIterator {
394            bits,
395            view_idx: 0,
396            n_views,
397            _phantom: PhantomData,
398        }
399    }
400}
401
402impl<'a, const NB: usize> Iterator for BitViewIterator<'a, NB> {
403    type Item = BitView<'a, NB>;
404
405    fn next(&mut self) -> Option<Self::Item> {
406        if self.view_idx == self.n_views {
407            return None;
408        }
409
410        let start_byte = self.view_idx * NB;
411        let end_byte = start_byte + NB;
412
413        let bits = if end_byte <= self.bits.len() {
414            // Full view from the original bits
415            BitView::from_slice(&self.bits[start_byte..end_byte])
416        } else {
417            // Partial view, copy to scratch
418            let remaining_bytes = self.bits.len() - start_byte;
419            let mut remaining = [0u8; NB];
420            remaining[..remaining_bytes].copy_from_slice(&self.bits[start_byte..]);
421            BitView::new_owned(remaining)
422        };
423
424        self.view_idx += 1;
425        Some(bits)
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    const NB: usize = 128; // Number of bytes
434    const N: usize = NB * 8; // Number of bits
435
436    #[test]
437    fn test_iter_ones_empty() {
438        let bits = [0; NB];
439        let view = BitView::<NB>::new(&bits);
440
441        let mut ones = Vec::new();
442        view.iter_ones(|idx| ones.push(idx));
443
444        assert_eq!(ones, Vec::<usize>::new());
445        assert_eq!(view.true_count(), 0);
446    }
447
448    #[test]
449    fn test_iter_ones_all_set() {
450        let view = BitView::<NB>::all_true();
451
452        let mut ones = Vec::new();
453        view.iter_ones(|idx| ones.push(idx));
454
455        assert_eq!(ones.len(), N);
456        assert_eq!(ones, (0..N).collect::<Vec<_>>());
457        assert_eq!(view.true_count(), N);
458    }
459
460    #[test]
461    fn test_iter_zeros_empty() {
462        let bits = [0; NB];
463        let view = BitView::<NB>::new(&bits);
464
465        let mut zeros = Vec::new();
466        view.iter_zeros(|idx| zeros.push(idx));
467
468        assert_eq!(zeros.len(), N);
469        assert_eq!(zeros, (0..N).collect::<Vec<_>>());
470    }
471
472    #[test]
473    fn test_iter_zeros_all_set() {
474        let view = BitView::<NB>::all_true();
475
476        let mut zeros = Vec::new();
477        view.iter_zeros(|idx| zeros.push(idx));
478
479        assert_eq!(zeros, Vec::<usize>::new());
480    }
481
482    #[test]
483    fn test_iter_ones_single_bit() {
484        let mut bits = [0; NB];
485        bits[0] = 1; // Set bit 0 (LSB)
486        let view = BitView::new(&bits);
487
488        let mut ones = Vec::new();
489        view.iter_ones(|idx| ones.push(idx));
490
491        assert_eq!(ones, vec![0]);
492        assert_eq!(view.true_count(), 1);
493    }
494
495    #[test]
496    fn test_iter_zeros_single_bit_unset() {
497        let mut bits = [u8::MAX; NB];
498        bits[0] = u8::MAX ^ 1; // Clear bit 0 (LSB)
499        let view = BitView::new(&bits);
500
501        let mut zeros = Vec::new();
502        view.iter_zeros(|idx| zeros.push(idx));
503
504        assert_eq!(zeros, vec![0]);
505    }
506
507    #[test]
508    fn test_iter_ones_multiple_bits_first_word() {
509        let mut bits = [0; NB];
510        bits[0] = 0b1010101; // Set bits 0, 2, 4, 6
511        let view = BitView::new(&bits);
512
513        let mut ones = Vec::new();
514        view.iter_ones(|idx| ones.push(idx));
515
516        assert_eq!(ones, vec![0, 2, 4, 6]);
517        assert_eq!(view.true_count(), 4);
518    }
519
520    #[test]
521    fn test_iter_zeros_multiple_bits_first_word() {
522        let mut bits = [u8::MAX; NB];
523        bits[0] = !0b1010101; // Clear bits 0, 2, 4, 6
524        let view = BitView::new(&bits);
525
526        let mut zeros = Vec::new();
527        view.iter_zeros(|idx| zeros.push(idx));
528
529        assert_eq!(zeros, vec![0, 2, 4, 6]);
530    }
531
532    #[test]
533    fn test_lsb_bit_ordering() {
534        let mut bits = [0; NB];
535        bits[0] = 0b11111111; // Set bits 0-7 (LSB ordering)
536        let view = BitView::new(&bits);
537
538        let mut ones = Vec::new();
539        view.iter_ones(|idx| ones.push(idx));
540
541        assert_eq!(ones, vec![0, 1, 2, 3, 4, 5, 6, 7]);
542        assert_eq!(view.true_count(), 8);
543    }
544
545    #[test]
546    fn test_all_false_static() {
547        let view = BitView::<NB>::all_false();
548
549        let mut ones = Vec::new();
550        let mut zeros = Vec::new();
551        view.iter_ones(|idx| ones.push(idx));
552        view.iter_zeros(|idx| zeros.push(idx));
553
554        assert_eq!(ones, Vec::<usize>::new());
555        assert_eq!(zeros, (0..N).collect::<Vec<_>>());
556        assert_eq!(view.true_count(), 0);
557    }
558
559    #[test]
560    fn test_compatibility_with_mask_all_true() {
561        // Create corresponding BitView
562        let view = BitView::<NB>::all_true();
563
564        // Collect ones from BitView
565        let mut bitview_ones = Vec::new();
566        view.iter_ones(|idx| bitview_ones.push(idx));
567
568        // Get indices from Mask (all indices for all_true mask)
569        let expected_indices: Vec<usize> = (0..N).collect();
570
571        assert_eq!(bitview_ones, expected_indices);
572        assert_eq!(view.true_count(), N);
573    }
574
575    #[test]
576    fn test_compatibility_with_mask_all_false() {
577        // Create corresponding BitView
578        let view = BitView::<NB>::all_false();
579
580        // Collect ones from BitView
581        let mut bitview_ones = Vec::new();
582        view.iter_ones(|idx| bitview_ones.push(idx));
583
584        // Collect zeros from BitView
585        let mut bitview_zeros = Vec::new();
586        view.iter_zeros(|idx| bitview_zeros.push(idx));
587
588        assert_eq!(bitview_ones, Vec::<usize>::new());
589        assert_eq!(bitview_zeros, (0..N).collect::<Vec<_>>());
590        assert_eq!(view.true_count(), 0);
591    }
592
593    #[test]
594    fn test_compatibility_with_mask_from_indices() {
595        // Create a Mask from specific indices
596        let indices = vec![0, 10, 20, 63, 64, 100, 500, 1023];
597
598        // Create corresponding BitView
599        let mut bits = [0; NB];
600        for idx in &indices {
601            let word_idx = idx / 8;
602            let bit_idx = idx % 8;
603            bits[word_idx] |= 1u8 << bit_idx;
604        }
605        let view = BitView::new(&bits);
606
607        // Collect ones from BitView
608        let mut bitview_ones = Vec::new();
609        view.iter_ones(|idx| bitview_ones.push(idx));
610
611        assert_eq!(bitview_ones, indices);
612        assert_eq!(view.true_count(), indices.len());
613    }
614
615    #[test]
616    fn test_compatibility_with_mask_slices() {
617        // Create a Mask from slices (ranges)
618        let slices = vec![(0, 10), (100, 110), (500, 510)];
619
620        // Create corresponding BitView
621        let mut bits = [0; NB];
622        for (start, end) in &slices {
623            for idx in *start..*end {
624                let word_idx = idx / 8;
625                let bit_idx = idx % 8;
626                bits[word_idx] |= 1u8 << bit_idx;
627            }
628        }
629        let view = BitView::new(&bits);
630
631        // Collect ones from BitView
632        let mut bitview_ones = Vec::new();
633        view.iter_ones(|idx| bitview_ones.push(idx));
634
635        // Expected indices from slices
636        let mut expected_indices = Vec::new();
637        for (start, end) in &slices {
638            expected_indices.extend(*start..*end);
639        }
640
641        assert_eq!(bitview_ones, expected_indices);
642        assert_eq!(view.true_count(), expected_indices.len());
643    }
644
645    #[test]
646    fn test_with_prefix() {
647        assert_eq!(BitView::<NB>::with_prefix(0).true_count(), 0);
648
649        // May as well test all the possible prefix lengths!
650        for i in 1..N {
651            let view = BitView::<NB>::with_prefix(i);
652
653            // Collect slices (there should be one slice from 0 to n_true)
654            let mut slices = vec![];
655            view.iter_slices(|slice| slices.push(slice));
656
657            assert_eq!(slices.len(), 1);
658        }
659    }
660}