vortex_buffer/bit/
buf.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::{BitAnd, BitOr, BitXor, Not, RangeBounds};
5
6use crate::bit::ops::{bitwise_binary_op, bitwise_unary_op};
7use crate::bit::{
8    BitChunks, BitIndexIterator, BitIterator, BitSliceIterator, UnalignedBitChunk,
9    get_bit_unchecked,
10};
11use crate::{Alignment, BitBufferMut, Buffer, BufferMut, ByteBuffer, buffer};
12
13/// An immutable bitset stored as a packed byte buffer.
14#[derive(Debug, Clone, Eq)]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16pub struct BitBuffer {
17    buffer: ByteBuffer,
18    /// Represents the offset of the bit buffer into the first byte.
19    ///
20    /// This is always less than 8 (for when the bit buffer is not aligned to a byte).
21    offset: usize,
22    len: usize,
23}
24
25impl PartialEq for BitBuffer {
26    fn eq(&self, other: &Self) -> bool {
27        if self.len != other.len {
28            return false;
29        }
30
31        self.chunks()
32            .iter_padded()
33            .zip(other.chunks().iter_padded())
34            .all(|(a, b)| a == b)
35    }
36}
37
38impl BitBuffer {
39    /// Create a new `BoolBuffer` backed by a [`ByteBuffer`] with `len` bits in view.
40    ///
41    /// Panics if the buffer is not large enough to hold `len` bits.
42    pub fn new(buffer: ByteBuffer, len: usize) -> Self {
43        assert!(
44            buffer.len() * 8 >= len,
45            "provided ByteBuffer not large enough to back BoolBuffer with len {len}"
46        );
47
48        // BitBuffers make no assumptions on byte alignment, so we strip any alignment.
49        let buffer = buffer.aligned(Alignment::none());
50
51        Self {
52            buffer,
53            len,
54            offset: 0,
55        }
56    }
57
58    /// Create a new `BoolBuffer` backed by a [`ByteBuffer`] with `len` bits in view, starting at
59    /// the given `offset` (in bits).
60    ///
61    /// Panics if the buffer is not large enough to hold `len` bits after the offset.
62    pub fn new_with_offset(buffer: ByteBuffer, len: usize, offset: usize) -> Self {
63        assert!(
64            len.saturating_add(offset) <= buffer.len().saturating_mul(8),
65            "provided ByteBuffer (len={}) not large enough to back BoolBuffer with offset {offset} len {len}",
66            buffer.len()
67        );
68
69        // BitBuffers make no assumptions on byte alignment, so we strip any alignment.
70        let buffer = buffer.aligned(Alignment::none());
71
72        Self {
73            buffer,
74            offset,
75            len,
76        }
77    }
78
79    /// Create a new `BoolBuffer` of length `len` where all bits are set (true).
80    pub fn new_set(len: usize) -> Self {
81        let words = len.div_ceil(8);
82        let buffer = buffer![0xFF; words];
83
84        Self {
85            buffer,
86            len,
87            offset: 0,
88        }
89    }
90
91    /// Create a new `BoolBuffer` of length `len` where all bits are unset (false).
92    pub fn new_unset(len: usize) -> Self {
93        let words = len.div_ceil(8);
94        let buffer = Buffer::zeroed(words);
95
96        Self {
97            buffer,
98            len,
99            offset: 0,
100        }
101    }
102
103    /// Create a new empty `BitBuffer`.
104    pub fn empty() -> Self {
105        Self::new_set(0)
106    }
107
108    /// Create a new `BitBuffer` of length `len` where all bits are set to `value`.
109    pub fn full(value: bool, len: usize) -> Self {
110        if value {
111            Self::new_set(len)
112        } else {
113            Self::new_unset(len)
114        }
115    }
116
117    /// Invokes `f` with indexes `0..len` collecting the boolean results into a new `BitBuffer`
118    pub fn collect_bool<F: FnMut(usize) -> bool>(len: usize, mut f: F) -> Self {
119        let mut buffer = BufferMut::with_capacity(len.div_ceil(64) * 8);
120
121        let chunks = len / 64;
122        let remainder = len % 64;
123        for chunk in 0..chunks {
124            let mut packed = 0;
125            for bit_idx in 0..64 {
126                let i = bit_idx + chunk * 64;
127                packed |= (f(i) as u64) << bit_idx;
128            }
129
130            // SAFETY: Already allocated sufficient capacity
131            unsafe { buffer.push_unchecked(packed) }
132        }
133
134        if remainder != 0 {
135            let mut packed = 0;
136            for bit_idx in 0..remainder {
137                let i = bit_idx + chunks * 64;
138                packed |= (f(i) as u64) << bit_idx;
139            }
140
141            // SAFETY: Already allocated sufficient capacity
142            unsafe { buffer.push_unchecked(packed) }
143        }
144
145        buffer.truncate(len.div_ceil(8));
146
147        Self::new(buffer.freeze().into_byte_buffer(), len)
148    }
149
150    /// Get the logical length of this `BoolBuffer`.
151    ///
152    /// This may differ from the physical length of the backing buffer, for example if it was
153    /// created using the `new_with_offset` constructor, or if it was sliced.
154    #[inline]
155    pub fn len(&self) -> usize {
156        self.len
157    }
158
159    /// Returns `true` if the `BoolBuffer` is empty.
160    #[inline]
161    pub fn is_empty(&self) -> bool {
162        self.len() == 0
163    }
164
165    /// Offset of the start of the buffer in bits.
166    #[inline(always)]
167    pub fn offset(&self) -> usize {
168        self.offset
169    }
170
171    /// Get a reference to the underlying buffer.
172    #[inline(always)]
173    pub fn inner(&self) -> &ByteBuffer {
174        &self.buffer
175    }
176
177    /// Retrieve the value at the given index.
178    ///
179    /// Panics if the index is out of bounds.
180    ///
181    /// Please note for repeatedly calling this function, please prefer [`crate::get_bit`].
182    #[inline]
183    pub fn value(&self, index: usize) -> bool {
184        assert!(index < self.len);
185        unsafe { self.value_unchecked(index) }
186    }
187
188    /// Retrieve the value at the given index without bounds checking
189    ///
190    /// # SAFETY
191    /// Caller must ensure that index is within the range of the buffer
192    #[inline]
193    pub unsafe fn value_unchecked(&self, index: usize) -> bool {
194        unsafe { get_bit_unchecked(self.buffer.as_ptr(), index + self.offset) }
195    }
196
197    /// Create a new zero-copy slice of this BoolBuffer that begins at the `start` index and extends
198    /// for `len` bits.
199    ///
200    /// Panics if the slice would extend beyond the end of the buffer.
201    pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
202        let start = match range.start_bound() {
203            std::ops::Bound::Included(&s) => s,
204            std::ops::Bound::Excluded(&s) => s + 1,
205            std::ops::Bound::Unbounded => 0,
206        };
207        let end = match range.end_bound() {
208            std::ops::Bound::Included(&e) => e + 1,
209            std::ops::Bound::Excluded(&e) => e,
210            std::ops::Bound::Unbounded => self.len,
211        };
212
213        assert!(start <= end);
214        assert!(start <= self.len);
215        assert!(end <= self.len);
216        let len = end - start;
217
218        Self::new_with_offset(self.buffer.clone(), len, self.offset + start)
219    }
220
221    /// Slice any full bytes from the buffer, leaving the offset < 8.
222    pub fn shrink_offset(self) -> Self {
223        let bit_offset = self.offset % 8;
224        let len = self.len;
225        let buffer = self.into_inner();
226        BitBuffer::new_with_offset(buffer, len, bit_offset)
227    }
228
229    /// Access chunks of the buffer aligned to 8 byte boundary as [prefix, \<full chunks\>, suffix]
230    pub fn unaligned_chunks(&self) -> UnalignedBitChunk<'_> {
231        UnalignedBitChunk::new(self.buffer.as_slice(), self.offset, self.len)
232    }
233
234    /// Access chunks of the underlying buffer as 8 byte chunks with a final trailer
235    ///
236    /// If you're performing operations on a single buffer, prefer [BitBuffer::unaligned_chunks]
237    pub fn chunks(&self) -> BitChunks<'_> {
238        BitChunks::new(self.buffer.as_slice(), self.offset, self.len)
239    }
240
241    /// Get the number of set bits in the buffer.
242    pub fn true_count(&self) -> usize {
243        self.unaligned_chunks().count_ones()
244    }
245
246    /// Get the number of unset bits in the buffer.
247    pub fn false_count(&self) -> usize {
248        self.len - self.true_count()
249    }
250
251    /// Iterator over bits in the buffer
252    pub fn iter(&self) -> BitIterator<'_> {
253        BitIterator::new(self.buffer.as_slice(), self.offset, self.len)
254    }
255
256    /// Iterator over set indices of the underlying buffer
257    pub fn set_indices(&self) -> BitIndexIterator<'_> {
258        BitIndexIterator::new(self.buffer.as_slice(), self.offset, self.len)
259    }
260
261    /// Iterator over set slices of the underlying buffer
262    pub fn set_slices(&self) -> BitSliceIterator<'_> {
263        BitSliceIterator::new(self.buffer.as_slice(), self.offset, self.len)
264    }
265
266    /// Created a new BitBuffer with offset reset to 0
267    pub fn sliced(&self) -> Self {
268        if self.offset % 8 == 0 {
269            return Self::new(
270                self.buffer.slice(self.offset / 8..self.len.div_ceil(8)),
271                self.len,
272            );
273        }
274        bitwise_unary_op(self, |a| a)
275    }
276}
277
278// Conversions
279
280impl BitBuffer {
281    /// Consumes this `BoolBuffer` and returns the backing `Buffer<u8>` with any offset
282    /// and length information applied.
283    pub fn into_inner(self) -> ByteBuffer {
284        let word_start = self.offset / 8;
285        let word_end = (self.offset + self.len).div_ceil(8);
286
287        self.buffer.slice(word_start..word_end)
288    }
289
290    /// Attempt to convert this `BitBuffer` into a mutable version.
291    pub fn try_into_mut(self) -> Result<BitBufferMut, Self> {
292        match self.buffer.try_into_mut() {
293            Ok(buffer) => Ok(BitBufferMut::from_buffer(buffer, self.offset, self.len)),
294            Err(buffer) => Err(BitBuffer::new_with_offset(buffer, self.len, self.offset)),
295        }
296    }
297
298    /// Get a mutable version of this `BitBuffer` along with bit offset in the first byte.
299    ///
300    /// If the caller doesn't hold only reference to the underlying buffer, a copy is created.
301    /// The second value of the tuple is a bit_offset of the first value in the first byte
302    pub fn into_mut(self) -> BitBufferMut {
303        let offset = self.offset;
304        let len = self.len;
305        // TODO(robert): if we are copying here we could strip offset bits
306        let inner = self.into_inner().into_mut();
307        BitBufferMut::from_buffer(inner, offset, len)
308    }
309}
310
311impl From<&[bool]> for BitBuffer {
312    fn from(value: &[bool]) -> Self {
313        BitBufferMut::from(value).freeze()
314    }
315}
316
317impl From<Vec<bool>> for BitBuffer {
318    fn from(value: Vec<bool>) -> Self {
319        BitBufferMut::from(value).freeze()
320    }
321}
322
323impl FromIterator<bool> for BitBuffer {
324    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
325        BitBufferMut::from_iter(iter).freeze()
326    }
327}
328
329impl BitOr for &BitBuffer {
330    type Output = BitBuffer;
331
332    fn bitor(self, rhs: Self) -> Self::Output {
333        bitwise_binary_op(self, rhs, |a, b| a | b)
334    }
335}
336
337impl BitOr<&BitBuffer> for BitBuffer {
338    type Output = BitBuffer;
339
340    fn bitor(self, rhs: &BitBuffer) -> Self::Output {
341        (&self).bitor(rhs)
342    }
343}
344
345impl BitAnd for &BitBuffer {
346    type Output = BitBuffer;
347
348    fn bitand(self, rhs: Self) -> Self::Output {
349        bitwise_binary_op(self, rhs, |a, b| a & b)
350    }
351}
352
353impl BitAnd<BitBuffer> for &BitBuffer {
354    type Output = BitBuffer;
355
356    fn bitand(self, rhs: BitBuffer) -> Self::Output {
357        self.bitand(&rhs)
358    }
359}
360
361impl BitAnd<&BitBuffer> for BitBuffer {
362    type Output = BitBuffer;
363
364    fn bitand(self, rhs: &BitBuffer) -> Self::Output {
365        (&self).bitand(rhs)
366    }
367}
368
369impl Not for &BitBuffer {
370    type Output = BitBuffer;
371
372    fn not(self) -> Self::Output {
373        bitwise_unary_op(self, |a| !a)
374    }
375}
376
377impl Not for BitBuffer {
378    type Output = BitBuffer;
379
380    fn not(self) -> Self::Output {
381        (&self).not()
382    }
383}
384
385impl BitXor for &BitBuffer {
386    type Output = BitBuffer;
387
388    fn bitxor(self, rhs: Self) -> Self::Output {
389        bitwise_binary_op(self, rhs, |a, b| a ^ b)
390    }
391}
392
393impl BitXor<&BitBuffer> for BitBuffer {
394    type Output = BitBuffer;
395
396    fn bitxor(self, rhs: &BitBuffer) -> Self::Output {
397        (&self).bitxor(rhs)
398    }
399}
400
401impl BitBuffer {
402    /// Create a new BitBuffer by performing a bitwise AND NOT operation between two BitBuffers.
403    ///
404    /// This operation is sufficiently common that we provide a dedicated method for it avoid
405    /// making two passes over the data.
406    pub fn bitand_not(&self, rhs: &BitBuffer) -> BitBuffer {
407        bitwise_binary_op(self, rhs, |a, b| a & !b)
408    }
409
410    /// Iterate through bits in a buffer.
411    ///
412    /// # Arguments
413    ///
414    /// * `f` - Callback function taking (bit_index, is_set)
415    ///
416    /// # Panics
417    ///
418    /// Panics if the range is outside valid bounds of the buffer.
419    #[inline]
420    pub fn iter_bits<F>(&self, mut f: F)
421    where
422        F: FnMut(usize, bool),
423    {
424        let total_bits = self.len;
425        if total_bits == 0 {
426            return;
427        }
428
429        let is_bit_set = |byte: u8, bit_idx: usize| (byte & (1 << bit_idx)) != 0;
430        let bit_offset = self.offset % 8;
431        let mut buffer_ptr = unsafe { self.buffer.as_ptr().add(self.offset / 8) };
432        let mut callback_idx = 0;
433
434        // Handle incomplete first byte.
435        if bit_offset > 0 {
436            let bits_in_first_byte = (8 - bit_offset).min(total_bits);
437            let byte = unsafe { *buffer_ptr };
438
439            for bit_idx in 0..bits_in_first_byte {
440                f(callback_idx, is_bit_set(byte, bit_offset + bit_idx));
441                callback_idx += 1;
442            }
443
444            buffer_ptr = unsafe { buffer_ptr.add(1) };
445        }
446
447        // Process complete bytes.
448        let complete_bytes = (total_bits - callback_idx) / 8;
449        for _ in 0..complete_bytes {
450            let byte = unsafe { *buffer_ptr };
451
452            for bit_idx in 0..8 {
453                f(callback_idx, is_bit_set(byte, bit_idx));
454                callback_idx += 1;
455            }
456            buffer_ptr = unsafe { buffer_ptr.add(1) };
457        }
458
459        // Handle remaining bits at the end.
460        let remaining_bits = total_bits - callback_idx;
461        if remaining_bits > 0 {
462            let byte = unsafe { *buffer_ptr };
463
464            for bit_idx in 0..remaining_bits {
465                f(callback_idx, is_bit_set(byte, bit_idx));
466                callback_idx += 1;
467            }
468        }
469    }
470}
471
472impl<'a> IntoIterator for &'a BitBuffer {
473    type Item = bool;
474    type IntoIter = BitIterator<'a>;
475
476    fn into_iter(self) -> Self::IntoIter {
477        self.iter()
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use rstest::rstest;
484
485    use crate::bit::BitBuffer;
486    use crate::{ByteBuffer, buffer};
487
488    #[test]
489    fn test_bool() {
490        // Create a new Buffer<u64> of length 1024 where the 8th bit is set.
491        let buffer: ByteBuffer = buffer![1 << 7; 1024];
492        let bools = BitBuffer::new(buffer, 1024 * 8);
493
494        // sanity checks
495        assert_eq!(bools.len(), 1024 * 8);
496        assert!(!bools.is_empty());
497        assert_eq!(bools.true_count(), 1024);
498        assert_eq!(bools.false_count(), 1024 * 7);
499
500        // Check all the values
501        for word in 0..1024 {
502            for bit in 0..8 {
503                if bit == 7 {
504                    assert!(bools.value(word * 8 + bit));
505                } else {
506                    assert!(!bools.value(word * 8 + bit));
507                }
508            }
509        }
510
511        // Slice the buffer to create a new subset view.
512        let sliced = bools.slice(64..72);
513
514        // sanity checks
515        assert_eq!(sliced.len(), 8);
516        assert!(!sliced.is_empty());
517        assert_eq!(sliced.true_count(), 1);
518        assert_eq!(sliced.false_count(), 7);
519
520        // Check all of the values like before
521        for bit in 0..8 {
522            if bit == 7 {
523                assert!(sliced.value(bit));
524            } else {
525                assert!(!sliced.value(bit));
526            }
527        }
528    }
529
530    #[test]
531    fn test_padded_equaltiy() {
532        let buf1 = BitBuffer::new_set(64); // All bits set.
533        let buf2 = BitBuffer::collect_bool(64, |x| x < 32); // First half set, other half unset.
534
535        for i in 0..32 {
536            assert_eq!(buf1.value(i), buf2.value(i), "Bit {} should be the same", i);
537        }
538
539        for i in 32..64 {
540            assert_ne!(buf1.value(i), buf2.value(i), "Bit {} should differ", i);
541        }
542
543        assert_eq!(
544            buf1.slice(0..32),
545            buf2.slice(0..32),
546            "Buffer slices with same bits should be equal (`PartialEq` needs `iter_padded()`)"
547        );
548        assert_ne!(
549            buf1.slice(32..64),
550            buf2.slice(32..64),
551            "Buffer slices with different bits should not be equal (`PartialEq` needs `iter_padded()`)"
552        );
553    }
554
555    #[test]
556    fn test_slice_offset_calculation() {
557        let buf = BitBuffer::collect_bool(16, |_| true);
558        let sliced = buf.slice(10..16);
559        assert_eq!(sliced.offset(), 10);
560    }
561
562    #[rstest]
563    #[case(5)]
564    #[case(8)]
565    #[case(10)]
566    #[case(13)]
567    #[case(16)]
568    #[case(23)]
569    #[case(100)]
570    fn test_iter_bits(#[case] len: usize) {
571        let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
572
573        let mut collected = Vec::new();
574        buf.iter_bits(|idx, is_set| {
575            collected.push((idx, is_set));
576        });
577
578        assert_eq!(collected.len(), len);
579
580        for (idx, is_set) in collected {
581            assert_eq!(is_set, idx % 2 == 0);
582        }
583    }
584
585    #[rstest]
586    #[case(3, 5)]
587    #[case(3, 8)]
588    #[case(5, 10)]
589    #[case(2, 16)]
590    #[case(8, 16)]
591    #[case(9, 16)]
592    #[case(17, 16)]
593    fn test_iter_bits_with_offset(#[case] offset: usize, #[case] len: usize) {
594        let total_bits = offset + len;
595        let buf = BitBuffer::collect_bool(total_bits, |i| i % 2 == 0);
596        let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
597
598        let mut collected = Vec::new();
599        buf_with_offset.iter_bits(|idx, is_set| {
600            collected.push((idx, is_set));
601        });
602
603        assert_eq!(collected.len(), len);
604
605        for (idx, is_set) in collected {
606            // The bits should match the original buffer at positions offset + idx
607            assert_eq!(is_set, (offset + idx) % 2 == 0);
608        }
609    }
610
611    #[rstest]
612    #[case(8, 10)]
613    #[case(9, 7)]
614    #[case(16, 8)]
615    #[case(17, 10)]
616    fn test_iter_bits_catches_wrong_byte_offset(#[case] offset: usize, #[case] len: usize) {
617        let total_bits = offset + len;
618        // Alternating pattern to catch byte offset errors: Bits are set for even indexed bytes.
619        let buf = BitBuffer::collect_bool(total_bits, |i| (i / 8) % 2 == 0);
620
621        let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
622
623        let mut collected = Vec::new();
624        buf_with_offset.iter_bits(|idx, is_set| {
625            collected.push((idx, is_set));
626        });
627
628        assert_eq!(collected.len(), len);
629
630        for (idx, is_set) in collected {
631            let bit_position = offset + idx;
632            let byte_index = bit_position / 8;
633            let expected_is_set = byte_index % 2 == 0;
634
635            assert_eq!(
636                is_set, expected_is_set,
637                "Bit mismatch at index {}: expected {} got {}",
638                bit_position, expected_is_set, is_set
639            );
640        }
641    }
642}