1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::fmt::Result as FmtResult;
7use std::ops::BitAnd;
8use std::ops::BitOr;
9use std::ops::BitXor;
10use std::ops::Bound;
11use std::ops::Not;
12use std::ops::RangeBounds;
13
14use crate::Alignment;
15use crate::BitBufferMut;
16use crate::Buffer;
17use crate::BufferMut;
18use crate::ByteBuffer;
19use crate::bit::BitChunks;
20use crate::bit::BitIndexIterator;
21use crate::bit::BitIterator;
22use crate::bit::BitSliceIterator;
23use crate::bit::UnalignedBitChunk;
24use crate::bit::get_bit_unchecked;
25use crate::bit::ops::bitwise_binary_op;
26use crate::bit::ops::bitwise_unary_op;
27use crate::buffer;
28use crate::trusted_len::TrustedLenExt;
29
30#[derive(Debug, Clone, Eq)]
32#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
33pub struct BitBuffer {
34 buffer: ByteBuffer,
35 offset: usize,
39 len: usize,
40}
41
42const LIMIT_LEN: usize = 16;
43impl Display for BitBuffer {
44 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
45 let limit = f.precision().unwrap_or(LIMIT_LEN);
46 let buf: Vec<bool> = self.into_iter().take(limit).collect();
47 f.debug_struct("BitBuffer")
48 .field("len", &self.len)
49 .field("buffer", &buf)
50 .finish()
51 }
52}
53
54impl PartialEq for BitBuffer {
55 fn eq(&self, other: &Self) -> bool {
56 if self.len != other.len {
57 return false;
58 }
59
60 self.chunks()
61 .iter_padded()
62 .zip(other.chunks().iter_padded())
63 .all(|(a, b)| a == b)
64 }
65}
66
67impl BitBuffer {
68 pub fn new(buffer: ByteBuffer, len: usize) -> Self {
72 assert!(
73 buffer.len() * 8 >= len,
74 "provided ByteBuffer not large enough to back BoolBuffer with len {len}"
75 );
76
77 let buffer = buffer.aligned(Alignment::none());
79
80 Self {
81 buffer,
82 len,
83 offset: 0,
84 }
85 }
86
87 pub fn new_with_offset(buffer: ByteBuffer, len: usize, offset: usize) -> Self {
92 assert!(
93 len.saturating_add(offset) <= buffer.len().saturating_mul(8),
94 "provided ByteBuffer (len={}) not large enough to back BoolBuffer with offset {offset} len {len}",
95 buffer.len()
96 );
97
98 let buffer = buffer.aligned(Alignment::none());
100
101 let byte_offset = offset / 8;
103 let offset = offset % 8;
104 let buffer = if byte_offset != 0 {
105 buffer.slice(byte_offset..)
106 } else {
107 buffer
108 };
109
110 Self {
111 buffer,
112 offset,
113 len,
114 }
115 }
116
117 pub fn new_set(len: usize) -> Self {
119 let words = len.div_ceil(8);
120 let buffer = buffer![0xFF; words];
121
122 Self {
123 buffer,
124 len,
125 offset: 0,
126 }
127 }
128
129 pub fn new_unset(len: usize) -> Self {
131 let words = len.div_ceil(8);
132 let buffer = Buffer::zeroed(words);
133
134 Self {
135 buffer,
136 len,
137 offset: 0,
138 }
139 }
140
141 pub fn from_indices(len: usize, indices: &[usize]) -> BitBuffer {
143 BitBufferMut::from_indices(len, indices).freeze()
144 }
145
146 pub fn empty() -> Self {
148 Self::new_set(0)
149 }
150
151 pub fn full(value: bool, len: usize) -> Self {
153 if value {
154 Self::new_set(len)
155 } else {
156 Self::new_unset(len)
157 }
158 }
159
160 #[inline]
162 pub fn collect_bool<F: FnMut(usize) -> bool>(len: usize, f: F) -> Self {
163 BitBufferMut::collect_bool(len, f).freeze()
164 }
165
166 pub fn map_cmp<F>(&self, mut f: F) -> Self
171 where
172 F: FnMut(usize, bool) -> bool,
173 {
174 let len = self.len;
175 let mut buffer: BufferMut<u64> = BufferMut::with_capacity(len.div_ceil(64));
176
177 let chunks_count = len / 64;
178 let remainder = len % 64;
179 let chunks = self.chunks();
180
181 for (chunk_idx, src_chunk) in chunks.iter().enumerate() {
182 let mut packed = 0u64;
183 for bit_idx in 0..64 {
184 let i = bit_idx + chunk_idx * 64;
185 let bit_value = (src_chunk >> bit_idx) & 1 == 1;
186 packed |= (f(i, bit_value) as u64) << bit_idx;
187 }
188
189 unsafe { buffer.push_unchecked(packed) }
191 }
192
193 if remainder != 0 {
194 let src_chunk = chunks.remainder_bits();
195 let mut packed = 0u64;
196 for bit_idx in 0..remainder {
197 let i = bit_idx + chunks_count * 64;
198 let bit_value = (src_chunk >> bit_idx) & 1 == 1;
199 packed |= (f(i, bit_value) as u64) << bit_idx;
200 }
201
202 unsafe { buffer.push_unchecked(packed) }
204 }
205
206 buffer.truncate(len.div_ceil(8));
207
208 Self {
209 buffer: buffer.freeze().into_byte_buffer(),
210 offset: 0,
211 len,
212 }
213 }
214
215 pub fn clear(&mut self) {
217 self.buffer.clear();
218 self.len = 0;
219 self.offset = 0;
220 }
221
222 #[inline]
227 pub fn len(&self) -> usize {
228 self.len
229 }
230
231 #[inline]
233 pub fn is_empty(&self) -> bool {
234 self.len() == 0
235 }
236
237 #[inline(always)]
239 pub fn offset(&self) -> usize {
240 self.offset
241 }
242
243 #[inline(always)]
245 pub fn inner(&self) -> &ByteBuffer {
246 &self.buffer
247 }
248
249 #[inline]
255 pub fn value(&self, index: usize) -> bool {
256 assert!(index < self.len);
257 unsafe { self.value_unchecked(index) }
258 }
259
260 #[inline]
265 pub unsafe fn value_unchecked(&self, index: usize) -> bool {
266 unsafe { get_bit_unchecked(self.buffer.as_ptr(), index + self.offset) }
267 }
268
269 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
274 let start = match range.start_bound() {
275 Bound::Included(&s) => s,
276 Bound::Excluded(&s) => s + 1,
277 Bound::Unbounded => 0,
278 };
279 let end = match range.end_bound() {
280 Bound::Included(&e) => e + 1,
281 Bound::Excluded(&e) => e,
282 Bound::Unbounded => self.len,
283 };
284
285 assert!(start <= end);
286 assert!(start <= self.len);
287 assert!(end <= self.len);
288 let len = end - start;
289
290 Self::new_with_offset(self.buffer.clone(), len, self.offset + start)
291 }
292
293 pub fn shrink_offset(self) -> Self {
295 let word_start = self.offset / 8;
296 let word_end = (self.offset + self.len).div_ceil(8);
297
298 let buffer = self.buffer.slice(word_start..word_end);
299
300 let bit_offset = self.offset % 8;
301 let len = self.len;
302 BitBuffer::new_with_offset(buffer, len, bit_offset)
303 }
304
305 pub fn unaligned_chunks(&self) -> UnalignedBitChunk<'_> {
307 UnalignedBitChunk::new(self.buffer.as_slice(), self.offset, self.len)
308 }
309
310 pub fn chunks(&self) -> BitChunks<'_> {
314 BitChunks::new(self.buffer.as_slice(), self.offset, self.len)
315 }
316
317 pub fn true_count(&self) -> usize {
319 self.unaligned_chunks().count_ones()
320 }
321
322 pub fn false_count(&self) -> usize {
324 self.len - self.true_count()
325 }
326
327 pub fn iter(&self) -> BitIterator<'_> {
329 BitIterator::new(self.buffer.as_slice(), self.offset, self.len)
330 }
331
332 pub fn set_indices(&self) -> BitIndexIterator<'_> {
334 BitIndexIterator::new(self.buffer.as_slice(), self.offset, self.len)
335 }
336
337 pub fn set_slices(&self) -> BitSliceIterator<'_> {
339 BitSliceIterator::new(self.buffer.as_slice(), self.offset, self.len)
340 }
341
342 pub fn sliced(&self) -> Self {
344 if self.offset.is_multiple_of(8) {
345 return Self::new(
346 self.buffer.slice(self.offset / 8..self.len.div_ceil(8)),
347 self.len,
348 );
349 }
350 let iter = self.chunks().iter_padded();
352 let iter = unsafe { iter.trusted_len() };
353 let result = Buffer::<u64>::from_trusted_len_iter(iter).into_byte_buffer();
354
355 BitBuffer::new(result, self.len())
356 }
357}
358
359impl BitBuffer {
362 pub fn into_inner(self) -> (usize, usize, ByteBuffer) {
364 (self.offset, self.len, self.buffer)
365 }
366
367 pub fn try_into_mut(self) -> Result<BitBufferMut, Self> {
369 match self.buffer.try_into_mut() {
370 Ok(buffer) => Ok(BitBufferMut::from_buffer(buffer, self.offset, self.len)),
371 Err(buffer) => Err(BitBuffer::new_with_offset(buffer, self.len, self.offset)),
372 }
373 }
374
375 pub fn into_mut(self) -> BitBufferMut {
380 let (offset, len, inner) = self.into_inner();
381 BitBufferMut::from_buffer(inner.into_mut(), offset, len)
383 }
384}
385
386impl From<&[bool]> for BitBuffer {
387 fn from(value: &[bool]) -> Self {
388 BitBufferMut::from(value).freeze()
389 }
390}
391
392impl From<Vec<bool>> for BitBuffer {
393 fn from(value: Vec<bool>) -> Self {
394 BitBufferMut::from(value).freeze()
395 }
396}
397
398impl FromIterator<bool> for BitBuffer {
399 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
400 BitBufferMut::from_iter(iter).freeze()
401 }
402}
403
404impl BitOr for BitBuffer {
405 type Output = Self;
406
407 fn bitor(self, rhs: Self) -> Self::Output {
408 BitOr::bitor(&self, &rhs)
409 }
410}
411
412impl BitOr for &BitBuffer {
413 type Output = BitBuffer;
414
415 fn bitor(self, rhs: Self) -> Self::Output {
416 bitwise_binary_op(self, rhs, |a, b| a | b)
417 }
418}
419
420impl BitOr<&BitBuffer> for BitBuffer {
421 type Output = BitBuffer;
422
423 fn bitor(self, rhs: &BitBuffer) -> Self::Output {
424 (&self).bitor(rhs)
425 }
426}
427
428impl BitAnd for &BitBuffer {
429 type Output = BitBuffer;
430
431 fn bitand(self, rhs: Self) -> Self::Output {
432 bitwise_binary_op(self, rhs, |a, b| a & b)
433 }
434}
435
436impl BitAnd<BitBuffer> for &BitBuffer {
437 type Output = BitBuffer;
438
439 fn bitand(self, rhs: BitBuffer) -> Self::Output {
440 self.bitand(&rhs)
441 }
442}
443
444impl BitAnd<&BitBuffer> for BitBuffer {
445 type Output = BitBuffer;
446
447 fn bitand(self, rhs: &BitBuffer) -> Self::Output {
448 (&self).bitand(rhs)
449 }
450}
451
452impl BitAnd<BitBuffer> for BitBuffer {
453 type Output = BitBuffer;
454
455 fn bitand(self, rhs: BitBuffer) -> Self::Output {
456 (&self).bitand(&rhs)
457 }
458}
459
460impl Not for &BitBuffer {
461 type Output = BitBuffer;
462
463 fn not(self) -> Self::Output {
464 bitwise_unary_op(self, |a| !a)
465 }
466}
467
468impl Not for BitBuffer {
469 type Output = BitBuffer;
470
471 fn not(self) -> Self::Output {
472 (&self).not()
473 }
474}
475
476impl BitXor for &BitBuffer {
477 type Output = BitBuffer;
478
479 fn bitxor(self, rhs: Self) -> Self::Output {
480 bitwise_binary_op(self, rhs, |a, b| a ^ b)
481 }
482}
483
484impl BitXor<&BitBuffer> for BitBuffer {
485 type Output = BitBuffer;
486
487 fn bitxor(self, rhs: &BitBuffer) -> Self::Output {
488 (&self).bitxor(rhs)
489 }
490}
491
492impl BitBuffer {
493 pub fn bitand_not(&self, rhs: &BitBuffer) -> BitBuffer {
498 bitwise_binary_op(self, rhs, |a, b| a & !b)
499 }
500
501 #[inline]
511 pub fn iter_bits<F>(&self, mut f: F)
512 where
513 F: FnMut(usize, bool),
514 {
515 let total_bits = self.len;
516 if total_bits == 0 {
517 return;
518 }
519
520 let is_bit_set = |byte: u8, bit_idx: usize| (byte & (1 << bit_idx)) != 0;
521 let bit_offset = self.offset % 8;
522 let mut buffer_ptr = unsafe { self.buffer.as_ptr().add(self.offset / 8) };
523 let mut callback_idx = 0;
524
525 if bit_offset > 0 {
527 let bits_in_first_byte = (8 - bit_offset).min(total_bits);
528 let byte = unsafe { *buffer_ptr };
529
530 for bit_idx in 0..bits_in_first_byte {
531 f(callback_idx, is_bit_set(byte, bit_offset + bit_idx));
532 callback_idx += 1;
533 }
534
535 buffer_ptr = unsafe { buffer_ptr.add(1) };
536 }
537
538 let complete_bytes = (total_bits - callback_idx) / 8;
540 for _ in 0..complete_bytes {
541 let byte = unsafe { *buffer_ptr };
542
543 for bit_idx in 0..8 {
544 f(callback_idx, is_bit_set(byte, bit_idx));
545 callback_idx += 1;
546 }
547 buffer_ptr = unsafe { buffer_ptr.add(1) };
548 }
549
550 let remaining_bits = total_bits - callback_idx;
552 if remaining_bits > 0 {
553 let byte = unsafe { *buffer_ptr };
554
555 for bit_idx in 0..remaining_bits {
556 f(callback_idx, is_bit_set(byte, bit_idx));
557 callback_idx += 1;
558 }
559 }
560 }
561}
562
563impl<'a> IntoIterator for &'a BitBuffer {
564 type Item = bool;
565 type IntoIter = BitIterator<'a>;
566
567 fn into_iter(self) -> Self::IntoIter {
568 self.iter()
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use rstest::rstest;
575
576 use crate::ByteBuffer;
577 use crate::bit::BitBuffer;
578 use crate::buffer;
579
580 #[test]
581 fn test_bool() {
582 let buffer: ByteBuffer = buffer![1 << 7; 1024];
584 let bools = BitBuffer::new(buffer, 1024 * 8);
585
586 assert_eq!(bools.len(), 1024 * 8);
588 assert!(!bools.is_empty());
589 assert_eq!(bools.true_count(), 1024);
590 assert_eq!(bools.false_count(), 1024 * 7);
591
592 for word in 0..1024 {
594 for bit in 0..8 {
595 if bit == 7 {
596 assert!(bools.value(word * 8 + bit));
597 } else {
598 assert!(!bools.value(word * 8 + bit));
599 }
600 }
601 }
602
603 let sliced = bools.slice(64..72);
605
606 assert_eq!(sliced.len(), 8);
608 assert!(!sliced.is_empty());
609 assert_eq!(sliced.true_count(), 1);
610 assert_eq!(sliced.false_count(), 7);
611
612 for bit in 0..8 {
614 if bit == 7 {
615 assert!(sliced.value(bit));
616 } else {
617 assert!(!sliced.value(bit));
618 }
619 }
620 }
621
622 #[test]
623 fn test_padded_equaltiy() {
624 let buf1 = BitBuffer::new_set(64); let buf2 = BitBuffer::collect_bool(64, |x| x < 32); for i in 0..32 {
628 assert_eq!(buf1.value(i), buf2.value(i), "Bit {} should be the same", i);
629 }
630
631 for i in 32..64 {
632 assert_ne!(buf1.value(i), buf2.value(i), "Bit {} should differ", i);
633 }
634
635 assert_eq!(
636 buf1.slice(0..32),
637 buf2.slice(0..32),
638 "Buffer slices with same bits should be equal (`PartialEq` needs `iter_padded()`)"
639 );
640 assert_ne!(
641 buf1.slice(32..64),
642 buf2.slice(32..64),
643 "Buffer slices with different bits should not be equal (`PartialEq` needs `iter_padded()`)"
644 );
645 }
646
647 #[test]
648 fn test_slice_offset_calculation() {
649 let buf = BitBuffer::collect_bool(16, |_| true);
650 let sliced = buf.slice(10..16);
651 assert_eq!(sliced.len(), 6);
652 assert_eq!(sliced.offset(), 2);
654 }
655
656 #[rstest]
657 #[case(5)]
658 #[case(8)]
659 #[case(10)]
660 #[case(13)]
661 #[case(16)]
662 #[case(23)]
663 #[case(100)]
664 fn test_iter_bits(#[case] len: usize) {
665 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
666
667 let mut collected = Vec::new();
668 buf.iter_bits(|idx, is_set| {
669 collected.push((idx, is_set));
670 });
671
672 assert_eq!(collected.len(), len);
673
674 for (idx, is_set) in collected {
675 assert_eq!(is_set, idx % 2 == 0);
676 }
677 }
678
679 #[rstest]
680 #[case(3, 5)]
681 #[case(3, 8)]
682 #[case(5, 10)]
683 #[case(2, 16)]
684 #[case(8, 16)]
685 #[case(9, 16)]
686 #[case(17, 16)]
687 fn test_iter_bits_with_offset(#[case] offset: usize, #[case] len: usize) {
688 let total_bits = offset + len;
689 let buf = BitBuffer::collect_bool(total_bits, |i| i % 2 == 0);
690 let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
691
692 let mut collected = Vec::new();
693 buf_with_offset.iter_bits(|idx, is_set| {
694 collected.push((idx, is_set));
695 });
696
697 assert_eq!(collected.len(), len);
698
699 for (idx, is_set) in collected {
700 assert_eq!(is_set, (offset + idx).is_multiple_of(2));
702 }
703 }
704
705 #[rstest]
706 #[case(8, 10)]
707 #[case(9, 7)]
708 #[case(16, 8)]
709 #[case(17, 10)]
710 fn test_iter_bits_catches_wrong_byte_offset(#[case] offset: usize, #[case] len: usize) {
711 let total_bits = offset + len;
712 let buf = BitBuffer::collect_bool(total_bits, |i| (i / 8) % 2 == 0);
714
715 let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
716
717 let mut collected = Vec::new();
718 buf_with_offset.iter_bits(|idx, is_set| {
719 collected.push((idx, is_set));
720 });
721
722 assert_eq!(collected.len(), len);
723
724 for (idx, is_set) in collected {
725 let bit_position = offset + idx;
726 let byte_index = bit_position / 8;
727 let expected_is_set = byte_index.is_multiple_of(2);
728
729 assert_eq!(
730 is_set, expected_is_set,
731 "Bit mismatch at index {}: expected {} got {}",
732 bit_position, expected_is_set, is_set
733 );
734 }
735 }
736
737 #[rstest]
738 #[case(5)]
739 #[case(8)]
740 #[case(10)]
741 #[case(64)]
742 #[case(65)]
743 #[case(100)]
744 #[case(128)]
745 fn test_map_cmp_identity(#[case] len: usize) {
746 let buf = BitBuffer::collect_bool(len, |i| i % 3 == 0);
748 let mapped = buf.map_cmp(|_idx, bit| bit);
749
750 assert_eq!(buf.len(), mapped.len());
751 for i in 0..len {
752 assert_eq!(buf.value(i), mapped.value(i), "Mismatch at index {}", i);
753 }
754 }
755
756 #[rstest]
757 #[case(5)]
758 #[case(8)]
759 #[case(64)]
760 #[case(65)]
761 #[case(100)]
762 fn test_map_cmp_negate(#[case] len: usize) {
763 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
765 let mapped = buf.map_cmp(|_idx, bit| !bit);
766
767 assert_eq!(buf.len(), mapped.len());
768 for i in 0..len {
769 assert_eq!(!buf.value(i), mapped.value(i), "Mismatch at index {}", i);
770 }
771 }
772
773 #[test]
774 fn test_map_cmp_conditional() {
775 let len = 100;
777 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
778
779 let mapped = buf.map_cmp(|idx, bit| bit && idx % 4 == 0);
781
782 for i in 0..len {
783 let expected = (i % 2 == 0) && (i % 4 == 0);
784 assert_eq!(mapped.value(i), expected, "Mismatch at index {}", i);
785 }
786 }
787}