1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::fmt::Result as FmtResult;
7use std::ops::BitAnd;
8use std::ops::BitOr;
9use std::ops::BitXor;
10use std::ops::Bound;
11use std::ops::Not;
12use std::ops::RangeBounds;
13
14use crate::Alignment;
15use crate::BitBufferMut;
16use crate::Buffer;
17use crate::BufferMut;
18use crate::ByteBuffer;
19use crate::bit::BitChunks;
20use crate::bit::BitIndexIterator;
21use crate::bit::BitIterator;
22use crate::bit::BitSliceIterator;
23use crate::bit::UnalignedBitChunk;
24use crate::bit::collect_bool_word;
25use crate::bit::count_ones::count_ones;
26use crate::bit::get_bit_unchecked;
27use crate::bit::ops::bitwise_binary_op;
28use crate::bit::ops::bitwise_unary_op;
29use crate::bit::select::bit_select;
30use crate::buffer;
31
32#[derive(Debug, Clone, Eq)]
34#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
35pub struct BitBuffer {
36 buffer: ByteBuffer,
37 offset: usize,
41 len: usize,
42}
43
44const LIMIT_LEN: usize = 16;
45impl Display for BitBuffer {
46 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
47 let limit = f.precision().unwrap_or(LIMIT_LEN);
48 let buf: Vec<bool> = self.into_iter().take(limit).collect();
49 f.debug_struct("BitBuffer")
50 .field("len", &self.len)
51 .field("buffer", &buf)
52 .finish()
53 }
54}
55
56impl PartialEq for BitBuffer {
57 fn eq(&self, other: &Self) -> bool {
58 if self.len != other.len {
59 return false;
60 }
61
62 self.chunks()
63 .iter_padded()
64 .zip(other.chunks().iter_padded())
65 .all(|(a, b)| a == b)
66 }
67}
68
69impl BitBuffer {
70 pub fn new(buffer: ByteBuffer, len: usize) -> Self {
74 assert!(
75 buffer.len() * 8 >= len,
76 "provided ByteBuffer not large enough to back BoolBuffer with len {len}"
77 );
78
79 let buffer = buffer.aligned(Alignment::none());
81
82 Self {
83 buffer,
84 len,
85 offset: 0,
86 }
87 }
88
89 pub fn new_with_offset(buffer: ByteBuffer, len: usize, offset: usize) -> Self {
94 assert!(
95 len.saturating_add(offset) <= buffer.len().saturating_mul(8),
96 "provided ByteBuffer (len={}) not large enough to back BoolBuffer with offset {offset} len {len}",
97 buffer.len()
98 );
99
100 let buffer = buffer.aligned(Alignment::none());
102
103 let byte_offset = offset / 8;
105 let offset = offset % 8;
106 let buffer = if byte_offset != 0 {
107 buffer.slice(byte_offset..)
108 } else {
109 buffer
110 };
111
112 Self {
113 buffer,
114 offset,
115 len,
116 }
117 }
118
119 pub fn new_set(len: usize) -> Self {
121 let words = len.div_ceil(8);
122 let buffer = buffer![0xFF; words];
123
124 Self {
125 buffer,
126 len,
127 offset: 0,
128 }
129 }
130
131 pub fn new_unset(len: usize) -> Self {
133 let words = len.div_ceil(8);
134 let buffer = Buffer::zeroed(words);
135
136 Self {
137 buffer,
138 len,
139 offset: 0,
140 }
141 }
142
143 pub fn from_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> BitBuffer {
145 BitBufferMut::from_indices(len, indices).freeze()
146 }
147
148 pub fn empty() -> Self {
150 Self::new_set(0)
151 }
152
153 pub fn full(value: bool, len: usize) -> Self {
155 if value {
156 Self::new_set(len)
157 } else {
158 Self::new_unset(len)
159 }
160 }
161
162 #[inline]
164 pub fn collect_bool<F: FnMut(usize) -> bool>(len: usize, f: F) -> Self {
165 BitBufferMut::collect_bool(len, f).freeze()
166 }
167
168 pub fn map_cmp<F>(&self, mut f: F) -> Self
173 where
174 F: FnMut(usize, bool) -> bool,
175 {
176 let len = self.len;
177 let mut buffer: BufferMut<u64> = BufferMut::with_capacity(len.div_ceil(64));
178
179 let chunks_count = len / 64;
180 let remainder = len % 64;
181 let chunks = self.chunks();
182
183 for (chunk_idx, src_chunk) in chunks.iter().enumerate() {
184 let packed = collect_bool_word(64, |bit_idx| {
185 let i = bit_idx + chunk_idx * 64;
186 let bit_value = (src_chunk >> bit_idx) & 1 == 1;
187 f(i, bit_value)
188 });
189
190 unsafe { buffer.push_unchecked(packed) }
192 }
193
194 if remainder != 0 {
195 let src_chunk = chunks.remainder_bits();
196 let packed = collect_bool_word(remainder, |bit_idx| {
197 let i = bit_idx + chunks_count * 64;
198 let bit_value = (src_chunk >> bit_idx) & 1 == 1;
199 f(i, bit_value)
200 });
201
202 unsafe { buffer.push_unchecked(packed) }
204 }
205
206 buffer.truncate(len.div_ceil(8));
207
208 Self {
209 buffer: buffer.freeze().into_byte_buffer(),
210 offset: 0,
211 len,
212 }
213 }
214
215 pub fn clear(&mut self) {
217 self.buffer.clear();
218 self.len = 0;
219 self.offset = 0;
220 }
221
222 #[inline]
227 pub fn len(&self) -> usize {
228 self.len
229 }
230
231 #[inline]
233 pub fn is_empty(&self) -> bool {
234 self.len() == 0
235 }
236
237 #[inline(always)]
239 pub fn offset(&self) -> usize {
240 self.offset
241 }
242
243 #[inline(always)]
245 pub fn inner(&self) -> &ByteBuffer {
246 &self.buffer
247 }
248
249 #[inline]
255 pub fn value(&self, index: usize) -> bool {
256 assert!(index < self.len);
257 unsafe { self.value_unchecked(index) }
258 }
259
260 #[inline]
265 pub unsafe fn value_unchecked(&self, index: usize) -> bool {
266 unsafe { get_bit_unchecked(self.buffer.as_ptr(), index + self.offset) }
267 }
268
269 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
274 let start = match range.start_bound() {
275 Bound::Included(&s) => s,
276 Bound::Excluded(&s) => s + 1,
277 Bound::Unbounded => 0,
278 };
279 let end = match range.end_bound() {
280 Bound::Included(&e) => e + 1,
281 Bound::Excluded(&e) => e,
282 Bound::Unbounded => self.len,
283 };
284
285 assert!(start <= end);
286 assert!(start <= self.len);
287 assert!(end <= self.len);
288 let len = end - start;
289
290 Self::new_with_offset(self.buffer.clone(), len, self.offset + start)
291 }
292
293 pub fn shrink_offset(self) -> Self {
295 let word_start = self.offset / 8;
296 let word_end = (self.offset + self.len).div_ceil(8);
297
298 let buffer = self.buffer.slice(word_start..word_end);
299
300 let bit_offset = self.offset % 8;
301 let len = self.len;
302 BitBuffer::new_with_offset(buffer, len, bit_offset)
303 }
304
305 pub fn unaligned_chunks(&self) -> UnalignedBitChunk<'_> {
307 UnalignedBitChunk::new(self.buffer.as_slice(), self.offset, self.len)
308 }
309
310 pub fn chunks(&self) -> BitChunks<'_> {
314 BitChunks::new(self.buffer.as_slice(), self.offset, self.len)
315 }
316
317 pub fn true_count(&self) -> usize {
319 count_ones(self.buffer.as_slice(), self.offset, self.len)
320 }
321
322 pub fn select(&self, nth: usize) -> Option<usize> {
329 bit_select(self.buffer.as_slice(), self.offset, self.len, nth)
330 }
331
332 pub fn false_count(&self) -> usize {
334 self.len - self.true_count()
335 }
336
337 pub fn iter(&self) -> BitIterator<'_> {
339 BitIterator::new(self.buffer.as_slice(), self.offset, self.len)
340 }
341
342 pub fn set_indices(&self) -> BitIndexIterator<'_> {
344 BitIndexIterator::new(self.buffer.as_slice(), self.offset, self.len)
345 }
346
347 pub fn set_slices(&self) -> BitSliceIterator<'_> {
349 BitSliceIterator::new(self.buffer.as_slice(), self.offset, self.len)
350 }
351
352 pub fn sliced(&self) -> Self {
354 if self.offset.is_multiple_of(8) {
355 return Self::new(
356 self.buffer.slice(self.offset / 8..self.len.div_ceil(8)),
357 self.len,
358 );
359 }
360
361 bitwise_unary_op(self.clone(), |a| a)
362 }
363}
364
365impl BitBuffer {
368 pub fn into_inner(self) -> (usize, usize, ByteBuffer) {
370 (self.offset, self.len, self.buffer)
371 }
372
373 pub fn try_into_mut(self) -> Result<BitBufferMut, Self> {
375 match self.buffer.try_into_mut() {
376 Ok(buffer) => Ok(BitBufferMut::from_buffer(buffer, self.offset, self.len)),
377 Err(buffer) => Err(BitBuffer::new_with_offset(buffer, self.len, self.offset)),
378 }
379 }
380}
381
382impl From<&[bool]> for BitBuffer {
383 fn from(value: &[bool]) -> Self {
384 BitBufferMut::from(value).freeze()
385 }
386}
387
388impl From<Vec<bool>> for BitBuffer {
389 fn from(value: Vec<bool>) -> Self {
390 BitBufferMut::from(value).freeze()
391 }
392}
393
394impl FromIterator<bool> for BitBuffer {
395 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
396 BitBufferMut::from_iter(iter).freeze()
397 }
398}
399
400impl BitOr for BitBuffer {
401 type Output = Self;
402
403 #[inline]
404 fn bitor(self, rhs: Self) -> Self::Output {
405 BitOr::bitor(&self, &rhs)
406 }
407}
408
409impl BitOr for &BitBuffer {
410 type Output = BitBuffer;
411
412 #[inline]
413 fn bitor(self, rhs: Self) -> Self::Output {
414 bitwise_binary_op(self, rhs, |a, b| a | b)
415 }
416}
417
418impl BitOr<&BitBuffer> for BitBuffer {
419 type Output = BitBuffer;
420
421 #[inline]
422 fn bitor(self, rhs: &BitBuffer) -> Self::Output {
423 (&self).bitor(rhs)
424 }
425}
426
427impl BitAnd for &BitBuffer {
428 type Output = BitBuffer;
429
430 #[inline]
431 fn bitand(self, rhs: Self) -> Self::Output {
432 bitwise_binary_op(self, rhs, |a, b| a & b)
433 }
434}
435
436impl BitAnd<BitBuffer> for &BitBuffer {
437 type Output = BitBuffer;
438
439 #[inline]
440 fn bitand(self, rhs: BitBuffer) -> Self::Output {
441 self.bitand(&rhs)
442 }
443}
444
445impl BitAnd<&BitBuffer> for BitBuffer {
446 type Output = BitBuffer;
447
448 #[inline]
449 fn bitand(self, rhs: &BitBuffer) -> Self::Output {
450 (&self).bitand(rhs)
451 }
452}
453
454impl BitAnd<BitBuffer> for BitBuffer {
455 type Output = BitBuffer;
456
457 #[inline]
458 fn bitand(self, rhs: BitBuffer) -> Self::Output {
459 (&self).bitand(&rhs)
460 }
461}
462
463impl Not for &BitBuffer {
464 type Output = BitBuffer;
465
466 #[inline]
467 fn not(self) -> Self::Output {
468 !self.clone()
469 }
470}
471
472impl Not for BitBuffer {
473 type Output = BitBuffer;
474
475 #[inline]
476 fn not(self) -> Self::Output {
477 bitwise_unary_op(self, |a| !a)
478 }
479}
480
481impl BitXor for &BitBuffer {
482 type Output = BitBuffer;
483
484 #[inline]
485 fn bitxor(self, rhs: Self) -> Self::Output {
486 bitwise_binary_op(self, rhs, |a, b| a ^ b)
487 }
488}
489
490impl BitXor<&BitBuffer> for BitBuffer {
491 type Output = BitBuffer;
492
493 #[inline]
494 fn bitxor(self, rhs: &BitBuffer) -> Self::Output {
495 (&self).bitxor(rhs)
496 }
497}
498
499impl BitBuffer {
500 pub fn bitand_not(&self, rhs: &BitBuffer) -> BitBuffer {
505 bitwise_binary_op(self, rhs, |a, b| a & !b)
506 }
507
508 #[inline]
518 pub fn iter_bits<F>(&self, mut f: F)
519 where
520 F: FnMut(usize, bool),
521 {
522 let total_bits = self.len;
523 if total_bits == 0 {
524 return;
525 }
526
527 let is_bit_set = |byte: u8, bit_idx: usize| (byte & (1 << bit_idx)) != 0;
528 let bit_offset = self.offset % 8;
529 let mut buffer_ptr = unsafe { self.buffer.as_ptr().add(self.offset / 8) };
530 let mut callback_idx = 0;
531
532 if bit_offset > 0 {
534 let bits_in_first_byte = (8 - bit_offset).min(total_bits);
535 let byte = unsafe { *buffer_ptr };
536
537 for bit_idx in 0..bits_in_first_byte {
538 f(callback_idx, is_bit_set(byte, bit_offset + bit_idx));
539 callback_idx += 1;
540 }
541
542 buffer_ptr = unsafe { buffer_ptr.add(1) };
543 }
544
545 let complete_bytes = (total_bits - callback_idx) / 8;
547 for _ in 0..complete_bytes {
548 let byte = unsafe { *buffer_ptr };
549
550 for bit_idx in 0..8 {
551 f(callback_idx, is_bit_set(byte, bit_idx));
552 callback_idx += 1;
553 }
554 buffer_ptr = unsafe { buffer_ptr.add(1) };
555 }
556
557 let remaining_bits = total_bits - callback_idx;
559 if remaining_bits > 0 {
560 let byte = unsafe { *buffer_ptr };
561
562 for bit_idx in 0..remaining_bits {
563 f(callback_idx, is_bit_set(byte, bit_idx));
564 callback_idx += 1;
565 }
566 }
567 }
568}
569
570impl<'a> IntoIterator for &'a BitBuffer {
571 type Item = bool;
572 type IntoIter = BitIterator<'a>;
573
574 fn into_iter(self) -> Self::IntoIter {
575 self.iter()
576 }
577}
578
579#[cfg(test)]
580mod tests {
581 use rstest::rstest;
582
583 use crate::ByteBuffer;
584 use crate::bit::BitBuffer;
585 use crate::buffer;
586
587 #[test]
588 fn test_bool() {
589 let buffer: ByteBuffer = buffer![1 << 7; 1024];
591 let bools = BitBuffer::new(buffer, 1024 * 8);
592
593 assert_eq!(bools.len(), 1024 * 8);
595 assert!(!bools.is_empty());
596 assert_eq!(bools.true_count(), 1024);
597 assert_eq!(bools.false_count(), 1024 * 7);
598
599 for word in 0..1024 {
601 for bit in 0..8 {
602 if bit == 7 {
603 assert!(bools.value(word * 8 + bit));
604 } else {
605 assert!(!bools.value(word * 8 + bit));
606 }
607 }
608 }
609
610 let sliced = bools.slice(64..72);
612
613 assert_eq!(sliced.len(), 8);
615 assert!(!sliced.is_empty());
616 assert_eq!(sliced.true_count(), 1);
617 assert_eq!(sliced.false_count(), 7);
618
619 for bit in 0..8 {
621 if bit == 7 {
622 assert!(sliced.value(bit));
623 } else {
624 assert!(!sliced.value(bit));
625 }
626 }
627 }
628
629 #[test]
630 fn test_padded_equaltiy() {
631 let buf1 = BitBuffer::new_set(64); let buf2 = BitBuffer::collect_bool(64, |x| x < 32); for i in 0..32 {
635 assert_eq!(buf1.value(i), buf2.value(i), "Bit {} should be the same", i);
636 }
637
638 for i in 32..64 {
639 assert_ne!(buf1.value(i), buf2.value(i), "Bit {} should differ", i);
640 }
641
642 assert_eq!(
643 buf1.slice(0..32),
644 buf2.slice(0..32),
645 "Buffer slices with same bits should be equal (`PartialEq` needs `iter_padded()`)"
646 );
647 assert_ne!(
648 buf1.slice(32..64),
649 buf2.slice(32..64),
650 "Buffer slices with different bits should not be equal (`PartialEq` needs `iter_padded()`)"
651 );
652 }
653
654 #[test]
655 fn test_slice_offset_calculation() {
656 let buf = BitBuffer::collect_bool(16, |_| true);
657 let sliced = buf.slice(10..16);
658 assert_eq!(sliced.len(), 6);
659 assert_eq!(sliced.offset(), 2);
661 }
662
663 #[test]
664 fn test_from_indices_dense_crosses_words() {
665 let len = 130;
666 let indices = (0..len).filter(|idx| idx % 3 != 1);
667 let buf = BitBuffer::from_indices(len, indices);
668
669 assert_eq!(buf.len(), len);
670 for idx in 0..len {
671 assert_eq!(buf.value(idx), idx % 3 != 1, "mismatch at {idx}");
672 }
673 }
674
675 #[test]
676 #[should_panic(expected = "index 5 exceeds len 5")]
677 fn test_from_indices_out_of_bounds() {
678 BitBuffer::from_indices(5, [0, 5]);
679 }
680
681 #[rstest]
682 #[case(5)]
683 #[case(8)]
684 #[case(10)]
685 #[case(13)]
686 #[case(16)]
687 #[case(23)]
688 #[case(100)]
689 fn test_iter_bits(#[case] len: usize) {
690 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
691
692 let mut collected = Vec::new();
693 buf.iter_bits(|idx, is_set| {
694 collected.push((idx, is_set));
695 });
696
697 assert_eq!(collected.len(), len);
698
699 for (idx, is_set) in collected {
700 assert_eq!(is_set, idx % 2 == 0);
701 }
702 }
703
704 #[rstest]
705 #[case(3, 5)]
706 #[case(3, 8)]
707 #[case(5, 10)]
708 #[case(2, 16)]
709 #[case(8, 16)]
710 #[case(9, 16)]
711 #[case(17, 16)]
712 fn test_iter_bits_with_offset(#[case] offset: usize, #[case] len: usize) {
713 let total_bits = offset + len;
714 let buf = BitBuffer::collect_bool(total_bits, |i| i % 2 == 0);
715 let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
716
717 let mut collected = Vec::new();
718 buf_with_offset.iter_bits(|idx, is_set| {
719 collected.push((idx, is_set));
720 });
721
722 assert_eq!(collected.len(), len);
723
724 for (idx, is_set) in collected {
725 assert_eq!(is_set, (offset + idx).is_multiple_of(2));
727 }
728 }
729
730 #[rstest]
731 #[case(8, 10)]
732 #[case(9, 7)]
733 #[case(16, 8)]
734 #[case(17, 10)]
735 fn test_iter_bits_catches_wrong_byte_offset(#[case] offset: usize, #[case] len: usize) {
736 let total_bits = offset + len;
737 let buf = BitBuffer::collect_bool(total_bits, |i| (i / 8) % 2 == 0);
739
740 let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
741
742 let mut collected = Vec::new();
743 buf_with_offset.iter_bits(|idx, is_set| {
744 collected.push((idx, is_set));
745 });
746
747 assert_eq!(collected.len(), len);
748
749 for (idx, is_set) in collected {
750 let bit_position = offset + idx;
751 let byte_index = bit_position / 8;
752 let expected_is_set = byte_index.is_multiple_of(2);
753
754 assert_eq!(
755 is_set, expected_is_set,
756 "Bit mismatch at index {}: expected {} got {}",
757 bit_position, expected_is_set, is_set
758 );
759 }
760 }
761
762 #[rstest]
763 #[case(5)]
764 #[case(8)]
765 #[case(10)]
766 #[case(64)]
767 #[case(65)]
768 #[case(100)]
769 #[case(128)]
770 fn test_map_cmp_identity(#[case] len: usize) {
771 let buf = BitBuffer::collect_bool(len, |i| i % 3 == 0);
773 let mapped = buf.map_cmp(|_idx, bit| bit);
774
775 assert_eq!(buf.len(), mapped.len());
776 for i in 0..len {
777 assert_eq!(buf.value(i), mapped.value(i), "Mismatch at index {}", i);
778 }
779 }
780
781 #[rstest]
782 #[case(5)]
783 #[case(8)]
784 #[case(64)]
785 #[case(65)]
786 #[case(100)]
787 fn test_map_cmp_negate(#[case] len: usize) {
788 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
790 let mapped = buf.map_cmp(|_idx, bit| !bit);
791
792 assert_eq!(buf.len(), mapped.len());
793 for i in 0..len {
794 assert_eq!(!buf.value(i), mapped.value(i), "Mismatch at index {}", i);
795 }
796 }
797
798 #[test]
799 fn test_map_cmp_conditional() {
800 let len = 100;
802 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
803
804 let mapped = buf.map_cmp(|idx, bit| bit && idx % 4 == 0);
806
807 for i in 0..len {
808 let expected = (i % 2 == 0) && (i % 4 == 0);
809 assert_eq!(mapped.value(i), expected, "Mismatch at index {}", i);
810 }
811 }
812}