vortex_buffer/bit/
buf.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::BitAnd;
5use std::ops::BitOr;
6use std::ops::BitXor;
7use std::ops::Bound;
8use std::ops::Not;
9use std::ops::RangeBounds;
10
11use crate::Alignment;
12use crate::BitBufferMut;
13use crate::Buffer;
14use crate::ByteBuffer;
15use crate::bit::BitChunks;
16use crate::bit::BitIndexIterator;
17use crate::bit::BitIterator;
18use crate::bit::BitSliceIterator;
19use crate::bit::UnalignedBitChunk;
20use crate::bit::get_bit_unchecked;
21use crate::bit::ops::bitwise_binary_op;
22use crate::bit::ops::bitwise_unary_op;
23use crate::buffer;
24
25/// An immutable bitset stored as a packed byte buffer.
26#[derive(Debug, Clone, Eq)]
27#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
28pub struct BitBuffer {
29    buffer: ByteBuffer,
30    /// Represents the offset of the bit buffer into the first byte.
31    ///
32    /// This is always less than 8 (for when the bit buffer is not aligned to a byte).
33    offset: usize,
34    len: usize,
35}
36
37impl PartialEq for BitBuffer {
38    fn eq(&self, other: &Self) -> bool {
39        if self.len != other.len {
40            return false;
41        }
42
43        self.chunks()
44            .iter_padded()
45            .zip(other.chunks().iter_padded())
46            .all(|(a, b)| a == b)
47    }
48}
49
50impl BitBuffer {
51    /// Create a new `BoolBuffer` backed by a [`ByteBuffer`] with `len` bits in view.
52    ///
53    /// Panics if the buffer is not large enough to hold `len` bits.
54    pub fn new(buffer: ByteBuffer, len: usize) -> Self {
55        assert!(
56            buffer.len() * 8 >= len,
57            "provided ByteBuffer not large enough to back BoolBuffer with len {len}"
58        );
59
60        // BitBuffers make no assumptions on byte alignment, so we strip any alignment.
61        let buffer = buffer.aligned(Alignment::none());
62
63        Self {
64            buffer,
65            len,
66            offset: 0,
67        }
68    }
69
70    /// Create a new `BoolBuffer` backed by a [`ByteBuffer`] with `len` bits in view, starting at
71    /// the given `offset` (in bits).
72    ///
73    /// Panics if the buffer is not large enough to hold `len` bits after the offset.
74    pub fn new_with_offset(buffer: ByteBuffer, len: usize, offset: usize) -> Self {
75        assert!(
76            len.saturating_add(offset) <= buffer.len().saturating_mul(8),
77            "provided ByteBuffer (len={}) not large enough to back BoolBuffer with offset {offset} len {len}",
78            buffer.len()
79        );
80
81        // BitBuffers make no assumptions on byte alignment, so we strip any alignment.
82        let buffer = buffer.aligned(Alignment::none());
83
84        // Slice the buffer to ensure the offset is within the first byte
85        let byte_offset = offset / 8;
86        let offset = offset % 8;
87        let buffer = buffer.slice(byte_offset..);
88
89        Self {
90            buffer,
91            offset,
92            len,
93        }
94    }
95
96    /// Create a new `BoolBuffer` of length `len` where all bits are set (true).
97    pub fn new_set(len: usize) -> Self {
98        let words = len.div_ceil(8);
99        let buffer = buffer![0xFF; words];
100
101        Self {
102            buffer,
103            len,
104            offset: 0,
105        }
106    }
107
108    /// Create a new `BoolBuffer` of length `len` where all bits are unset (false).
109    pub fn new_unset(len: usize) -> Self {
110        let words = len.div_ceil(8);
111        let buffer = Buffer::zeroed(words);
112
113        Self {
114            buffer,
115            len,
116            offset: 0,
117        }
118    }
119
120    /// Create a new empty `BitBuffer`.
121    pub fn empty() -> Self {
122        Self::new_set(0)
123    }
124
125    /// Create a new `BitBuffer` of length `len` where all bits are set to `value`.
126    pub fn full(value: bool, len: usize) -> Self {
127        if value {
128            Self::new_set(len)
129        } else {
130            Self::new_unset(len)
131        }
132    }
133
134    /// Invokes `f` with indexes `0..len` collecting the boolean results into a new [`BitBuffer`].
135    pub fn collect_bool<F: FnMut(usize) -> bool>(len: usize, f: F) -> Self {
136        BitBufferMut::collect_bool(len, f).freeze()
137    }
138
139    /// Clear all bits in the buffer, preserving existing capacity.
140    pub fn clear(&mut self) {
141        self.buffer.clear();
142        self.len = 0;
143        self.offset = 0;
144    }
145
146    /// Get the logical length of this `BoolBuffer`.
147    ///
148    /// This may differ from the physical length of the backing buffer, for example if it was
149    /// created using the `new_with_offset` constructor, or if it was sliced.
150    #[inline]
151    pub fn len(&self) -> usize {
152        self.len
153    }
154
155    /// Returns `true` if the `BoolBuffer` is empty.
156    #[inline]
157    pub fn is_empty(&self) -> bool {
158        self.len() == 0
159    }
160
161    /// Offset of the start of the buffer in bits.
162    #[inline(always)]
163    pub fn offset(&self) -> usize {
164        self.offset
165    }
166
167    /// Get a reference to the underlying buffer.
168    #[inline(always)]
169    pub fn inner(&self) -> &ByteBuffer {
170        &self.buffer
171    }
172
173    /// Retrieve the value at the given index.
174    ///
175    /// Panics if the index is out of bounds.
176    ///
177    /// Please note for repeatedly calling this function, please prefer [`crate::get_bit`].
178    #[inline]
179    pub fn value(&self, index: usize) -> bool {
180        assert!(index < self.len);
181        unsafe { self.value_unchecked(index) }
182    }
183
184    /// Retrieve the value at the given index without bounds checking
185    ///
186    /// # SAFETY
187    /// Caller must ensure that index is within the range of the buffer
188    #[inline]
189    pub unsafe fn value_unchecked(&self, index: usize) -> bool {
190        unsafe { get_bit_unchecked(self.buffer.as_ptr(), index + self.offset) }
191    }
192
193    /// Create a new zero-copy slice of this BoolBuffer that begins at the `start` index and extends
194    /// for `len` bits.
195    ///
196    /// Panics if the slice would extend beyond the end of the buffer.
197    pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
198        let start = match range.start_bound() {
199            Bound::Included(&s) => s,
200            Bound::Excluded(&s) => s + 1,
201            Bound::Unbounded => 0,
202        };
203        let end = match range.end_bound() {
204            Bound::Included(&e) => e + 1,
205            Bound::Excluded(&e) => e,
206            Bound::Unbounded => self.len,
207        };
208
209        assert!(start <= end);
210        assert!(start <= self.len);
211        assert!(end <= self.len);
212        let len = end - start;
213
214        Self::new_with_offset(self.buffer.clone(), len, self.offset + start)
215    }
216
217    /// Slice any full bytes from the buffer, leaving the offset < 8.
218    pub fn shrink_offset(self) -> Self {
219        let word_start = self.offset / 8;
220        let word_end = (self.offset + self.len).div_ceil(8);
221
222        let buffer = self.buffer.slice(word_start..word_end);
223
224        let bit_offset = self.offset % 8;
225        let len = self.len;
226        BitBuffer::new_with_offset(buffer, len, bit_offset)
227    }
228
229    /// Access chunks of the buffer aligned to 8 byte boundary as [prefix, \<full chunks\>, suffix]
230    pub fn unaligned_chunks(&self) -> UnalignedBitChunk<'_> {
231        UnalignedBitChunk::new(self.buffer.as_slice(), self.offset, self.len)
232    }
233
234    /// Access chunks of the underlying buffer as 8 byte chunks with a final trailer
235    ///
236    /// If you're performing operations on a single buffer, prefer [BitBuffer::unaligned_chunks]
237    pub fn chunks(&self) -> BitChunks<'_> {
238        BitChunks::new(self.buffer.as_slice(), self.offset, self.len)
239    }
240
241    /// Get the number of set bits in the buffer.
242    pub fn true_count(&self) -> usize {
243        self.unaligned_chunks().count_ones()
244    }
245
246    /// Get the number of unset bits in the buffer.
247    pub fn false_count(&self) -> usize {
248        self.len - self.true_count()
249    }
250
251    /// Iterator over bits in the buffer
252    pub fn iter(&self) -> BitIterator<'_> {
253        BitIterator::new(self.buffer.as_slice(), self.offset, self.len)
254    }
255
256    /// Iterator over set indices of the underlying buffer
257    pub fn set_indices(&self) -> BitIndexIterator<'_> {
258        BitIndexIterator::new(self.buffer.as_slice(), self.offset, self.len)
259    }
260
261    /// Iterator over set slices of the underlying buffer
262    pub fn set_slices(&self) -> BitSliceIterator<'_> {
263        BitSliceIterator::new(self.buffer.as_slice(), self.offset, self.len)
264    }
265
266    /// Created a new BitBuffer with offset reset to 0
267    pub fn sliced(&self) -> Self {
268        if self.offset % 8 == 0 {
269            return Self::new(
270                self.buffer.slice(self.offset / 8..self.len.div_ceil(8)),
271                self.len,
272            );
273        }
274        bitwise_unary_op(self, |a| a)
275    }
276}
277
278// Conversions
279
280impl BitBuffer {
281    /// Returns the offset, len and underlying buffer.
282    pub fn into_inner(self) -> (usize, usize, ByteBuffer) {
283        (self.offset, self.len, self.buffer)
284    }
285
286    /// Attempt to convert this `BitBuffer` into a mutable version.
287    pub fn try_into_mut(self) -> Result<BitBufferMut, Self> {
288        match self.buffer.try_into_mut() {
289            Ok(buffer) => Ok(BitBufferMut::from_buffer(buffer, self.offset, self.len)),
290            Err(buffer) => Err(BitBuffer::new_with_offset(buffer, self.len, self.offset)),
291        }
292    }
293
294    /// Get a mutable version of this `BitBuffer` along with bit offset in the first byte.
295    ///
296    /// If the caller doesn't hold only reference to the underlying buffer, a copy is created.
297    /// The second value of the tuple is a bit_offset of the first value in the first byte
298    pub fn into_mut(self) -> BitBufferMut {
299        let (offset, len, inner) = self.into_inner();
300        // TODO(robert): if we are copying here we could strip offset bits
301        BitBufferMut::from_buffer(inner.into_mut(), offset, len)
302    }
303}
304
305impl From<&[bool]> for BitBuffer {
306    fn from(value: &[bool]) -> Self {
307        BitBufferMut::from(value).freeze()
308    }
309}
310
311impl From<Vec<bool>> for BitBuffer {
312    fn from(value: Vec<bool>) -> Self {
313        BitBufferMut::from(value).freeze()
314    }
315}
316
317impl FromIterator<bool> for BitBuffer {
318    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
319        BitBufferMut::from_iter(iter).freeze()
320    }
321}
322
323impl BitOr for &BitBuffer {
324    type Output = BitBuffer;
325
326    fn bitor(self, rhs: Self) -> Self::Output {
327        bitwise_binary_op(self, rhs, |a, b| a | b)
328    }
329}
330
331impl BitOr<&BitBuffer> for BitBuffer {
332    type Output = BitBuffer;
333
334    fn bitor(self, rhs: &BitBuffer) -> Self::Output {
335        (&self).bitor(rhs)
336    }
337}
338
339impl BitAnd for &BitBuffer {
340    type Output = BitBuffer;
341
342    fn bitand(self, rhs: Self) -> Self::Output {
343        bitwise_binary_op(self, rhs, |a, b| a & b)
344    }
345}
346
347impl BitAnd<BitBuffer> for &BitBuffer {
348    type Output = BitBuffer;
349
350    fn bitand(self, rhs: BitBuffer) -> Self::Output {
351        self.bitand(&rhs)
352    }
353}
354
355impl BitAnd<&BitBuffer> for BitBuffer {
356    type Output = BitBuffer;
357
358    fn bitand(self, rhs: &BitBuffer) -> Self::Output {
359        (&self).bitand(rhs)
360    }
361}
362
363impl Not for &BitBuffer {
364    type Output = BitBuffer;
365
366    fn not(self) -> Self::Output {
367        bitwise_unary_op(self, |a| !a)
368    }
369}
370
371impl Not for BitBuffer {
372    type Output = BitBuffer;
373
374    fn not(self) -> Self::Output {
375        (&self).not()
376    }
377}
378
379impl BitXor for &BitBuffer {
380    type Output = BitBuffer;
381
382    fn bitxor(self, rhs: Self) -> Self::Output {
383        bitwise_binary_op(self, rhs, |a, b| a ^ b)
384    }
385}
386
387impl BitXor<&BitBuffer> for BitBuffer {
388    type Output = BitBuffer;
389
390    fn bitxor(self, rhs: &BitBuffer) -> Self::Output {
391        (&self).bitxor(rhs)
392    }
393}
394
395impl BitBuffer {
396    /// Create a new BitBuffer by performing a bitwise AND NOT operation between two BitBuffers.
397    ///
398    /// This operation is sufficiently common that we provide a dedicated method for it avoid
399    /// making two passes over the data.
400    pub fn bitand_not(&self, rhs: &BitBuffer) -> BitBuffer {
401        bitwise_binary_op(self, rhs, |a, b| a & !b)
402    }
403
404    /// Iterate through bits in a buffer.
405    ///
406    /// # Arguments
407    ///
408    /// * `f` - Callback function taking (bit_index, is_set)
409    ///
410    /// # Panics
411    ///
412    /// Panics if the range is outside valid bounds of the buffer.
413    #[inline]
414    pub fn iter_bits<F>(&self, mut f: F)
415    where
416        F: FnMut(usize, bool),
417    {
418        let total_bits = self.len;
419        if total_bits == 0 {
420            return;
421        }
422
423        let is_bit_set = |byte: u8, bit_idx: usize| (byte & (1 << bit_idx)) != 0;
424        let bit_offset = self.offset % 8;
425        let mut buffer_ptr = unsafe { self.buffer.as_ptr().add(self.offset / 8) };
426        let mut callback_idx = 0;
427
428        // Handle incomplete first byte.
429        if bit_offset > 0 {
430            let bits_in_first_byte = (8 - bit_offset).min(total_bits);
431            let byte = unsafe { *buffer_ptr };
432
433            for bit_idx in 0..bits_in_first_byte {
434                f(callback_idx, is_bit_set(byte, bit_offset + bit_idx));
435                callback_idx += 1;
436            }
437
438            buffer_ptr = unsafe { buffer_ptr.add(1) };
439        }
440
441        // Process complete bytes.
442        let complete_bytes = (total_bits - callback_idx) / 8;
443        for _ in 0..complete_bytes {
444            let byte = unsafe { *buffer_ptr };
445
446            for bit_idx in 0..8 {
447                f(callback_idx, is_bit_set(byte, bit_idx));
448                callback_idx += 1;
449            }
450            buffer_ptr = unsafe { buffer_ptr.add(1) };
451        }
452
453        // Handle remaining bits at the end.
454        let remaining_bits = total_bits - callback_idx;
455        if remaining_bits > 0 {
456            let byte = unsafe { *buffer_ptr };
457
458            for bit_idx in 0..remaining_bits {
459                f(callback_idx, is_bit_set(byte, bit_idx));
460                callback_idx += 1;
461            }
462        }
463    }
464}
465
466impl<'a> IntoIterator for &'a BitBuffer {
467    type Item = bool;
468    type IntoIter = BitIterator<'a>;
469
470    fn into_iter(self) -> Self::IntoIter {
471        self.iter()
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use rstest::rstest;
478
479    use crate::ByteBuffer;
480    use crate::bit::BitBuffer;
481    use crate::buffer;
482
483    #[test]
484    fn test_bool() {
485        // Create a new Buffer<u64> of length 1024 where the 8th bit is set.
486        let buffer: ByteBuffer = buffer![1 << 7; 1024];
487        let bools = BitBuffer::new(buffer, 1024 * 8);
488
489        // sanity checks
490        assert_eq!(bools.len(), 1024 * 8);
491        assert!(!bools.is_empty());
492        assert_eq!(bools.true_count(), 1024);
493        assert_eq!(bools.false_count(), 1024 * 7);
494
495        // Check all the values
496        for word in 0..1024 {
497            for bit in 0..8 {
498                if bit == 7 {
499                    assert!(bools.value(word * 8 + bit));
500                } else {
501                    assert!(!bools.value(word * 8 + bit));
502                }
503            }
504        }
505
506        // Slice the buffer to create a new subset view.
507        let sliced = bools.slice(64..72);
508
509        // sanity checks
510        assert_eq!(sliced.len(), 8);
511        assert!(!sliced.is_empty());
512        assert_eq!(sliced.true_count(), 1);
513        assert_eq!(sliced.false_count(), 7);
514
515        // Check all of the values like before
516        for bit in 0..8 {
517            if bit == 7 {
518                assert!(sliced.value(bit));
519            } else {
520                assert!(!sliced.value(bit));
521            }
522        }
523    }
524
525    #[test]
526    fn test_padded_equaltiy() {
527        let buf1 = BitBuffer::new_set(64); // All bits set.
528        let buf2 = BitBuffer::collect_bool(64, |x| x < 32); // First half set, other half unset.
529
530        for i in 0..32 {
531            assert_eq!(buf1.value(i), buf2.value(i), "Bit {} should be the same", i);
532        }
533
534        for i in 32..64 {
535            assert_ne!(buf1.value(i), buf2.value(i), "Bit {} should differ", i);
536        }
537
538        assert_eq!(
539            buf1.slice(0..32),
540            buf2.slice(0..32),
541            "Buffer slices with same bits should be equal (`PartialEq` needs `iter_padded()`)"
542        );
543        assert_ne!(
544            buf1.slice(32..64),
545            buf2.slice(32..64),
546            "Buffer slices with different bits should not be equal (`PartialEq` needs `iter_padded()`)"
547        );
548    }
549
550    #[test]
551    fn test_slice_offset_calculation() {
552        let buf = BitBuffer::collect_bool(16, |_| true);
553        let sliced = buf.slice(10..16);
554        assert_eq!(sliced.len(), 6);
555        // Ensure the offset is modulo 8
556        assert_eq!(sliced.offset(), 2);
557    }
558
559    #[rstest]
560    #[case(5)]
561    #[case(8)]
562    #[case(10)]
563    #[case(13)]
564    #[case(16)]
565    #[case(23)]
566    #[case(100)]
567    fn test_iter_bits(#[case] len: usize) {
568        let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
569
570        let mut collected = Vec::new();
571        buf.iter_bits(|idx, is_set| {
572            collected.push((idx, is_set));
573        });
574
575        assert_eq!(collected.len(), len);
576
577        for (idx, is_set) in collected {
578            assert_eq!(is_set, idx % 2 == 0);
579        }
580    }
581
582    #[rstest]
583    #[case(3, 5)]
584    #[case(3, 8)]
585    #[case(5, 10)]
586    #[case(2, 16)]
587    #[case(8, 16)]
588    #[case(9, 16)]
589    #[case(17, 16)]
590    fn test_iter_bits_with_offset(#[case] offset: usize, #[case] len: usize) {
591        let total_bits = offset + len;
592        let buf = BitBuffer::collect_bool(total_bits, |i| i % 2 == 0);
593        let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
594
595        let mut collected = Vec::new();
596        buf_with_offset.iter_bits(|idx, is_set| {
597            collected.push((idx, is_set));
598        });
599
600        assert_eq!(collected.len(), len);
601
602        for (idx, is_set) in collected {
603            // The bits should match the original buffer at positions offset + idx
604            assert_eq!(is_set, (offset + idx) % 2 == 0);
605        }
606    }
607
608    #[rstest]
609    #[case(8, 10)]
610    #[case(9, 7)]
611    #[case(16, 8)]
612    #[case(17, 10)]
613    fn test_iter_bits_catches_wrong_byte_offset(#[case] offset: usize, #[case] len: usize) {
614        let total_bits = offset + len;
615        // Alternating pattern to catch byte offset errors: Bits are set for even indexed bytes.
616        let buf = BitBuffer::collect_bool(total_bits, |i| (i / 8) % 2 == 0);
617
618        let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
619
620        let mut collected = Vec::new();
621        buf_with_offset.iter_bits(|idx, is_set| {
622            collected.push((idx, is_set));
623        });
624
625        assert_eq!(collected.len(), len);
626
627        for (idx, is_set) in collected {
628            let bit_position = offset + idx;
629            let byte_index = bit_position / 8;
630            let expected_is_set = byte_index % 2 == 0;
631
632            assert_eq!(
633                is_set, expected_is_set,
634                "Bit mismatch at index {}: expected {} got {}",
635                bit_position, expected_is_set, is_set
636            );
637        }
638    }
639}