1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::fmt::Result as FmtResult;
7use std::ops::BitAnd;
8use std::ops::BitOr;
9use std::ops::BitXor;
10use std::ops::Bound;
11use std::ops::Not;
12use std::ops::RangeBounds;
13
14use crate::Alignment;
15use crate::BitBufferMut;
16use crate::Buffer;
17use crate::BufferMut;
18use crate::ByteBuffer;
19use crate::bit::BitChunks;
20use crate::bit::BitIndexIterator;
21use crate::bit::BitIterator;
22use crate::bit::BitSliceIterator;
23use crate::bit::UnalignedBitChunk;
24use crate::bit::collect_bool_word;
25use crate::bit::count_ones::count_ones;
26use crate::bit::get_bit_unchecked;
27use crate::bit::ops::bitwise_binary_op;
28use crate::bit::ops::bitwise_binary_op_lhs_owned;
29use crate::bit::ops::bitwise_unary_op;
30use crate::bit::ops::bitwise_unary_op_copy;
31use crate::bit::select::bit_select;
32use crate::buffer;
33
34#[derive(Debug, Clone, Eq)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37pub struct BitBuffer {
38 buffer: ByteBuffer,
39 offset: usize,
43 len: usize,
44}
45
46const LIMIT_LEN: usize = 16;
47impl Display for BitBuffer {
48 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
49 let limit = f.precision().unwrap_or(LIMIT_LEN);
50 let buf: Vec<bool> = self.into_iter().take(limit).collect();
51 f.debug_struct("BitBuffer")
52 .field("len", &self.len)
53 .field("buffer", &buf)
54 .finish()
55 }
56}
57
58impl PartialEq for BitBuffer {
59 fn eq(&self, other: &Self) -> bool {
60 if self.len != other.len {
61 return false;
62 }
63
64 if self.len == 0 {
65 return true;
66 }
67
68 if self.offset == 0 && other.offset == 0 {
70 let full_bytes = self.len / 8;
71 let self_bytes = &self.buffer.as_slice()[..full_bytes];
72 let other_bytes = &other.buffer.as_slice()[..full_bytes];
73 if self_bytes != other_bytes {
74 return false;
75 }
76 let rem = self.len % 8;
78 if rem != 0 {
79 let mask = (1u8 << rem) - 1;
80 let a = self.buffer.as_slice()[full_bytes] & mask;
81 let b = other.buffer.as_slice()[full_bytes] & mask;
82 return a == b;
83 }
84 return true;
85 }
86
87 self.chunks()
88 .iter_padded()
89 .zip(other.chunks().iter_padded())
90 .all(|(a, b)| a == b)
91 }
92}
93
94impl BitBuffer {
95 pub fn new(buffer: ByteBuffer, len: usize) -> Self {
99 assert!(
100 buffer.len() * 8 >= len,
101 "provided ByteBuffer not large enough to back BoolBuffer with len {len}"
102 );
103
104 let buffer = buffer.aligned(Alignment::none());
106
107 Self {
108 buffer,
109 len,
110 offset: 0,
111 }
112 }
113
114 pub fn new_with_offset(buffer: ByteBuffer, len: usize, offset: usize) -> Self {
119 assert!(
120 len.saturating_add(offset) <= buffer.len().saturating_mul(8),
121 "provided ByteBuffer (len={}) not large enough to back BoolBuffer with offset {offset} len {len}",
122 buffer.len()
123 );
124
125 let buffer = buffer.aligned(Alignment::none());
127
128 let byte_offset = offset / 8;
130 let offset = offset % 8;
131 let buffer = if byte_offset != 0 {
132 buffer.slice(byte_offset..)
133 } else {
134 buffer
135 };
136
137 Self {
138 buffer,
139 offset,
140 len,
141 }
142 }
143
144 pub fn new_set(len: usize) -> Self {
146 let words = len.div_ceil(8);
147 let buffer = buffer![0xFF; words];
148
149 Self {
150 buffer,
151 len,
152 offset: 0,
153 }
154 }
155
156 pub fn new_unset(len: usize) -> Self {
158 let words = len.div_ceil(8);
159 let buffer = Buffer::zeroed(words);
160
161 Self {
162 buffer,
163 len,
164 offset: 0,
165 }
166 }
167
168 pub fn from_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> BitBuffer {
170 BitBufferMut::from_indices(len, indices).freeze()
171 }
172
173 pub fn empty() -> Self {
175 Self::new_set(0)
176 }
177
178 pub fn full(value: bool, len: usize) -> Self {
180 if value {
181 Self::new_set(len)
182 } else {
183 Self::new_unset(len)
184 }
185 }
186
187 #[inline]
189 pub fn collect_bool<F: FnMut(usize) -> bool>(len: usize, f: F) -> Self {
190 BitBufferMut::collect_bool(len, f).freeze()
191 }
192
193 pub fn map_cmp<F>(&self, mut f: F) -> Self
198 where
199 F: FnMut(usize, bool) -> bool,
200 {
201 let len = self.len;
202 let mut buffer: BufferMut<u64> = BufferMut::with_capacity(len.div_ceil(64));
203
204 let chunks_count = len / 64;
205 let remainder = len % 64;
206 let chunks = self.chunks();
207
208 for (chunk_idx, src_chunk) in chunks.iter().enumerate() {
209 let packed = collect_bool_word(64, |bit_idx| {
210 let i = bit_idx + chunk_idx * 64;
211 let bit_value = (src_chunk >> bit_idx) & 1 == 1;
212 f(i, bit_value)
213 });
214
215 unsafe { buffer.push_unchecked(packed) }
217 }
218
219 if remainder != 0 {
220 let src_chunk = chunks.remainder_bits();
221 let packed = collect_bool_word(remainder, |bit_idx| {
222 let i = bit_idx + chunks_count * 64;
223 let bit_value = (src_chunk >> bit_idx) & 1 == 1;
224 f(i, bit_value)
225 });
226
227 unsafe { buffer.push_unchecked(packed) }
229 }
230
231 buffer.truncate(len.div_ceil(8));
232
233 Self {
234 buffer: buffer.freeze().into_byte_buffer(),
235 offset: 0,
236 len,
237 }
238 }
239
240 pub fn clear(&mut self) {
242 self.buffer.clear();
243 self.len = 0;
244 self.offset = 0;
245 }
246
247 #[inline]
252 pub fn len(&self) -> usize {
253 self.len
254 }
255
256 #[inline]
258 pub fn is_empty(&self) -> bool {
259 self.len() == 0
260 }
261
262 #[inline(always)]
264 pub fn offset(&self) -> usize {
265 self.offset
266 }
267
268 #[inline(always)]
270 pub fn inner(&self) -> &ByteBuffer {
271 &self.buffer
272 }
273
274 #[inline]
280 pub fn value(&self, index: usize) -> bool {
281 assert!(index < self.len);
282 unsafe { self.value_unchecked(index) }
283 }
284
285 #[inline]
290 pub unsafe fn value_unchecked(&self, index: usize) -> bool {
291 unsafe { get_bit_unchecked(self.buffer.as_ptr(), index + self.offset) }
292 }
293
294 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
299 let start = match range.start_bound() {
300 Bound::Included(&s) => s,
301 Bound::Excluded(&s) => s + 1,
302 Bound::Unbounded => 0,
303 };
304 let end = match range.end_bound() {
305 Bound::Included(&e) => e + 1,
306 Bound::Excluded(&e) => e,
307 Bound::Unbounded => self.len,
308 };
309
310 assert!(start <= end);
311 assert!(start <= self.len);
312 assert!(end <= self.len);
313 let len = end - start;
314
315 Self::new_with_offset(self.buffer.clone(), len, self.offset + start)
316 }
317
318 pub fn shrink_offset(self) -> Self {
320 let word_start = self.offset / 8;
321 let word_end = (self.offset + self.len).div_ceil(8);
322
323 let buffer = self.buffer.slice(word_start..word_end);
324
325 let bit_offset = self.offset % 8;
326 let len = self.len;
327 BitBuffer::new_with_offset(buffer, len, bit_offset)
328 }
329
330 pub fn unaligned_chunks(&self) -> UnalignedBitChunk<'_> {
332 UnalignedBitChunk::new(self.buffer.as_slice(), self.offset, self.len)
333 }
334
335 pub fn chunks(&self) -> BitChunks<'_> {
339 BitChunks::new(self.buffer.as_slice(), self.offset, self.len)
340 }
341
342 #[inline]
344 pub fn true_count(&self) -> usize {
345 count_ones(self.buffer.as_slice(), self.offset, self.len)
346 }
347
348 pub fn select(&self, nth: usize) -> Option<usize> {
355 bit_select(self.buffer.as_slice(), self.offset, self.len, nth)
356 }
357
358 #[inline]
360 pub fn false_count(&self) -> usize {
361 self.len - self.true_count()
362 }
363
364 pub fn iter(&self) -> BitIterator<'_> {
366 BitIterator::new(self.buffer.as_slice(), self.offset, self.len)
367 }
368
369 pub fn set_indices(&self) -> BitIndexIterator<'_> {
371 BitIndexIterator::new(self.buffer.as_slice(), self.offset, self.len)
372 }
373
374 pub fn set_slices(&self) -> BitSliceIterator<'_> {
376 BitSliceIterator::new(self.buffer.as_slice(), self.offset, self.len)
377 }
378
379 pub fn sliced(&self) -> Self {
381 if self.offset.is_multiple_of(8) {
382 return Self::new(
383 self.buffer
384 .slice(self.offset / 8..(self.offset + self.len).div_ceil(8)),
385 self.len,
386 );
387 }
388
389 bitwise_unary_op_copy(self, |a| a)
391 }
392}
393
394impl BitBuffer {
397 pub fn into_inner(self) -> (usize, usize, ByteBuffer) {
399 (self.offset, self.len, self.buffer)
400 }
401
402 pub fn try_into_mut(self) -> Result<BitBufferMut, Self> {
404 match self.buffer.try_into_mut() {
405 Ok(buffer) => Ok(BitBufferMut::from_buffer(buffer, self.offset, self.len)),
406 Err(buffer) => Err(BitBuffer::new_with_offset(buffer, self.len, self.offset)),
407 }
408 }
409}
410
411impl From<&[bool]> for BitBuffer {
412 fn from(value: &[bool]) -> Self {
413 BitBufferMut::from(value).freeze()
414 }
415}
416
417impl From<Vec<bool>> for BitBuffer {
418 fn from(value: Vec<bool>) -> Self {
419 BitBufferMut::from(value).freeze()
420 }
421}
422
423impl FromIterator<bool> for BitBuffer {
424 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
425 BitBufferMut::from_iter(iter).freeze()
426 }
427}
428
429impl BitOr for BitBuffer {
430 type Output = Self;
431
432 #[inline]
433 fn bitor(self, rhs: Self) -> Self::Output {
434 bitwise_binary_op_lhs_owned(self, &rhs, |a, b| a | b)
435 }
436}
437
438impl BitOr for &BitBuffer {
439 type Output = BitBuffer;
440
441 #[inline]
442 fn bitor(self, rhs: Self) -> Self::Output {
443 bitwise_binary_op(self, rhs, |a, b| a | b)
444 }
445}
446
447impl BitOr<&BitBuffer> for BitBuffer {
448 type Output = BitBuffer;
449
450 #[inline]
451 fn bitor(self, rhs: &BitBuffer) -> Self::Output {
452 bitwise_binary_op_lhs_owned(self, rhs, |a, b| a | b)
453 }
454}
455
456impl BitAnd for &BitBuffer {
457 type Output = BitBuffer;
458
459 #[inline]
460 fn bitand(self, rhs: Self) -> Self::Output {
461 bitwise_binary_op(self, rhs, |a, b| a & b)
462 }
463}
464
465impl BitAnd<BitBuffer> for &BitBuffer {
466 type Output = BitBuffer;
467
468 #[inline]
469 fn bitand(self, rhs: BitBuffer) -> Self::Output {
470 self.bitand(&rhs)
471 }
472}
473
474impl BitAnd<&BitBuffer> for BitBuffer {
475 type Output = BitBuffer;
476
477 #[inline]
478 fn bitand(self, rhs: &BitBuffer) -> Self::Output {
479 bitwise_binary_op_lhs_owned(self, rhs, |a, b| a & b)
480 }
481}
482
483impl BitAnd<BitBuffer> for BitBuffer {
484 type Output = BitBuffer;
485
486 #[inline]
487 fn bitand(self, rhs: BitBuffer) -> Self::Output {
488 bitwise_binary_op_lhs_owned(self, &rhs, |a, b| a & b)
489 }
490}
491
492impl Not for &BitBuffer {
493 type Output = BitBuffer;
494
495 #[inline]
496 fn not(self) -> Self::Output {
497 bitwise_unary_op_copy(self, |a| !a)
500 }
501}
502
503impl Not for BitBuffer {
504 type Output = BitBuffer;
505
506 #[inline]
507 fn not(self) -> Self::Output {
508 bitwise_unary_op(self, |a| !a)
509 }
510}
511
512impl BitXor for &BitBuffer {
513 type Output = BitBuffer;
514
515 #[inline]
516 fn bitxor(self, rhs: Self) -> Self::Output {
517 bitwise_binary_op(self, rhs, |a, b| a ^ b)
518 }
519}
520
521impl BitXor<&BitBuffer> for BitBuffer {
522 type Output = BitBuffer;
523
524 #[inline]
525 fn bitxor(self, rhs: &BitBuffer) -> Self::Output {
526 bitwise_binary_op_lhs_owned(self, rhs, |a, b| a ^ b)
527 }
528}
529
530impl BitBuffer {
531 pub fn bitand_not(&self, rhs: &BitBuffer) -> BitBuffer {
536 bitwise_binary_op(self, rhs, |a, b| a & !b)
537 }
538
539 pub fn into_bitand_not(self, rhs: &BitBuffer) -> BitBuffer {
541 bitwise_binary_op_lhs_owned(self, rhs, |a, b| a & !b)
542 }
543
544 #[inline]
554 pub fn iter_bits<F>(&self, mut f: F)
555 where
556 F: FnMut(usize, bool),
557 {
558 let total_bits = self.len;
559 if total_bits == 0 {
560 return;
561 }
562
563 let chunks = self.chunks();
565 let chunks_count = total_bits / 64;
566 let remainder = total_bits % 64;
567
568 for (chunk_idx, chunk) in chunks.iter().enumerate() {
569 let base = chunk_idx * 64;
570 for bit_idx in 0..64 {
571 f(base + bit_idx, (chunk >> bit_idx) & 1 == 1);
572 }
573 }
574
575 if remainder != 0 {
576 let rem_chunk = chunks.remainder_bits();
577 let base = chunks_count * 64;
578 for bit_idx in 0..remainder {
579 f(base + bit_idx, (rem_chunk >> bit_idx) & 1 == 1);
580 }
581 }
582 }
583}
584
585impl<'a> IntoIterator for &'a BitBuffer {
586 type Item = bool;
587 type IntoIter = BitIterator<'a>;
588
589 fn into_iter(self) -> Self::IntoIter {
590 self.iter()
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use rstest::rstest;
597
598 use crate::ByteBuffer;
599 use crate::bit::BitBuffer;
600 use crate::buffer;
601
602 #[test]
603 fn test_bool() {
604 let buffer: ByteBuffer = buffer![1 << 7; 1024];
606 let bools = BitBuffer::new(buffer, 1024 * 8);
607
608 assert_eq!(bools.len(), 1024 * 8);
610 assert!(!bools.is_empty());
611 assert_eq!(bools.true_count(), 1024);
612 assert_eq!(bools.false_count(), 1024 * 7);
613
614 for word in 0..1024 {
616 for bit in 0..8 {
617 if bit == 7 {
618 assert!(bools.value(word * 8 + bit));
619 } else {
620 assert!(!bools.value(word * 8 + bit));
621 }
622 }
623 }
624
625 let sliced = bools.slice(64..72);
627
628 assert_eq!(sliced.len(), 8);
630 assert!(!sliced.is_empty());
631 assert_eq!(sliced.true_count(), 1);
632 assert_eq!(sliced.false_count(), 7);
633
634 for bit in 0..8 {
636 if bit == 7 {
637 assert!(sliced.value(bit));
638 } else {
639 assert!(!sliced.value(bit));
640 }
641 }
642 }
643
644 #[test]
645 fn test_padded_equaltiy() {
646 let buf1 = BitBuffer::new_set(64); let buf2 = BitBuffer::collect_bool(64, |x| x < 32); for i in 0..32 {
650 assert_eq!(buf1.value(i), buf2.value(i), "Bit {} should be the same", i);
651 }
652
653 for i in 32..64 {
654 assert_ne!(buf1.value(i), buf2.value(i), "Bit {} should differ", i);
655 }
656
657 assert_eq!(
658 buf1.slice(0..32),
659 buf2.slice(0..32),
660 "Buffer slices with same bits should be equal (`PartialEq` needs `iter_padded()`)"
661 );
662 assert_ne!(
663 buf1.slice(32..64),
664 buf2.slice(32..64),
665 "Buffer slices with different bits should not be equal (`PartialEq` needs `iter_padded()`)"
666 );
667 }
668
669 #[test]
670 fn test_slice_offset_calculation() {
671 let buf = BitBuffer::collect_bool(16, |_| true);
672 let sliced = buf.slice(10..16);
673 assert_eq!(sliced.len(), 6);
674 assert_eq!(sliced.offset(), 2);
676 }
677
678 #[test]
679 fn test_from_indices_dense_crosses_words() {
680 let len = 130;
681 let indices = (0..len).filter(|idx| idx % 3 != 1);
682 let buf = BitBuffer::from_indices(len, indices);
683
684 assert_eq!(buf.len(), len);
685 for idx in 0..len {
686 assert_eq!(buf.value(idx), idx % 3 != 1, "mismatch at {idx}");
687 }
688 }
689
690 #[test]
691 #[should_panic(expected = "index 5 exceeds len 5")]
692 fn test_from_indices_out_of_bounds() {
693 BitBuffer::from_indices(5, [0, 5]);
694 }
695
696 #[rstest]
697 #[case(5)]
698 #[case(8)]
699 #[case(10)]
700 #[case(13)]
701 #[case(16)]
702 #[case(23)]
703 #[case(100)]
704 fn test_iter_bits(#[case] len: usize) {
705 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
706
707 let mut collected = Vec::new();
708 buf.iter_bits(|idx, is_set| {
709 collected.push((idx, is_set));
710 });
711
712 assert_eq!(collected.len(), len);
713
714 for (idx, is_set) in collected {
715 assert_eq!(is_set, idx % 2 == 0);
716 }
717 }
718
719 #[rstest]
720 #[case(3, 5)]
721 #[case(3, 8)]
722 #[case(5, 10)]
723 #[case(2, 16)]
724 #[case(8, 16)]
725 #[case(9, 16)]
726 #[case(17, 16)]
727 fn test_iter_bits_with_offset(#[case] offset: usize, #[case] len: usize) {
728 let total_bits = offset + len;
729 let buf = BitBuffer::collect_bool(total_bits, |i| i % 2 == 0);
730 let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
731
732 let mut collected = Vec::new();
733 buf_with_offset.iter_bits(|idx, is_set| {
734 collected.push((idx, is_set));
735 });
736
737 assert_eq!(collected.len(), len);
738
739 for (idx, is_set) in collected {
740 assert_eq!(is_set, (offset + idx).is_multiple_of(2));
742 }
743 }
744
745 #[rstest]
746 #[case(8, 10)]
747 #[case(9, 7)]
748 #[case(16, 8)]
749 #[case(17, 10)]
750 fn test_iter_bits_catches_wrong_byte_offset(#[case] offset: usize, #[case] len: usize) {
751 let total_bits = offset + len;
752 let buf = BitBuffer::collect_bool(total_bits, |i| (i / 8) % 2 == 0);
754
755 let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
756
757 let mut collected = Vec::new();
758 buf_with_offset.iter_bits(|idx, is_set| {
759 collected.push((idx, is_set));
760 });
761
762 assert_eq!(collected.len(), len);
763
764 for (idx, is_set) in collected {
765 let bit_position = offset + idx;
766 let byte_index = bit_position / 8;
767 let expected_is_set = byte_index.is_multiple_of(2);
768
769 assert_eq!(
770 is_set, expected_is_set,
771 "Bit mismatch at index {}: expected {} got {}",
772 bit_position, expected_is_set, is_set
773 );
774 }
775 }
776
777 #[rstest]
778 #[case(5)]
779 #[case(8)]
780 #[case(10)]
781 #[case(64)]
782 #[case(65)]
783 #[case(100)]
784 #[case(128)]
785 fn test_map_cmp_identity(#[case] len: usize) {
786 let buf = BitBuffer::collect_bool(len, |i| i % 3 == 0);
788 let mapped = buf.map_cmp(|_idx, bit| bit);
789
790 assert_eq!(buf.len(), mapped.len());
791 for i in 0..len {
792 assert_eq!(buf.value(i), mapped.value(i), "Mismatch at index {}", i);
793 }
794 }
795
796 #[rstest]
797 #[case(5)]
798 #[case(8)]
799 #[case(64)]
800 #[case(65)]
801 #[case(100)]
802 fn test_map_cmp_negate(#[case] len: usize) {
803 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
805 let mapped = buf.map_cmp(|_idx, bit| !bit);
806
807 assert_eq!(buf.len(), mapped.len());
808 for i in 0..len {
809 assert_eq!(!buf.value(i), mapped.value(i), "Mismatch at index {}", i);
810 }
811 }
812
813 #[test]
814 fn test_map_cmp_conditional() {
815 let len = 100;
817 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
818
819 let mapped = buf.map_cmp(|idx, bit| bit && idx % 4 == 0);
821
822 for i in 0..len {
823 let expected = (i % 2 == 0) && (i % 4 == 0);
824 assert_eq!(mapped.value(i), expected, "Mismatch at index {}", i);
825 }
826 }
827}