1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::fmt::Result as FmtResult;
7use std::ops::BitAnd;
8use std::ops::BitOr;
9use std::ops::BitXor;
10use std::ops::Bound;
11use std::ops::Not;
12use std::ops::RangeBounds;
13
14use crate::Alignment;
15use crate::BitBufferMut;
16use crate::Buffer;
17use crate::BufferMut;
18use crate::ByteBuffer;
19use crate::bit::BitChunks;
20use crate::bit::BitIndexIterator;
21use crate::bit::BitIterator;
22use crate::bit::BitSliceIterator;
23use crate::bit::UnalignedBitChunk;
24use crate::bit::count_ones::count_ones;
25use crate::bit::get_bit_unchecked;
26use crate::bit::ops::bitwise_binary_op;
27use crate::bit::ops::bitwise_unary_op;
28use crate::buffer;
29
30#[derive(Debug, Clone, Eq)]
32#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
33pub struct BitBuffer {
34 buffer: ByteBuffer,
35 offset: usize,
39 len: usize,
40}
41
42const LIMIT_LEN: usize = 16;
43impl Display for BitBuffer {
44 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
45 let limit = f.precision().unwrap_or(LIMIT_LEN);
46 let buf: Vec<bool> = self.into_iter().take(limit).collect();
47 f.debug_struct("BitBuffer")
48 .field("len", &self.len)
49 .field("buffer", &buf)
50 .finish()
51 }
52}
53
54impl PartialEq for BitBuffer {
55 fn eq(&self, other: &Self) -> bool {
56 if self.len != other.len {
57 return false;
58 }
59
60 self.chunks()
61 .iter_padded()
62 .zip(other.chunks().iter_padded())
63 .all(|(a, b)| a == b)
64 }
65}
66
67impl BitBuffer {
68 pub fn new(buffer: ByteBuffer, len: usize) -> Self {
72 assert!(
73 buffer.len() * 8 >= len,
74 "provided ByteBuffer not large enough to back BoolBuffer with len {len}"
75 );
76
77 let buffer = buffer.aligned(Alignment::none());
79
80 Self {
81 buffer,
82 len,
83 offset: 0,
84 }
85 }
86
87 pub fn new_with_offset(buffer: ByteBuffer, len: usize, offset: usize) -> Self {
92 assert!(
93 len.saturating_add(offset) <= buffer.len().saturating_mul(8),
94 "provided ByteBuffer (len={}) not large enough to back BoolBuffer with offset {offset} len {len}",
95 buffer.len()
96 );
97
98 let buffer = buffer.aligned(Alignment::none());
100
101 let byte_offset = offset / 8;
103 let offset = offset % 8;
104 let buffer = if byte_offset != 0 {
105 buffer.slice(byte_offset..)
106 } else {
107 buffer
108 };
109
110 Self {
111 buffer,
112 offset,
113 len,
114 }
115 }
116
117 pub fn new_set(len: usize) -> Self {
119 let words = len.div_ceil(8);
120 let buffer = buffer![0xFF; words];
121
122 Self {
123 buffer,
124 len,
125 offset: 0,
126 }
127 }
128
129 pub fn new_unset(len: usize) -> Self {
131 let words = len.div_ceil(8);
132 let buffer = Buffer::zeroed(words);
133
134 Self {
135 buffer,
136 len,
137 offset: 0,
138 }
139 }
140
141 pub fn from_indices(len: usize, indices: &[usize]) -> BitBuffer {
143 BitBufferMut::from_indices(len, indices).freeze()
144 }
145
146 pub fn empty() -> Self {
148 Self::new_set(0)
149 }
150
151 pub fn full(value: bool, len: usize) -> Self {
153 if value {
154 Self::new_set(len)
155 } else {
156 Self::new_unset(len)
157 }
158 }
159
160 #[inline]
162 pub fn collect_bool<F: FnMut(usize) -> bool>(len: usize, f: F) -> Self {
163 BitBufferMut::collect_bool(len, f).freeze()
164 }
165
166 pub fn map_cmp<F>(&self, mut f: F) -> Self
171 where
172 F: FnMut(usize, bool) -> bool,
173 {
174 let len = self.len;
175 let mut buffer: BufferMut<u64> = BufferMut::with_capacity(len.div_ceil(64));
176
177 let chunks_count = len / 64;
178 let remainder = len % 64;
179 let chunks = self.chunks();
180
181 for (chunk_idx, src_chunk) in chunks.iter().enumerate() {
182 let mut packed = 0u64;
183 for bit_idx in 0..64 {
184 let i = bit_idx + chunk_idx * 64;
185 let bit_value = (src_chunk >> bit_idx) & 1 == 1;
186 packed |= (f(i, bit_value) as u64) << bit_idx;
187 }
188
189 unsafe { buffer.push_unchecked(packed) }
191 }
192
193 if remainder != 0 {
194 let src_chunk = chunks.remainder_bits();
195 let mut packed = 0u64;
196 for bit_idx in 0..remainder {
197 let i = bit_idx + chunks_count * 64;
198 let bit_value = (src_chunk >> bit_idx) & 1 == 1;
199 packed |= (f(i, bit_value) as u64) << bit_idx;
200 }
201
202 unsafe { buffer.push_unchecked(packed) }
204 }
205
206 buffer.truncate(len.div_ceil(8));
207
208 Self {
209 buffer: buffer.freeze().into_byte_buffer(),
210 offset: 0,
211 len,
212 }
213 }
214
215 pub fn clear(&mut self) {
217 self.buffer.clear();
218 self.len = 0;
219 self.offset = 0;
220 }
221
222 #[inline]
227 pub fn len(&self) -> usize {
228 self.len
229 }
230
231 #[inline]
233 pub fn is_empty(&self) -> bool {
234 self.len() == 0
235 }
236
237 #[inline(always)]
239 pub fn offset(&self) -> usize {
240 self.offset
241 }
242
243 #[inline(always)]
245 pub fn inner(&self) -> &ByteBuffer {
246 &self.buffer
247 }
248
249 #[inline]
255 pub fn value(&self, index: usize) -> bool {
256 assert!(index < self.len);
257 unsafe { self.value_unchecked(index) }
258 }
259
260 #[inline]
265 pub unsafe fn value_unchecked(&self, index: usize) -> bool {
266 unsafe { get_bit_unchecked(self.buffer.as_ptr(), index + self.offset) }
267 }
268
269 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
274 let start = match range.start_bound() {
275 Bound::Included(&s) => s,
276 Bound::Excluded(&s) => s + 1,
277 Bound::Unbounded => 0,
278 };
279 let end = match range.end_bound() {
280 Bound::Included(&e) => e + 1,
281 Bound::Excluded(&e) => e,
282 Bound::Unbounded => self.len,
283 };
284
285 assert!(start <= end);
286 assert!(start <= self.len);
287 assert!(end <= self.len);
288 let len = end - start;
289
290 Self::new_with_offset(self.buffer.clone(), len, self.offset + start)
291 }
292
293 pub fn shrink_offset(self) -> Self {
295 let word_start = self.offset / 8;
296 let word_end = (self.offset + self.len).div_ceil(8);
297
298 let buffer = self.buffer.slice(word_start..word_end);
299
300 let bit_offset = self.offset % 8;
301 let len = self.len;
302 BitBuffer::new_with_offset(buffer, len, bit_offset)
303 }
304
305 pub fn unaligned_chunks(&self) -> UnalignedBitChunk<'_> {
307 UnalignedBitChunk::new(self.buffer.as_slice(), self.offset, self.len)
308 }
309
310 pub fn chunks(&self) -> BitChunks<'_> {
314 BitChunks::new(self.buffer.as_slice(), self.offset, self.len)
315 }
316
317 pub fn true_count(&self) -> usize {
319 count_ones(self.buffer.as_slice(), self.offset, self.len)
320 }
321
322 pub fn false_count(&self) -> usize {
324 self.len - self.true_count()
325 }
326
327 pub fn iter(&self) -> BitIterator<'_> {
329 BitIterator::new(self.buffer.as_slice(), self.offset, self.len)
330 }
331
332 pub fn set_indices(&self) -> BitIndexIterator<'_> {
334 BitIndexIterator::new(self.buffer.as_slice(), self.offset, self.len)
335 }
336
337 pub fn set_slices(&self) -> BitSliceIterator<'_> {
339 BitSliceIterator::new(self.buffer.as_slice(), self.offset, self.len)
340 }
341
342 pub fn sliced(&self) -> Self {
344 if self.offset.is_multiple_of(8) {
345 return Self::new(
346 self.buffer.slice(self.offset / 8..self.len.div_ceil(8)),
347 self.len,
348 );
349 }
350
351 bitwise_unary_op(self.clone(), |a| a)
352 }
353}
354
355impl BitBuffer {
358 pub fn into_inner(self) -> (usize, usize, ByteBuffer) {
360 (self.offset, self.len, self.buffer)
361 }
362
363 pub fn try_into_mut(self) -> Result<BitBufferMut, Self> {
365 match self.buffer.try_into_mut() {
366 Ok(buffer) => Ok(BitBufferMut::from_buffer(buffer, self.offset, self.len)),
367 Err(buffer) => Err(BitBuffer::new_with_offset(buffer, self.len, self.offset)),
368 }
369 }
370}
371
372impl From<&[bool]> for BitBuffer {
373 fn from(value: &[bool]) -> Self {
374 BitBufferMut::from(value).freeze()
375 }
376}
377
378impl From<Vec<bool>> for BitBuffer {
379 fn from(value: Vec<bool>) -> Self {
380 BitBufferMut::from(value).freeze()
381 }
382}
383
384impl FromIterator<bool> for BitBuffer {
385 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
386 BitBufferMut::from_iter(iter).freeze()
387 }
388}
389
390impl BitOr for BitBuffer {
391 type Output = Self;
392
393 #[inline]
394 fn bitor(self, rhs: Self) -> Self::Output {
395 BitOr::bitor(&self, &rhs)
396 }
397}
398
399impl BitOr for &BitBuffer {
400 type Output = BitBuffer;
401
402 #[inline]
403 fn bitor(self, rhs: Self) -> Self::Output {
404 bitwise_binary_op(self, rhs, |a, b| a | b)
405 }
406}
407
408impl BitOr<&BitBuffer> for BitBuffer {
409 type Output = BitBuffer;
410
411 #[inline]
412 fn bitor(self, rhs: &BitBuffer) -> Self::Output {
413 (&self).bitor(rhs)
414 }
415}
416
417impl BitAnd for &BitBuffer {
418 type Output = BitBuffer;
419
420 #[inline]
421 fn bitand(self, rhs: Self) -> Self::Output {
422 bitwise_binary_op(self, rhs, |a, b| a & b)
423 }
424}
425
426impl BitAnd<BitBuffer> for &BitBuffer {
427 type Output = BitBuffer;
428
429 #[inline]
430 fn bitand(self, rhs: BitBuffer) -> Self::Output {
431 self.bitand(&rhs)
432 }
433}
434
435impl BitAnd<&BitBuffer> for BitBuffer {
436 type Output = BitBuffer;
437
438 #[inline]
439 fn bitand(self, rhs: &BitBuffer) -> Self::Output {
440 (&self).bitand(rhs)
441 }
442}
443
444impl BitAnd<BitBuffer> for BitBuffer {
445 type Output = BitBuffer;
446
447 #[inline]
448 fn bitand(self, rhs: BitBuffer) -> Self::Output {
449 (&self).bitand(&rhs)
450 }
451}
452
453impl Not for &BitBuffer {
454 type Output = BitBuffer;
455
456 #[inline]
457 fn not(self) -> Self::Output {
458 !self.clone()
459 }
460}
461
462impl Not for BitBuffer {
463 type Output = BitBuffer;
464
465 #[inline]
466 fn not(self) -> Self::Output {
467 bitwise_unary_op(self, |a| !a)
468 }
469}
470
471impl BitXor for &BitBuffer {
472 type Output = BitBuffer;
473
474 #[inline]
475 fn bitxor(self, rhs: Self) -> Self::Output {
476 bitwise_binary_op(self, rhs, |a, b| a ^ b)
477 }
478}
479
480impl BitXor<&BitBuffer> for BitBuffer {
481 type Output = BitBuffer;
482
483 #[inline]
484 fn bitxor(self, rhs: &BitBuffer) -> Self::Output {
485 (&self).bitxor(rhs)
486 }
487}
488
489impl BitBuffer {
490 pub fn bitand_not(&self, rhs: &BitBuffer) -> BitBuffer {
495 bitwise_binary_op(self, rhs, |a, b| a & !b)
496 }
497
498 #[inline]
508 pub fn iter_bits<F>(&self, mut f: F)
509 where
510 F: FnMut(usize, bool),
511 {
512 let total_bits = self.len;
513 if total_bits == 0 {
514 return;
515 }
516
517 let is_bit_set = |byte: u8, bit_idx: usize| (byte & (1 << bit_idx)) != 0;
518 let bit_offset = self.offset % 8;
519 let mut buffer_ptr = unsafe { self.buffer.as_ptr().add(self.offset / 8) };
520 let mut callback_idx = 0;
521
522 if bit_offset > 0 {
524 let bits_in_first_byte = (8 - bit_offset).min(total_bits);
525 let byte = unsafe { *buffer_ptr };
526
527 for bit_idx in 0..bits_in_first_byte {
528 f(callback_idx, is_bit_set(byte, bit_offset + bit_idx));
529 callback_idx += 1;
530 }
531
532 buffer_ptr = unsafe { buffer_ptr.add(1) };
533 }
534
535 let complete_bytes = (total_bits - callback_idx) / 8;
537 for _ in 0..complete_bytes {
538 let byte = unsafe { *buffer_ptr };
539
540 for bit_idx in 0..8 {
541 f(callback_idx, is_bit_set(byte, bit_idx));
542 callback_idx += 1;
543 }
544 buffer_ptr = unsafe { buffer_ptr.add(1) };
545 }
546
547 let remaining_bits = total_bits - callback_idx;
549 if remaining_bits > 0 {
550 let byte = unsafe { *buffer_ptr };
551
552 for bit_idx in 0..remaining_bits {
553 f(callback_idx, is_bit_set(byte, bit_idx));
554 callback_idx += 1;
555 }
556 }
557 }
558}
559
560impl<'a> IntoIterator for &'a BitBuffer {
561 type Item = bool;
562 type IntoIter = BitIterator<'a>;
563
564 fn into_iter(self) -> Self::IntoIter {
565 self.iter()
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use rstest::rstest;
572
573 use crate::ByteBuffer;
574 use crate::bit::BitBuffer;
575 use crate::buffer;
576
577 #[test]
578 fn test_bool() {
579 let buffer: ByteBuffer = buffer![1 << 7; 1024];
581 let bools = BitBuffer::new(buffer, 1024 * 8);
582
583 assert_eq!(bools.len(), 1024 * 8);
585 assert!(!bools.is_empty());
586 assert_eq!(bools.true_count(), 1024);
587 assert_eq!(bools.false_count(), 1024 * 7);
588
589 for word in 0..1024 {
591 for bit in 0..8 {
592 if bit == 7 {
593 assert!(bools.value(word * 8 + bit));
594 } else {
595 assert!(!bools.value(word * 8 + bit));
596 }
597 }
598 }
599
600 let sliced = bools.slice(64..72);
602
603 assert_eq!(sliced.len(), 8);
605 assert!(!sliced.is_empty());
606 assert_eq!(sliced.true_count(), 1);
607 assert_eq!(sliced.false_count(), 7);
608
609 for bit in 0..8 {
611 if bit == 7 {
612 assert!(sliced.value(bit));
613 } else {
614 assert!(!sliced.value(bit));
615 }
616 }
617 }
618
619 #[test]
620 fn test_padded_equaltiy() {
621 let buf1 = BitBuffer::new_set(64); let buf2 = BitBuffer::collect_bool(64, |x| x < 32); for i in 0..32 {
625 assert_eq!(buf1.value(i), buf2.value(i), "Bit {} should be the same", i);
626 }
627
628 for i in 32..64 {
629 assert_ne!(buf1.value(i), buf2.value(i), "Bit {} should differ", i);
630 }
631
632 assert_eq!(
633 buf1.slice(0..32),
634 buf2.slice(0..32),
635 "Buffer slices with same bits should be equal (`PartialEq` needs `iter_padded()`)"
636 );
637 assert_ne!(
638 buf1.slice(32..64),
639 buf2.slice(32..64),
640 "Buffer slices with different bits should not be equal (`PartialEq` needs `iter_padded()`)"
641 );
642 }
643
644 #[test]
645 fn test_slice_offset_calculation() {
646 let buf = BitBuffer::collect_bool(16, |_| true);
647 let sliced = buf.slice(10..16);
648 assert_eq!(sliced.len(), 6);
649 assert_eq!(sliced.offset(), 2);
651 }
652
653 #[rstest]
654 #[case(5)]
655 #[case(8)]
656 #[case(10)]
657 #[case(13)]
658 #[case(16)]
659 #[case(23)]
660 #[case(100)]
661 fn test_iter_bits(#[case] len: usize) {
662 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
663
664 let mut collected = Vec::new();
665 buf.iter_bits(|idx, is_set| {
666 collected.push((idx, is_set));
667 });
668
669 assert_eq!(collected.len(), len);
670
671 for (idx, is_set) in collected {
672 assert_eq!(is_set, idx % 2 == 0);
673 }
674 }
675
676 #[rstest]
677 #[case(3, 5)]
678 #[case(3, 8)]
679 #[case(5, 10)]
680 #[case(2, 16)]
681 #[case(8, 16)]
682 #[case(9, 16)]
683 #[case(17, 16)]
684 fn test_iter_bits_with_offset(#[case] offset: usize, #[case] len: usize) {
685 let total_bits = offset + len;
686 let buf = BitBuffer::collect_bool(total_bits, |i| i % 2 == 0);
687 let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
688
689 let mut collected = Vec::new();
690 buf_with_offset.iter_bits(|idx, is_set| {
691 collected.push((idx, is_set));
692 });
693
694 assert_eq!(collected.len(), len);
695
696 for (idx, is_set) in collected {
697 assert_eq!(is_set, (offset + idx).is_multiple_of(2));
699 }
700 }
701
702 #[rstest]
703 #[case(8, 10)]
704 #[case(9, 7)]
705 #[case(16, 8)]
706 #[case(17, 10)]
707 fn test_iter_bits_catches_wrong_byte_offset(#[case] offset: usize, #[case] len: usize) {
708 let total_bits = offset + len;
709 let buf = BitBuffer::collect_bool(total_bits, |i| (i / 8) % 2 == 0);
711
712 let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
713
714 let mut collected = Vec::new();
715 buf_with_offset.iter_bits(|idx, is_set| {
716 collected.push((idx, is_set));
717 });
718
719 assert_eq!(collected.len(), len);
720
721 for (idx, is_set) in collected {
722 let bit_position = offset + idx;
723 let byte_index = bit_position / 8;
724 let expected_is_set = byte_index.is_multiple_of(2);
725
726 assert_eq!(
727 is_set, expected_is_set,
728 "Bit mismatch at index {}: expected {} got {}",
729 bit_position, expected_is_set, is_set
730 );
731 }
732 }
733
734 #[rstest]
735 #[case(5)]
736 #[case(8)]
737 #[case(10)]
738 #[case(64)]
739 #[case(65)]
740 #[case(100)]
741 #[case(128)]
742 fn test_map_cmp_identity(#[case] len: usize) {
743 let buf = BitBuffer::collect_bool(len, |i| i % 3 == 0);
745 let mapped = buf.map_cmp(|_idx, bit| bit);
746
747 assert_eq!(buf.len(), mapped.len());
748 for i in 0..len {
749 assert_eq!(buf.value(i), mapped.value(i), "Mismatch at index {}", i);
750 }
751 }
752
753 #[rstest]
754 #[case(5)]
755 #[case(8)]
756 #[case(64)]
757 #[case(65)]
758 #[case(100)]
759 fn test_map_cmp_negate(#[case] len: usize) {
760 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
762 let mapped = buf.map_cmp(|_idx, bit| !bit);
763
764 assert_eq!(buf.len(), mapped.len());
765 for i in 0..len {
766 assert_eq!(!buf.value(i), mapped.value(i), "Mismatch at index {}", i);
767 }
768 }
769
770 #[test]
771 fn test_map_cmp_conditional() {
772 let len = 100;
774 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
775
776 let mapped = buf.map_cmp(|idx, bit| bit && idx % 4 == 0);
778
779 for i in 0..len {
780 let expected = (i % 2 == 0) && (i % 4 == 0);
781 assert_eq!(mapped.value(i), expected, "Mismatch at index {}", i);
782 }
783 }
784}