1use std::ops::BitAnd;
5use std::ops::BitOr;
6use std::ops::BitXor;
7use std::ops::Bound;
8use std::ops::Not;
9use std::ops::RangeBounds;
10
11use crate::Alignment;
12use crate::BitBufferMut;
13use crate::Buffer;
14use crate::BufferMut;
15use crate::ByteBuffer;
16use crate::bit::BitChunks;
17use crate::bit::BitIndexIterator;
18use crate::bit::BitIterator;
19use crate::bit::BitSliceIterator;
20use crate::bit::UnalignedBitChunk;
21use crate::bit::get_bit_unchecked;
22use crate::bit::ops::bitwise_binary_op;
23use crate::bit::ops::bitwise_unary_op;
24use crate::buffer;
25
26#[derive(Debug, Clone, Eq)]
28#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
29pub struct BitBuffer {
30 buffer: ByteBuffer,
31 offset: usize,
35 len: usize,
36}
37
38impl PartialEq for BitBuffer {
39 fn eq(&self, other: &Self) -> bool {
40 if self.len != other.len {
41 return false;
42 }
43
44 self.chunks()
45 .iter_padded()
46 .zip(other.chunks().iter_padded())
47 .all(|(a, b)| a == b)
48 }
49}
50
51impl BitBuffer {
52 pub fn new(buffer: ByteBuffer, len: usize) -> Self {
56 assert!(
57 buffer.len() * 8 >= len,
58 "provided ByteBuffer not large enough to back BoolBuffer with len {len}"
59 );
60
61 let buffer = buffer.aligned(Alignment::none());
63
64 Self {
65 buffer,
66 len,
67 offset: 0,
68 }
69 }
70
71 pub fn new_with_offset(buffer: ByteBuffer, len: usize, offset: usize) -> Self {
76 assert!(
77 len.saturating_add(offset) <= buffer.len().saturating_mul(8),
78 "provided ByteBuffer (len={}) not large enough to back BoolBuffer with offset {offset} len {len}",
79 buffer.len()
80 );
81
82 let buffer = buffer.aligned(Alignment::none());
84
85 let byte_offset = offset / 8;
87 let offset = offset % 8;
88 let buffer = buffer.slice(byte_offset..);
89
90 Self {
91 buffer,
92 offset,
93 len,
94 }
95 }
96
97 pub fn new_set(len: usize) -> Self {
99 let words = len.div_ceil(8);
100 let buffer = buffer![0xFF; words];
101
102 Self {
103 buffer,
104 len,
105 offset: 0,
106 }
107 }
108
109 pub fn new_unset(len: usize) -> Self {
111 let words = len.div_ceil(8);
112 let buffer = Buffer::zeroed(words);
113
114 Self {
115 buffer,
116 len,
117 offset: 0,
118 }
119 }
120
121 pub fn from_indices(len: usize, indices: &[usize]) -> BitBuffer {
123 BitBufferMut::from_indices(len, indices).freeze()
124 }
125
126 pub fn empty() -> Self {
128 Self::new_set(0)
129 }
130
131 pub fn full(value: bool, len: usize) -> Self {
133 if value {
134 Self::new_set(len)
135 } else {
136 Self::new_unset(len)
137 }
138 }
139
140 #[inline]
142 pub fn collect_bool<F: FnMut(usize) -> bool>(len: usize, f: F) -> Self {
143 BitBufferMut::collect_bool(len, f).freeze()
144 }
145
146 pub fn map_cmp<F>(&self, mut f: F) -> Self
151 where
152 F: FnMut(usize, bool) -> bool,
153 {
154 let len = self.len;
155 let mut buffer: BufferMut<u64> = BufferMut::with_capacity(len.div_ceil(64));
156
157 let chunks_count = len / 64;
158 let remainder = len % 64;
159 let chunks = self.chunks();
160
161 for (chunk_idx, src_chunk) in chunks.iter().enumerate() {
162 let mut packed = 0u64;
163 for bit_idx in 0..64 {
164 let i = bit_idx + chunk_idx * 64;
165 let bit_value = (src_chunk >> bit_idx) & 1 == 1;
166 packed |= (f(i, bit_value) as u64) << bit_idx;
167 }
168
169 unsafe { buffer.push_unchecked(packed) }
171 }
172
173 if remainder != 0 {
174 let src_chunk = chunks.remainder_bits();
175 let mut packed = 0u64;
176 for bit_idx in 0..remainder {
177 let i = bit_idx + chunks_count * 64;
178 let bit_value = (src_chunk >> bit_idx) & 1 == 1;
179 packed |= (f(i, bit_value) as u64) << bit_idx;
180 }
181
182 unsafe { buffer.push_unchecked(packed) }
184 }
185
186 buffer.truncate(len.div_ceil(8));
187
188 Self {
189 buffer: buffer.freeze().into_byte_buffer(),
190 offset: 0,
191 len,
192 }
193 }
194
195 pub fn clear(&mut self) {
197 self.buffer.clear();
198 self.len = 0;
199 self.offset = 0;
200 }
201
202 #[inline]
207 pub fn len(&self) -> usize {
208 self.len
209 }
210
211 #[inline]
213 pub fn is_empty(&self) -> bool {
214 self.len() == 0
215 }
216
217 #[inline(always)]
219 pub fn offset(&self) -> usize {
220 self.offset
221 }
222
223 #[inline(always)]
225 pub fn inner(&self) -> &ByteBuffer {
226 &self.buffer
227 }
228
229 #[inline]
235 pub fn value(&self, index: usize) -> bool {
236 assert!(index < self.len);
237 unsafe { self.value_unchecked(index) }
238 }
239
240 #[inline]
245 pub unsafe fn value_unchecked(&self, index: usize) -> bool {
246 unsafe { get_bit_unchecked(self.buffer.as_ptr(), index + self.offset) }
247 }
248
249 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
254 let start = match range.start_bound() {
255 Bound::Included(&s) => s,
256 Bound::Excluded(&s) => s + 1,
257 Bound::Unbounded => 0,
258 };
259 let end = match range.end_bound() {
260 Bound::Included(&e) => e + 1,
261 Bound::Excluded(&e) => e,
262 Bound::Unbounded => self.len,
263 };
264
265 assert!(start <= end);
266 assert!(start <= self.len);
267 assert!(end <= self.len);
268 let len = end - start;
269
270 Self::new_with_offset(self.buffer.clone(), len, self.offset + start)
271 }
272
273 pub fn shrink_offset(self) -> Self {
275 let word_start = self.offset / 8;
276 let word_end = (self.offset + self.len).div_ceil(8);
277
278 let buffer = self.buffer.slice(word_start..word_end);
279
280 let bit_offset = self.offset % 8;
281 let len = self.len;
282 BitBuffer::new_with_offset(buffer, len, bit_offset)
283 }
284
285 pub fn unaligned_chunks(&self) -> UnalignedBitChunk<'_> {
287 UnalignedBitChunk::new(self.buffer.as_slice(), self.offset, self.len)
288 }
289
290 pub fn chunks(&self) -> BitChunks<'_> {
294 BitChunks::new(self.buffer.as_slice(), self.offset, self.len)
295 }
296
297 pub fn true_count(&self) -> usize {
299 self.unaligned_chunks().count_ones()
300 }
301
302 pub fn false_count(&self) -> usize {
304 self.len - self.true_count()
305 }
306
307 pub fn iter(&self) -> BitIterator<'_> {
309 BitIterator::new(self.buffer.as_slice(), self.offset, self.len)
310 }
311
312 pub fn set_indices(&self) -> BitIndexIterator<'_> {
314 BitIndexIterator::new(self.buffer.as_slice(), self.offset, self.len)
315 }
316
317 pub fn set_slices(&self) -> BitSliceIterator<'_> {
319 BitSliceIterator::new(self.buffer.as_slice(), self.offset, self.len)
320 }
321
322 pub fn sliced(&self) -> Self {
324 if self.offset.is_multiple_of(8) {
325 return Self::new(
326 self.buffer.slice(self.offset / 8..self.len.div_ceil(8)),
327 self.len,
328 );
329 }
330 bitwise_unary_op(self, |a| a)
331 }
332}
333
334impl BitBuffer {
337 pub fn into_inner(self) -> (usize, usize, ByteBuffer) {
339 (self.offset, self.len, self.buffer)
340 }
341
342 pub fn try_into_mut(self) -> Result<BitBufferMut, Self> {
344 match self.buffer.try_into_mut() {
345 Ok(buffer) => Ok(BitBufferMut::from_buffer(buffer, self.offset, self.len)),
346 Err(buffer) => Err(BitBuffer::new_with_offset(buffer, self.len, self.offset)),
347 }
348 }
349
350 pub fn into_mut(self) -> BitBufferMut {
355 let (offset, len, inner) = self.into_inner();
356 BitBufferMut::from_buffer(inner.into_mut(), offset, len)
358 }
359}
360
361impl From<&[bool]> for BitBuffer {
362 fn from(value: &[bool]) -> Self {
363 BitBufferMut::from(value).freeze()
364 }
365}
366
367impl From<Vec<bool>> for BitBuffer {
368 fn from(value: Vec<bool>) -> Self {
369 BitBufferMut::from(value).freeze()
370 }
371}
372
373impl FromIterator<bool> for BitBuffer {
374 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
375 BitBufferMut::from_iter(iter).freeze()
376 }
377}
378
379impl BitOr for BitBuffer {
380 type Output = Self;
381
382 fn bitor(self, rhs: Self) -> Self::Output {
383 BitOr::bitor(&self, &rhs)
384 }
385}
386
387impl BitOr for &BitBuffer {
388 type Output = BitBuffer;
389
390 fn bitor(self, rhs: Self) -> Self::Output {
391 bitwise_binary_op(self, rhs, |a, b| a | b)
392 }
393}
394
395impl BitOr<&BitBuffer> for BitBuffer {
396 type Output = BitBuffer;
397
398 fn bitor(self, rhs: &BitBuffer) -> Self::Output {
399 (&self).bitor(rhs)
400 }
401}
402
403impl BitAnd for &BitBuffer {
404 type Output = BitBuffer;
405
406 fn bitand(self, rhs: Self) -> Self::Output {
407 bitwise_binary_op(self, rhs, |a, b| a & b)
408 }
409}
410
411impl BitAnd<BitBuffer> for &BitBuffer {
412 type Output = BitBuffer;
413
414 fn bitand(self, rhs: BitBuffer) -> Self::Output {
415 self.bitand(&rhs)
416 }
417}
418
419impl BitAnd<&BitBuffer> for BitBuffer {
420 type Output = BitBuffer;
421
422 fn bitand(self, rhs: &BitBuffer) -> Self::Output {
423 (&self).bitand(rhs)
424 }
425}
426
427impl BitAnd<BitBuffer> for BitBuffer {
428 type Output = BitBuffer;
429
430 fn bitand(self, rhs: BitBuffer) -> Self::Output {
431 (&self).bitand(&rhs)
432 }
433}
434
435impl Not for &BitBuffer {
436 type Output = BitBuffer;
437
438 fn not(self) -> Self::Output {
439 bitwise_unary_op(self, |a| !a)
440 }
441}
442
443impl Not for BitBuffer {
444 type Output = BitBuffer;
445
446 fn not(self) -> Self::Output {
447 (&self).not()
448 }
449}
450
451impl BitXor for &BitBuffer {
452 type Output = BitBuffer;
453
454 fn bitxor(self, rhs: Self) -> Self::Output {
455 bitwise_binary_op(self, rhs, |a, b| a ^ b)
456 }
457}
458
459impl BitXor<&BitBuffer> for BitBuffer {
460 type Output = BitBuffer;
461
462 fn bitxor(self, rhs: &BitBuffer) -> Self::Output {
463 (&self).bitxor(rhs)
464 }
465}
466
467impl BitBuffer {
468 pub fn bitand_not(&self, rhs: &BitBuffer) -> BitBuffer {
473 bitwise_binary_op(self, rhs, |a, b| a & !b)
474 }
475
476 #[inline]
486 pub fn iter_bits<F>(&self, mut f: F)
487 where
488 F: FnMut(usize, bool),
489 {
490 let total_bits = self.len;
491 if total_bits == 0 {
492 return;
493 }
494
495 let is_bit_set = |byte: u8, bit_idx: usize| (byte & (1 << bit_idx)) != 0;
496 let bit_offset = self.offset % 8;
497 let mut buffer_ptr = unsafe { self.buffer.as_ptr().add(self.offset / 8) };
498 let mut callback_idx = 0;
499
500 if bit_offset > 0 {
502 let bits_in_first_byte = (8 - bit_offset).min(total_bits);
503 let byte = unsafe { *buffer_ptr };
504
505 for bit_idx in 0..bits_in_first_byte {
506 f(callback_idx, is_bit_set(byte, bit_offset + bit_idx));
507 callback_idx += 1;
508 }
509
510 buffer_ptr = unsafe { buffer_ptr.add(1) };
511 }
512
513 let complete_bytes = (total_bits - callback_idx) / 8;
515 for _ in 0..complete_bytes {
516 let byte = unsafe { *buffer_ptr };
517
518 for bit_idx in 0..8 {
519 f(callback_idx, is_bit_set(byte, bit_idx));
520 callback_idx += 1;
521 }
522 buffer_ptr = unsafe { buffer_ptr.add(1) };
523 }
524
525 let remaining_bits = total_bits - callback_idx;
527 if remaining_bits > 0 {
528 let byte = unsafe { *buffer_ptr };
529
530 for bit_idx in 0..remaining_bits {
531 f(callback_idx, is_bit_set(byte, bit_idx));
532 callback_idx += 1;
533 }
534 }
535 }
536}
537
538impl<'a> IntoIterator for &'a BitBuffer {
539 type Item = bool;
540 type IntoIter = BitIterator<'a>;
541
542 fn into_iter(self) -> Self::IntoIter {
543 self.iter()
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use rstest::rstest;
550
551 use crate::ByteBuffer;
552 use crate::bit::BitBuffer;
553 use crate::buffer;
554
555 #[test]
556 fn test_bool() {
557 let buffer: ByteBuffer = buffer![1 << 7; 1024];
559 let bools = BitBuffer::new(buffer, 1024 * 8);
560
561 assert_eq!(bools.len(), 1024 * 8);
563 assert!(!bools.is_empty());
564 assert_eq!(bools.true_count(), 1024);
565 assert_eq!(bools.false_count(), 1024 * 7);
566
567 for word in 0..1024 {
569 for bit in 0..8 {
570 if bit == 7 {
571 assert!(bools.value(word * 8 + bit));
572 } else {
573 assert!(!bools.value(word * 8 + bit));
574 }
575 }
576 }
577
578 let sliced = bools.slice(64..72);
580
581 assert_eq!(sliced.len(), 8);
583 assert!(!sliced.is_empty());
584 assert_eq!(sliced.true_count(), 1);
585 assert_eq!(sliced.false_count(), 7);
586
587 for bit in 0..8 {
589 if bit == 7 {
590 assert!(sliced.value(bit));
591 } else {
592 assert!(!sliced.value(bit));
593 }
594 }
595 }
596
597 #[test]
598 fn test_padded_equaltiy() {
599 let buf1 = BitBuffer::new_set(64); let buf2 = BitBuffer::collect_bool(64, |x| x < 32); for i in 0..32 {
603 assert_eq!(buf1.value(i), buf2.value(i), "Bit {} should be the same", i);
604 }
605
606 for i in 32..64 {
607 assert_ne!(buf1.value(i), buf2.value(i), "Bit {} should differ", i);
608 }
609
610 assert_eq!(
611 buf1.slice(0..32),
612 buf2.slice(0..32),
613 "Buffer slices with same bits should be equal (`PartialEq` needs `iter_padded()`)"
614 );
615 assert_ne!(
616 buf1.slice(32..64),
617 buf2.slice(32..64),
618 "Buffer slices with different bits should not be equal (`PartialEq` needs `iter_padded()`)"
619 );
620 }
621
622 #[test]
623 fn test_slice_offset_calculation() {
624 let buf = BitBuffer::collect_bool(16, |_| true);
625 let sliced = buf.slice(10..16);
626 assert_eq!(sliced.len(), 6);
627 assert_eq!(sliced.offset(), 2);
629 }
630
631 #[rstest]
632 #[case(5)]
633 #[case(8)]
634 #[case(10)]
635 #[case(13)]
636 #[case(16)]
637 #[case(23)]
638 #[case(100)]
639 fn test_iter_bits(#[case] len: usize) {
640 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
641
642 let mut collected = Vec::new();
643 buf.iter_bits(|idx, is_set| {
644 collected.push((idx, is_set));
645 });
646
647 assert_eq!(collected.len(), len);
648
649 for (idx, is_set) in collected {
650 assert_eq!(is_set, idx % 2 == 0);
651 }
652 }
653
654 #[rstest]
655 #[case(3, 5)]
656 #[case(3, 8)]
657 #[case(5, 10)]
658 #[case(2, 16)]
659 #[case(8, 16)]
660 #[case(9, 16)]
661 #[case(17, 16)]
662 fn test_iter_bits_with_offset(#[case] offset: usize, #[case] len: usize) {
663 let total_bits = offset + len;
664 let buf = BitBuffer::collect_bool(total_bits, |i| i % 2 == 0);
665 let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
666
667 let mut collected = Vec::new();
668 buf_with_offset.iter_bits(|idx, is_set| {
669 collected.push((idx, is_set));
670 });
671
672 assert_eq!(collected.len(), len);
673
674 for (idx, is_set) in collected {
675 assert_eq!(is_set, (offset + idx).is_multiple_of(2));
677 }
678 }
679
680 #[rstest]
681 #[case(8, 10)]
682 #[case(9, 7)]
683 #[case(16, 8)]
684 #[case(17, 10)]
685 fn test_iter_bits_catches_wrong_byte_offset(#[case] offset: usize, #[case] len: usize) {
686 let total_bits = offset + len;
687 let buf = BitBuffer::collect_bool(total_bits, |i| (i / 8) % 2 == 0);
689
690 let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
691
692 let mut collected = Vec::new();
693 buf_with_offset.iter_bits(|idx, is_set| {
694 collected.push((idx, is_set));
695 });
696
697 assert_eq!(collected.len(), len);
698
699 for (idx, is_set) in collected {
700 let bit_position = offset + idx;
701 let byte_index = bit_position / 8;
702 let expected_is_set = byte_index.is_multiple_of(2);
703
704 assert_eq!(
705 is_set, expected_is_set,
706 "Bit mismatch at index {}: expected {} got {}",
707 bit_position, expected_is_set, is_set
708 );
709 }
710 }
711
712 #[rstest]
713 #[case(5)]
714 #[case(8)]
715 #[case(10)]
716 #[case(64)]
717 #[case(65)]
718 #[case(100)]
719 #[case(128)]
720 fn test_map_cmp_identity(#[case] len: usize) {
721 let buf = BitBuffer::collect_bool(len, |i| i % 3 == 0);
723 let mapped = buf.map_cmp(|_idx, bit| bit);
724
725 assert_eq!(buf.len(), mapped.len());
726 for i in 0..len {
727 assert_eq!(buf.value(i), mapped.value(i), "Mismatch at index {}", i);
728 }
729 }
730
731 #[rstest]
732 #[case(5)]
733 #[case(8)]
734 #[case(64)]
735 #[case(65)]
736 #[case(100)]
737 fn test_map_cmp_negate(#[case] len: usize) {
738 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
740 let mapped = buf.map_cmp(|_idx, bit| !bit);
741
742 assert_eq!(buf.len(), mapped.len());
743 for i in 0..len {
744 assert_eq!(!buf.value(i), mapped.value(i), "Mismatch at index {}", i);
745 }
746 }
747
748 #[test]
749 fn test_map_cmp_conditional() {
750 let len = 100;
752 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
753
754 let mapped = buf.map_cmp(|idx, bit| bit && idx % 4 == 0);
756
757 for i in 0..len {
758 let expected = (i % 2 == 0) && (i % 4 == 0);
759 assert_eq!(mapped.value(i), expected, "Mismatch at index {}", i);
760 }
761 }
762}