1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::fmt::Result as FmtResult;
7use std::ops::BitAnd;
8use std::ops::BitOr;
9use std::ops::BitXor;
10use std::ops::Bound;
11use std::ops::Not;
12use std::ops::RangeBounds;
13
14use crate::Alignment;
15use crate::BitBufferMut;
16use crate::Buffer;
17use crate::BufferMut;
18use crate::ByteBuffer;
19use crate::bit::BitChunks;
20use crate::bit::BitIndexIterator;
21use crate::bit::BitIterator;
22use crate::bit::BitSliceIterator;
23use crate::bit::UnalignedBitChunk;
24use crate::bit::collect_bool_word;
25use crate::bit::count_ones::count_ones;
26use crate::bit::get_bit_unchecked;
27use crate::bit::ops::bitwise_binary_op;
28use crate::bit::ops::bitwise_binary_op_lhs_owned;
29use crate::bit::ops::bitwise_unary_op;
30use crate::bit::ops::bitwise_unary_op_copy;
31use crate::bit::select::bit_select;
32use crate::buffer;
33
34#[derive(Debug, Clone, Eq)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37pub struct BitBuffer {
38 buffer: ByteBuffer,
39 offset: usize,
43 len: usize,
44}
45
46const LIMIT_LEN: usize = 16;
47impl Display for BitBuffer {
48 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
49 let limit = f.precision().unwrap_or(LIMIT_LEN);
50 let buf: Vec<bool> = self.into_iter().take(limit).collect();
51 f.debug_struct("BitBuffer")
52 .field("len", &self.len)
53 .field("buffer", &buf)
54 .finish()
55 }
56}
57
58impl PartialEq for BitBuffer {
59 fn eq(&self, other: &Self) -> bool {
60 if self.len != other.len {
61 return false;
62 }
63
64 if self.len == 0 {
65 return true;
66 }
67
68 if self.offset == 0 && other.offset == 0 {
70 let full_bytes = self.len / 8;
71 let self_bytes = &self.buffer.as_slice()[..full_bytes];
72 let other_bytes = &other.buffer.as_slice()[..full_bytes];
73 if self_bytes != other_bytes {
74 return false;
75 }
76 let rem = self.len % 8;
78 if rem != 0 {
79 let mask = (1u8 << rem) - 1;
80 let a = self.buffer.as_slice()[full_bytes] & mask;
81 let b = other.buffer.as_slice()[full_bytes] & mask;
82 return a == b;
83 }
84 return true;
85 }
86
87 self.chunks()
88 .iter_padded()
89 .zip(other.chunks().iter_padded())
90 .all(|(a, b)| a == b)
91 }
92}
93
94impl BitBuffer {
95 pub fn new(buffer: ByteBuffer, len: usize) -> Self {
99 assert!(
100 buffer.len() * 8 >= len,
101 "provided ByteBuffer not large enough to back BoolBuffer with len {len}"
102 );
103
104 let buffer = buffer.aligned(Alignment::none());
106
107 Self {
108 buffer,
109 len,
110 offset: 0,
111 }
112 }
113
114 pub fn new_with_offset(buffer: ByteBuffer, len: usize, offset: usize) -> Self {
119 assert!(
120 len.saturating_add(offset) <= buffer.len().saturating_mul(8),
121 "provided ByteBuffer (len={}) not large enough to back BoolBuffer with offset {offset} len {len}",
122 buffer.len()
123 );
124
125 let buffer = buffer.aligned(Alignment::none());
127
128 let byte_offset = offset / 8;
130 let offset = offset % 8;
131 let buffer = if byte_offset != 0 {
132 buffer.slice(byte_offset..)
133 } else {
134 buffer
135 };
136
137 Self {
138 buffer,
139 offset,
140 len,
141 }
142 }
143
144 pub fn new_set(len: usize) -> Self {
146 let words = len.div_ceil(8);
147 let buffer = buffer![0xFF; words];
148
149 Self {
150 buffer,
151 len,
152 offset: 0,
153 }
154 }
155
156 pub fn new_unset(len: usize) -> Self {
158 let words = len.div_ceil(8);
159 let buffer = Buffer::zeroed(words);
160
161 Self {
162 buffer,
163 len,
164 offset: 0,
165 }
166 }
167
168 pub fn from_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> BitBuffer {
170 BitBufferMut::from_indices(len, indices).freeze()
171 }
172
173 pub fn empty() -> Self {
175 Self::new_set(0)
176 }
177
178 pub fn full(value: bool, len: usize) -> Self {
180 if value {
181 Self::new_set(len)
182 } else {
183 Self::new_unset(len)
184 }
185 }
186
187 #[inline]
189 pub fn collect_bool<F: FnMut(usize) -> bool>(len: usize, f: F) -> Self {
190 BitBufferMut::collect_bool(len, f).freeze()
191 }
192
193 pub fn map_cmp<F>(&self, mut f: F) -> Self
198 where
199 F: FnMut(usize, bool) -> bool,
200 {
201 let len = self.len;
202 let mut buffer: BufferMut<u64> = BufferMut::with_capacity(len.div_ceil(64));
203
204 let chunks_count = len / 64;
205 let remainder = len % 64;
206 let chunks = self.chunks();
207
208 for (chunk_idx, src_chunk) in chunks.iter().enumerate() {
209 let packed = collect_bool_word(64, |bit_idx| {
210 let i = bit_idx + chunk_idx * 64;
211 let bit_value = (src_chunk >> bit_idx) & 1 == 1;
212 f(i, bit_value)
213 });
214
215 unsafe { buffer.push_unchecked(packed) }
217 }
218
219 if remainder != 0 {
220 let src_chunk = chunks.remainder_bits();
221 let packed = collect_bool_word(remainder, |bit_idx| {
222 let i = bit_idx + chunks_count * 64;
223 let bit_value = (src_chunk >> bit_idx) & 1 == 1;
224 f(i, bit_value)
225 });
226
227 unsafe { buffer.push_unchecked(packed) }
229 }
230
231 let mut bytes = buffer.into_byte_buffer();
232 bytes.truncate(len.div_ceil(8));
233
234 Self {
235 buffer: bytes.freeze(),
236 offset: 0,
237 len,
238 }
239 }
240
241 pub fn clear(&mut self) {
243 self.buffer.clear();
244 self.len = 0;
245 self.offset = 0;
246 }
247
248 #[inline]
253 pub fn len(&self) -> usize {
254 self.len
255 }
256
257 #[inline]
259 pub fn is_empty(&self) -> bool {
260 self.len() == 0
261 }
262
263 #[inline(always)]
265 pub fn offset(&self) -> usize {
266 self.offset
267 }
268
269 #[inline(always)]
271 pub fn inner(&self) -> &ByteBuffer {
272 &self.buffer
273 }
274
275 #[inline]
281 pub fn byte_aligned_bytes(&self) -> Option<&[u8]> {
282 if !self.offset.is_multiple_of(8) {
283 return None;
284 }
285
286 let n_bytes = self.len.div_ceil(8);
287 let start = self.offset / 8;
288 let end = start + n_bytes;
289 Some(&self.buffer.as_slice()[start..end])
290 }
291
292 #[inline]
298 pub fn value(&self, index: usize) -> bool {
299 assert!(index < self.len);
300 unsafe { self.value_unchecked(index) }
301 }
302
303 #[inline]
308 pub unsafe fn value_unchecked(&self, index: usize) -> bool {
309 unsafe { get_bit_unchecked(self.buffer.as_ptr(), index + self.offset) }
310 }
311
312 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
317 let start = match range.start_bound() {
318 Bound::Included(&s) => s,
319 Bound::Excluded(&s) => s + 1,
320 Bound::Unbounded => 0,
321 };
322 let end = match range.end_bound() {
323 Bound::Included(&e) => e + 1,
324 Bound::Excluded(&e) => e,
325 Bound::Unbounded => self.len,
326 };
327
328 assert!(start <= end);
329 assert!(start <= self.len);
330 assert!(end <= self.len);
331 let len = end - start;
332
333 let offset = self.offset + start;
334 let byte_offset = offset / 8;
335 let bit_offset = offset % 8;
336
337 let buffer = if byte_offset != 0 {
340 self.buffer.slice_unaligned(byte_offset..)
341 } else {
342 self.buffer.clone().aligned(Alignment::none())
343 };
344
345 Self {
346 buffer,
347 offset: bit_offset,
348 len,
349 }
350 }
351
352 pub fn shrink_offset(self) -> Self {
354 let word_start = self.offset / 8;
355 let word_end = (self.offset + self.len).div_ceil(8);
356
357 let buffer = self.buffer.slice(word_start..word_end);
358
359 let bit_offset = self.offset % 8;
360 let len = self.len;
361 BitBuffer::new_with_offset(buffer, len, bit_offset)
362 }
363
364 pub fn unaligned_chunks(&self) -> UnalignedBitChunk<'_> {
366 UnalignedBitChunk::new(self.buffer.as_slice(), self.offset, self.len)
367 }
368
369 pub fn chunks(&self) -> BitChunks<'_> {
373 BitChunks::new(self.buffer.as_slice(), self.offset, self.len)
374 }
375
376 #[inline]
378 pub fn true_count(&self) -> usize {
379 count_ones(self.buffer.as_slice(), self.offset, self.len)
380 }
381
382 pub fn select(&self, nth: usize) -> Option<usize> {
389 bit_select(self.buffer.as_slice(), self.offset, self.len, nth)
390 }
391
392 #[inline]
394 pub fn false_count(&self) -> usize {
395 self.len - self.true_count()
396 }
397
398 pub fn iter(&self) -> BitIterator<'_> {
400 BitIterator::new(self.buffer.as_slice(), self.offset, self.len)
401 }
402
403 pub fn set_indices(&self) -> BitIndexIterator<'_> {
405 BitIndexIterator::new(self.buffer.as_slice(), self.offset, self.len)
406 }
407
408 pub fn set_slices(&self) -> BitSliceIterator<'_> {
410 BitSliceIterator::new(self.buffer.as_slice(), self.offset, self.len)
411 }
412
413 pub fn sliced(&self) -> Self {
415 if self.offset.is_multiple_of(8) {
416 return Self::new(
417 self.buffer
418 .slice(self.offset / 8..(self.offset + self.len).div_ceil(8)),
419 self.len,
420 );
421 }
422
423 bitwise_unary_op_copy(self, |a| a)
425 }
426}
427
428impl BitBuffer {
431 pub fn into_inner(self) -> (usize, usize, ByteBuffer) {
433 (self.offset, self.len, self.buffer)
434 }
435
436 pub fn try_into_mut(self) -> Result<BitBufferMut, Self> {
438 match self.buffer.try_into_mut() {
439 Ok(buffer) => Ok(BitBufferMut::from_buffer(buffer, self.offset, self.len)),
440 Err(buffer) => Err(BitBuffer::new_with_offset(buffer, self.len, self.offset)),
441 }
442 }
443}
444
445impl From<&[bool]> for BitBuffer {
446 fn from(value: &[bool]) -> Self {
447 BitBufferMut::from(value).freeze()
448 }
449}
450
451impl From<Vec<bool>> for BitBuffer {
452 fn from(value: Vec<bool>) -> Self {
453 BitBufferMut::from(value).freeze()
454 }
455}
456
457impl FromIterator<bool> for BitBuffer {
458 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
459 BitBufferMut::from_iter(iter).freeze()
460 }
461}
462
463impl BitOr for BitBuffer {
464 type Output = Self;
465
466 #[inline]
467 fn bitor(self, rhs: Self) -> Self::Output {
468 bitwise_binary_op_lhs_owned(self, &rhs, |a, b| a | b)
469 }
470}
471
472impl BitOr for &BitBuffer {
473 type Output = BitBuffer;
474
475 #[inline]
476 fn bitor(self, rhs: Self) -> Self::Output {
477 bitwise_binary_op(self, rhs, |a, b| a | b)
478 }
479}
480
481impl BitOr<&BitBuffer> for BitBuffer {
482 type Output = BitBuffer;
483
484 #[inline]
485 fn bitor(self, rhs: &BitBuffer) -> Self::Output {
486 bitwise_binary_op_lhs_owned(self, rhs, |a, b| a | b)
487 }
488}
489
490impl BitAnd for &BitBuffer {
491 type Output = BitBuffer;
492
493 #[inline]
494 fn bitand(self, rhs: Self) -> Self::Output {
495 bitwise_binary_op(self, rhs, |a, b| a & b)
496 }
497}
498
499impl BitAnd<BitBuffer> for &BitBuffer {
500 type Output = BitBuffer;
501
502 #[inline]
503 fn bitand(self, rhs: BitBuffer) -> Self::Output {
504 self.bitand(&rhs)
505 }
506}
507
508impl BitAnd<&BitBuffer> for BitBuffer {
509 type Output = BitBuffer;
510
511 #[inline]
512 fn bitand(self, rhs: &BitBuffer) -> Self::Output {
513 bitwise_binary_op_lhs_owned(self, rhs, |a, b| a & b)
514 }
515}
516
517impl BitAnd<BitBuffer> for BitBuffer {
518 type Output = BitBuffer;
519
520 #[inline]
521 fn bitand(self, rhs: BitBuffer) -> Self::Output {
522 bitwise_binary_op_lhs_owned(self, &rhs, |a, b| a & b)
523 }
524}
525
526impl Not for &BitBuffer {
527 type Output = BitBuffer;
528
529 #[inline]
530 fn not(self) -> Self::Output {
531 bitwise_unary_op_copy(self, |a| !a)
534 }
535}
536
537impl Not for BitBuffer {
538 type Output = BitBuffer;
539
540 #[inline]
541 fn not(self) -> Self::Output {
542 bitwise_unary_op(self, |a| !a)
543 }
544}
545
546impl BitXor for &BitBuffer {
547 type Output = BitBuffer;
548
549 #[inline]
550 fn bitxor(self, rhs: Self) -> Self::Output {
551 bitwise_binary_op(self, rhs, |a, b| a ^ b)
552 }
553}
554
555impl BitXor<&BitBuffer> for BitBuffer {
556 type Output = BitBuffer;
557
558 #[inline]
559 fn bitxor(self, rhs: &BitBuffer) -> Self::Output {
560 bitwise_binary_op_lhs_owned(self, rhs, |a, b| a ^ b)
561 }
562}
563
564impl BitBuffer {
565 pub fn bitand_not(&self, rhs: &BitBuffer) -> BitBuffer {
570 bitwise_binary_op(self, rhs, |a, b| a & !b)
571 }
572
573 pub fn into_bitand_not(self, rhs: &BitBuffer) -> BitBuffer {
575 bitwise_binary_op_lhs_owned(self, rhs, |a, b| a & !b)
576 }
577
578 #[inline]
588 pub fn iter_bits<F>(&self, mut f: F)
589 where
590 F: FnMut(usize, bool),
591 {
592 let total_bits = self.len;
593 if total_bits == 0 {
594 return;
595 }
596
597 let chunks = self.chunks();
599 let chunks_count = total_bits / 64;
600 let remainder = total_bits % 64;
601
602 for (chunk_idx, chunk) in chunks.iter().enumerate() {
603 let base = chunk_idx * 64;
604 for bit_idx in 0..64 {
605 f(base + bit_idx, (chunk >> bit_idx) & 1 == 1);
606 }
607 }
608
609 if remainder != 0 {
610 let rem_chunk = chunks.remainder_bits();
611 let base = chunks_count * 64;
612 for bit_idx in 0..remainder {
613 f(base + bit_idx, (rem_chunk >> bit_idx) & 1 == 1);
614 }
615 }
616 }
617}
618
619impl<'a> IntoIterator for &'a BitBuffer {
620 type Item = bool;
621 type IntoIter = BitIterator<'a>;
622
623 fn into_iter(self) -> Self::IntoIter {
624 self.iter()
625 }
626}
627
628#[cfg(test)]
629mod tests {
630 use rstest::rstest;
631
632 use crate::ByteBuffer;
633 use crate::bit::BitBuffer;
634 use crate::buffer;
635
636 #[test]
637 fn test_bool() {
638 let buffer: ByteBuffer = buffer![1 << 7; 1024];
640 let bools = BitBuffer::new(buffer, 1024 * 8);
641
642 assert_eq!(bools.len(), 1024 * 8);
644 assert!(!bools.is_empty());
645 assert_eq!(bools.true_count(), 1024);
646 assert_eq!(bools.false_count(), 1024 * 7);
647
648 for word in 0..1024 {
650 for bit in 0..8 {
651 if bit == 7 {
652 assert!(bools.value(word * 8 + bit));
653 } else {
654 assert!(!bools.value(word * 8 + bit));
655 }
656 }
657 }
658
659 let sliced = bools.slice(64..72);
661
662 assert_eq!(sliced.len(), 8);
664 assert!(!sliced.is_empty());
665 assert_eq!(sliced.true_count(), 1);
666 assert_eq!(sliced.false_count(), 7);
667
668 for bit in 0..8 {
670 if bit == 7 {
671 assert!(sliced.value(bit));
672 } else {
673 assert!(!sliced.value(bit));
674 }
675 }
676 }
677
678 #[test]
679 fn test_padded_equaltiy() {
680 let buf1 = BitBuffer::new_set(64); let buf2 = BitBuffer::collect_bool(64, |x| x < 32); for i in 0..32 {
684 assert_eq!(buf1.value(i), buf2.value(i), "Bit {} should be the same", i);
685 }
686
687 for i in 32..64 {
688 assert_ne!(buf1.value(i), buf2.value(i), "Bit {} should differ", i);
689 }
690
691 assert_eq!(
692 buf1.slice(0..32),
693 buf2.slice(0..32),
694 "Buffer slices with same bits should be equal (`PartialEq` needs `iter_padded()`)"
695 );
696 assert_ne!(
697 buf1.slice(32..64),
698 buf2.slice(32..64),
699 "Buffer slices with different bits should not be equal (`PartialEq` needs `iter_padded()`)"
700 );
701 }
702
703 #[test]
704 fn test_slice_offset_calculation() {
705 let buf = BitBuffer::collect_bool(16, |_| true);
706 let sliced = buf.slice(10..16);
707 assert_eq!(sliced.len(), 6);
708 assert_eq!(sliced.offset(), 2);
710 }
711
712 #[test]
713 fn test_byte_aligned_bytes() {
714 let bytes: ByteBuffer = buffer![0b1010_0101u8, 0b0000_0011];
715 let buf = BitBuffer::new(bytes.clone(), 10);
716 assert_eq!(buf.byte_aligned_bytes(), Some(bytes.as_slice()));
717
718 let byte_sliced = buf.slice(8..10);
719 assert_eq!(byte_sliced.byte_aligned_bytes(), Some(&[0b0000_0011][..]));
720
721 let bit_sliced = buf.slice(1..9);
722 assert!(bit_sliced.byte_aligned_bytes().is_none());
723 }
724
725 #[test]
726 fn test_from_indices_dense_crosses_words() {
727 let len = 130;
728 let indices = (0..len).filter(|idx| idx % 3 != 1);
729 let buf = BitBuffer::from_indices(len, indices);
730
731 assert_eq!(buf.len(), len);
732 for idx in 0..len {
733 assert_eq!(buf.value(idx), idx % 3 != 1, "mismatch at {idx}");
734 }
735 }
736
737 #[test]
738 #[should_panic(expected = "index 5 exceeds len 5")]
739 fn test_from_indices_out_of_bounds() {
740 BitBuffer::from_indices(5, [0, 5]);
741 }
742
743 #[rstest]
744 #[case(5)]
745 #[case(8)]
746 #[case(10)]
747 #[case(13)]
748 #[case(16)]
749 #[case(23)]
750 #[case(100)]
751 fn test_iter_bits(#[case] len: usize) {
752 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
753
754 let mut collected = Vec::new();
755 buf.iter_bits(|idx, is_set| {
756 collected.push((idx, is_set));
757 });
758
759 assert_eq!(collected.len(), len);
760
761 for (idx, is_set) in collected {
762 assert_eq!(is_set, idx % 2 == 0);
763 }
764 }
765
766 #[rstest]
767 #[case(3, 5)]
768 #[case(3, 8)]
769 #[case(5, 10)]
770 #[case(2, 16)]
771 #[case(8, 16)]
772 #[case(9, 16)]
773 #[case(17, 16)]
774 fn test_iter_bits_with_offset(#[case] offset: usize, #[case] len: usize) {
775 let total_bits = offset + len;
776 let buf = BitBuffer::collect_bool(total_bits, |i| i % 2 == 0);
777 let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
778
779 let mut collected = Vec::new();
780 buf_with_offset.iter_bits(|idx, is_set| {
781 collected.push((idx, is_set));
782 });
783
784 assert_eq!(collected.len(), len);
785
786 for (idx, is_set) in collected {
787 assert_eq!(is_set, (offset + idx).is_multiple_of(2));
789 }
790 }
791
792 #[rstest]
793 #[case(8, 10)]
794 #[case(9, 7)]
795 #[case(16, 8)]
796 #[case(17, 10)]
797 fn test_iter_bits_catches_wrong_byte_offset(#[case] offset: usize, #[case] len: usize) {
798 let total_bits = offset + len;
799 let buf = BitBuffer::collect_bool(total_bits, |i| (i / 8) % 2 == 0);
801
802 let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
803
804 let mut collected = Vec::new();
805 buf_with_offset.iter_bits(|idx, is_set| {
806 collected.push((idx, is_set));
807 });
808
809 assert_eq!(collected.len(), len);
810
811 for (idx, is_set) in collected {
812 let bit_position = offset + idx;
813 let byte_index = bit_position / 8;
814 let expected_is_set = byte_index.is_multiple_of(2);
815
816 assert_eq!(
817 is_set, expected_is_set,
818 "Bit mismatch at index {}: expected {} got {}",
819 bit_position, expected_is_set, is_set
820 );
821 }
822 }
823
824 #[rstest]
825 #[case(5)]
826 #[case(8)]
827 #[case(10)]
828 #[case(64)]
829 #[case(65)]
830 #[case(100)]
831 #[case(128)]
832 fn test_map_cmp_identity(#[case] len: usize) {
833 let buf = BitBuffer::collect_bool(len, |i| i % 3 == 0);
835 let mapped = buf.map_cmp(|_idx, bit| bit);
836
837 assert_eq!(buf.len(), mapped.len());
838 for i in 0..len {
839 assert_eq!(buf.value(i), mapped.value(i), "Mismatch at index {}", i);
840 }
841 }
842
843 #[rstest]
844 #[case(5)]
845 #[case(8)]
846 #[case(64)]
847 #[case(65)]
848 #[case(100)]
849 fn test_map_cmp_negate(#[case] len: usize) {
850 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
852 let mapped = buf.map_cmp(|_idx, bit| !bit);
853
854 assert_eq!(buf.len(), mapped.len());
855 for i in 0..len {
856 assert_eq!(!buf.value(i), mapped.value(i), "Mismatch at index {}", i);
857 }
858 }
859
860 #[test]
861 fn test_map_cmp_conditional() {
862 let len = 100;
864 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
865
866 let mapped = buf.map_cmp(|idx, bit| bit && idx % 4 == 0);
868
869 for i in 0..len {
870 let expected = (i % 2 == 0) && (i % 4 == 0);
871 assert_eq!(mapped.value(i), expected, "Mismatch at index {}", i);
872 }
873 }
874}