1use std::ops::BitAnd;
5use std::ops::BitOr;
6use std::ops::BitXor;
7use std::ops::Bound;
8use std::ops::Not;
9use std::ops::RangeBounds;
10
11use crate::Alignment;
12use crate::BitBufferMut;
13use crate::Buffer;
14use crate::BufferMut;
15use crate::ByteBuffer;
16use crate::bit::BitChunks;
17use crate::bit::BitIndexIterator;
18use crate::bit::BitIterator;
19use crate::bit::BitSliceIterator;
20use crate::bit::UnalignedBitChunk;
21use crate::bit::get_bit_unchecked;
22use crate::bit::ops::bitwise_binary_op;
23use crate::bit::ops::bitwise_unary_op;
24use crate::buffer;
25
26#[derive(Debug, Clone, Eq)]
28#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
29pub struct BitBuffer {
30 buffer: ByteBuffer,
31 offset: usize,
35 len: usize,
36}
37
38impl PartialEq for BitBuffer {
39 fn eq(&self, other: &Self) -> bool {
40 if self.len != other.len {
41 return false;
42 }
43
44 self.chunks()
45 .iter_padded()
46 .zip(other.chunks().iter_padded())
47 .all(|(a, b)| a == b)
48 }
49}
50
51impl BitBuffer {
52 pub fn new(buffer: ByteBuffer, len: usize) -> Self {
56 assert!(
57 buffer.len() * 8 >= len,
58 "provided ByteBuffer not large enough to back BoolBuffer with len {len}"
59 );
60
61 let buffer = buffer.aligned(Alignment::none());
63
64 Self {
65 buffer,
66 len,
67 offset: 0,
68 }
69 }
70
71 pub fn new_with_offset(buffer: ByteBuffer, len: usize, offset: usize) -> Self {
76 assert!(
77 len.saturating_add(offset) <= buffer.len().saturating_mul(8),
78 "provided ByteBuffer (len={}) not large enough to back BoolBuffer with offset {offset} len {len}",
79 buffer.len()
80 );
81
82 let buffer = buffer.aligned(Alignment::none());
84
85 let byte_offset = offset / 8;
87 let offset = offset % 8;
88 let buffer = buffer.slice(byte_offset..);
89
90 Self {
91 buffer,
92 offset,
93 len,
94 }
95 }
96
97 pub fn new_set(len: usize) -> Self {
99 let words = len.div_ceil(8);
100 let buffer = buffer![0xFF; words];
101
102 Self {
103 buffer,
104 len,
105 offset: 0,
106 }
107 }
108
109 pub fn new_unset(len: usize) -> Self {
111 let words = len.div_ceil(8);
112 let buffer = Buffer::zeroed(words);
113
114 Self {
115 buffer,
116 len,
117 offset: 0,
118 }
119 }
120
121 pub fn empty() -> Self {
123 Self::new_set(0)
124 }
125
126 pub fn full(value: bool, len: usize) -> Self {
128 if value {
129 Self::new_set(len)
130 } else {
131 Self::new_unset(len)
132 }
133 }
134
135 pub fn collect_bool<F: FnMut(usize) -> bool>(len: usize, f: F) -> Self {
137 BitBufferMut::collect_bool(len, f).freeze()
138 }
139
140 pub fn map_cmp<F>(&self, mut f: F) -> Self
145 where
146 F: FnMut(usize, bool) -> bool,
147 {
148 let len = self.len;
149 let mut buffer: BufferMut<u64> = BufferMut::with_capacity(len.div_ceil(64));
150
151 let chunks_count = len / 64;
152 let remainder = len % 64;
153 let chunks = self.chunks();
154
155 for (chunk_idx, src_chunk) in chunks.iter().enumerate() {
156 let mut packed = 0u64;
157 for bit_idx in 0..64 {
158 let i = bit_idx + chunk_idx * 64;
159 let bit_value = (src_chunk >> bit_idx) & 1 == 1;
160 packed |= (f(i, bit_value) as u64) << bit_idx;
161 }
162
163 unsafe { buffer.push_unchecked(packed) }
165 }
166
167 if remainder != 0 {
168 let src_chunk = chunks.remainder_bits();
169 let mut packed = 0u64;
170 for bit_idx in 0..remainder {
171 let i = bit_idx + chunks_count * 64;
172 let bit_value = (src_chunk >> bit_idx) & 1 == 1;
173 packed |= (f(i, bit_value) as u64) << bit_idx;
174 }
175
176 unsafe { buffer.push_unchecked(packed) }
178 }
179
180 buffer.truncate(len.div_ceil(8));
181
182 Self {
183 buffer: buffer.freeze().into_byte_buffer(),
184 offset: 0,
185 len,
186 }
187 }
188
189 pub fn clear(&mut self) {
191 self.buffer.clear();
192 self.len = 0;
193 self.offset = 0;
194 }
195
196 #[inline]
201 pub fn len(&self) -> usize {
202 self.len
203 }
204
205 #[inline]
207 pub fn is_empty(&self) -> bool {
208 self.len() == 0
209 }
210
211 #[inline(always)]
213 pub fn offset(&self) -> usize {
214 self.offset
215 }
216
217 #[inline(always)]
219 pub fn inner(&self) -> &ByteBuffer {
220 &self.buffer
221 }
222
223 #[inline]
229 pub fn value(&self, index: usize) -> bool {
230 assert!(index < self.len);
231 unsafe { self.value_unchecked(index) }
232 }
233
234 #[inline]
239 pub unsafe fn value_unchecked(&self, index: usize) -> bool {
240 unsafe { get_bit_unchecked(self.buffer.as_ptr(), index + self.offset) }
241 }
242
243 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
248 let start = match range.start_bound() {
249 Bound::Included(&s) => s,
250 Bound::Excluded(&s) => s + 1,
251 Bound::Unbounded => 0,
252 };
253 let end = match range.end_bound() {
254 Bound::Included(&e) => e + 1,
255 Bound::Excluded(&e) => e,
256 Bound::Unbounded => self.len,
257 };
258
259 assert!(start <= end);
260 assert!(start <= self.len);
261 assert!(end <= self.len);
262 let len = end - start;
263
264 Self::new_with_offset(self.buffer.clone(), len, self.offset + start)
265 }
266
267 pub fn shrink_offset(self) -> Self {
269 let word_start = self.offset / 8;
270 let word_end = (self.offset + self.len).div_ceil(8);
271
272 let buffer = self.buffer.slice(word_start..word_end);
273
274 let bit_offset = self.offset % 8;
275 let len = self.len;
276 BitBuffer::new_with_offset(buffer, len, bit_offset)
277 }
278
279 pub fn unaligned_chunks(&self) -> UnalignedBitChunk<'_> {
281 UnalignedBitChunk::new(self.buffer.as_slice(), self.offset, self.len)
282 }
283
284 pub fn chunks(&self) -> BitChunks<'_> {
288 BitChunks::new(self.buffer.as_slice(), self.offset, self.len)
289 }
290
291 pub fn true_count(&self) -> usize {
293 self.unaligned_chunks().count_ones()
294 }
295
296 pub fn false_count(&self) -> usize {
298 self.len - self.true_count()
299 }
300
301 pub fn iter(&self) -> BitIterator<'_> {
303 BitIterator::new(self.buffer.as_slice(), self.offset, self.len)
304 }
305
306 pub fn set_indices(&self) -> BitIndexIterator<'_> {
308 BitIndexIterator::new(self.buffer.as_slice(), self.offset, self.len)
309 }
310
311 pub fn set_slices(&self) -> BitSliceIterator<'_> {
313 BitSliceIterator::new(self.buffer.as_slice(), self.offset, self.len)
314 }
315
316 pub fn sliced(&self) -> Self {
318 if self.offset.is_multiple_of(8) {
319 return Self::new(
320 self.buffer.slice(self.offset / 8..self.len.div_ceil(8)),
321 self.len,
322 );
323 }
324 bitwise_unary_op(self, |a| a)
325 }
326}
327
328impl BitBuffer {
331 pub fn into_inner(self) -> (usize, usize, ByteBuffer) {
333 (self.offset, self.len, self.buffer)
334 }
335
336 pub fn try_into_mut(self) -> Result<BitBufferMut, Self> {
338 match self.buffer.try_into_mut() {
339 Ok(buffer) => Ok(BitBufferMut::from_buffer(buffer, self.offset, self.len)),
340 Err(buffer) => Err(BitBuffer::new_with_offset(buffer, self.len, self.offset)),
341 }
342 }
343
344 pub fn into_mut(self) -> BitBufferMut {
349 let (offset, len, inner) = self.into_inner();
350 BitBufferMut::from_buffer(inner.into_mut(), offset, len)
352 }
353}
354
355impl From<&[bool]> for BitBuffer {
356 fn from(value: &[bool]) -> Self {
357 BitBufferMut::from(value).freeze()
358 }
359}
360
361impl From<Vec<bool>> for BitBuffer {
362 fn from(value: Vec<bool>) -> Self {
363 BitBufferMut::from(value).freeze()
364 }
365}
366
367impl FromIterator<bool> for BitBuffer {
368 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
369 BitBufferMut::from_iter(iter).freeze()
370 }
371}
372
373impl BitOr for &BitBuffer {
374 type Output = BitBuffer;
375
376 fn bitor(self, rhs: Self) -> Self::Output {
377 bitwise_binary_op(self, rhs, |a, b| a | b)
378 }
379}
380
381impl BitOr<&BitBuffer> for BitBuffer {
382 type Output = BitBuffer;
383
384 fn bitor(self, rhs: &BitBuffer) -> Self::Output {
385 (&self).bitor(rhs)
386 }
387}
388
389impl BitAnd for &BitBuffer {
390 type Output = BitBuffer;
391
392 fn bitand(self, rhs: Self) -> Self::Output {
393 bitwise_binary_op(self, rhs, |a, b| a & b)
394 }
395}
396
397impl BitAnd<BitBuffer> for &BitBuffer {
398 type Output = BitBuffer;
399
400 fn bitand(self, rhs: BitBuffer) -> Self::Output {
401 self.bitand(&rhs)
402 }
403}
404
405impl BitAnd<&BitBuffer> for BitBuffer {
406 type Output = BitBuffer;
407
408 fn bitand(self, rhs: &BitBuffer) -> Self::Output {
409 (&self).bitand(rhs)
410 }
411}
412
413impl Not for &BitBuffer {
414 type Output = BitBuffer;
415
416 fn not(self) -> Self::Output {
417 bitwise_unary_op(self, |a| !a)
418 }
419}
420
421impl Not for BitBuffer {
422 type Output = BitBuffer;
423
424 fn not(self) -> Self::Output {
425 (&self).not()
426 }
427}
428
429impl BitXor for &BitBuffer {
430 type Output = BitBuffer;
431
432 fn bitxor(self, rhs: Self) -> Self::Output {
433 bitwise_binary_op(self, rhs, |a, b| a ^ b)
434 }
435}
436
437impl BitXor<&BitBuffer> for BitBuffer {
438 type Output = BitBuffer;
439
440 fn bitxor(self, rhs: &BitBuffer) -> Self::Output {
441 (&self).bitxor(rhs)
442 }
443}
444
445impl BitBuffer {
446 pub fn bitand_not(&self, rhs: &BitBuffer) -> BitBuffer {
451 bitwise_binary_op(self, rhs, |a, b| a & !b)
452 }
453
454 #[inline]
464 pub fn iter_bits<F>(&self, mut f: F)
465 where
466 F: FnMut(usize, bool),
467 {
468 let total_bits = self.len;
469 if total_bits == 0 {
470 return;
471 }
472
473 let is_bit_set = |byte: u8, bit_idx: usize| (byte & (1 << bit_idx)) != 0;
474 let bit_offset = self.offset % 8;
475 let mut buffer_ptr = unsafe { self.buffer.as_ptr().add(self.offset / 8) };
476 let mut callback_idx = 0;
477
478 if bit_offset > 0 {
480 let bits_in_first_byte = (8 - bit_offset).min(total_bits);
481 let byte = unsafe { *buffer_ptr };
482
483 for bit_idx in 0..bits_in_first_byte {
484 f(callback_idx, is_bit_set(byte, bit_offset + bit_idx));
485 callback_idx += 1;
486 }
487
488 buffer_ptr = unsafe { buffer_ptr.add(1) };
489 }
490
491 let complete_bytes = (total_bits - callback_idx) / 8;
493 for _ in 0..complete_bytes {
494 let byte = unsafe { *buffer_ptr };
495
496 for bit_idx in 0..8 {
497 f(callback_idx, is_bit_set(byte, bit_idx));
498 callback_idx += 1;
499 }
500 buffer_ptr = unsafe { buffer_ptr.add(1) };
501 }
502
503 let remaining_bits = total_bits - callback_idx;
505 if remaining_bits > 0 {
506 let byte = unsafe { *buffer_ptr };
507
508 for bit_idx in 0..remaining_bits {
509 f(callback_idx, is_bit_set(byte, bit_idx));
510 callback_idx += 1;
511 }
512 }
513 }
514}
515
516impl<'a> IntoIterator for &'a BitBuffer {
517 type Item = bool;
518 type IntoIter = BitIterator<'a>;
519
520 fn into_iter(self) -> Self::IntoIter {
521 self.iter()
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use rstest::rstest;
528
529 use crate::ByteBuffer;
530 use crate::bit::BitBuffer;
531 use crate::buffer;
532
533 #[test]
534 fn test_bool() {
535 let buffer: ByteBuffer = buffer![1 << 7; 1024];
537 let bools = BitBuffer::new(buffer, 1024 * 8);
538
539 assert_eq!(bools.len(), 1024 * 8);
541 assert!(!bools.is_empty());
542 assert_eq!(bools.true_count(), 1024);
543 assert_eq!(bools.false_count(), 1024 * 7);
544
545 for word in 0..1024 {
547 for bit in 0..8 {
548 if bit == 7 {
549 assert!(bools.value(word * 8 + bit));
550 } else {
551 assert!(!bools.value(word * 8 + bit));
552 }
553 }
554 }
555
556 let sliced = bools.slice(64..72);
558
559 assert_eq!(sliced.len(), 8);
561 assert!(!sliced.is_empty());
562 assert_eq!(sliced.true_count(), 1);
563 assert_eq!(sliced.false_count(), 7);
564
565 for bit in 0..8 {
567 if bit == 7 {
568 assert!(sliced.value(bit));
569 } else {
570 assert!(!sliced.value(bit));
571 }
572 }
573 }
574
575 #[test]
576 fn test_padded_equaltiy() {
577 let buf1 = BitBuffer::new_set(64); let buf2 = BitBuffer::collect_bool(64, |x| x < 32); for i in 0..32 {
581 assert_eq!(buf1.value(i), buf2.value(i), "Bit {} should be the same", i);
582 }
583
584 for i in 32..64 {
585 assert_ne!(buf1.value(i), buf2.value(i), "Bit {} should differ", i);
586 }
587
588 assert_eq!(
589 buf1.slice(0..32),
590 buf2.slice(0..32),
591 "Buffer slices with same bits should be equal (`PartialEq` needs `iter_padded()`)"
592 );
593 assert_ne!(
594 buf1.slice(32..64),
595 buf2.slice(32..64),
596 "Buffer slices with different bits should not be equal (`PartialEq` needs `iter_padded()`)"
597 );
598 }
599
600 #[test]
601 fn test_slice_offset_calculation() {
602 let buf = BitBuffer::collect_bool(16, |_| true);
603 let sliced = buf.slice(10..16);
604 assert_eq!(sliced.len(), 6);
605 assert_eq!(sliced.offset(), 2);
607 }
608
609 #[rstest]
610 #[case(5)]
611 #[case(8)]
612 #[case(10)]
613 #[case(13)]
614 #[case(16)]
615 #[case(23)]
616 #[case(100)]
617 fn test_iter_bits(#[case] len: usize) {
618 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
619
620 let mut collected = Vec::new();
621 buf.iter_bits(|idx, is_set| {
622 collected.push((idx, is_set));
623 });
624
625 assert_eq!(collected.len(), len);
626
627 for (idx, is_set) in collected {
628 assert_eq!(is_set, idx % 2 == 0);
629 }
630 }
631
632 #[rstest]
633 #[case(3, 5)]
634 #[case(3, 8)]
635 #[case(5, 10)]
636 #[case(2, 16)]
637 #[case(8, 16)]
638 #[case(9, 16)]
639 #[case(17, 16)]
640 fn test_iter_bits_with_offset(#[case] offset: usize, #[case] len: usize) {
641 let total_bits = offset + len;
642 let buf = BitBuffer::collect_bool(total_bits, |i| i % 2 == 0);
643 let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
644
645 let mut collected = Vec::new();
646 buf_with_offset.iter_bits(|idx, is_set| {
647 collected.push((idx, is_set));
648 });
649
650 assert_eq!(collected.len(), len);
651
652 for (idx, is_set) in collected {
653 assert_eq!(is_set, (offset + idx).is_multiple_of(2));
655 }
656 }
657
658 #[rstest]
659 #[case(8, 10)]
660 #[case(9, 7)]
661 #[case(16, 8)]
662 #[case(17, 10)]
663 fn test_iter_bits_catches_wrong_byte_offset(#[case] offset: usize, #[case] len: usize) {
664 let total_bits = offset + len;
665 let buf = BitBuffer::collect_bool(total_bits, |i| (i / 8) % 2 == 0);
667
668 let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
669
670 let mut collected = Vec::new();
671 buf_with_offset.iter_bits(|idx, is_set| {
672 collected.push((idx, is_set));
673 });
674
675 assert_eq!(collected.len(), len);
676
677 for (idx, is_set) in collected {
678 let bit_position = offset + idx;
679 let byte_index = bit_position / 8;
680 let expected_is_set = byte_index.is_multiple_of(2);
681
682 assert_eq!(
683 is_set, expected_is_set,
684 "Bit mismatch at index {}: expected {} got {}",
685 bit_position, expected_is_set, is_set
686 );
687 }
688 }
689
690 #[rstest]
691 #[case(5)]
692 #[case(8)]
693 #[case(10)]
694 #[case(64)]
695 #[case(65)]
696 #[case(100)]
697 #[case(128)]
698 fn test_map_cmp_identity(#[case] len: usize) {
699 let buf = BitBuffer::collect_bool(len, |i| i % 3 == 0);
701 let mapped = buf.map_cmp(|_idx, bit| bit);
702
703 assert_eq!(buf.len(), mapped.len());
704 for i in 0..len {
705 assert_eq!(buf.value(i), mapped.value(i), "Mismatch at index {}", i);
706 }
707 }
708
709 #[rstest]
710 #[case(5)]
711 #[case(8)]
712 #[case(64)]
713 #[case(65)]
714 #[case(100)]
715 fn test_map_cmp_negate(#[case] len: usize) {
716 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
718 let mapped = buf.map_cmp(|_idx, bit| !bit);
719
720 assert_eq!(buf.len(), mapped.len());
721 for i in 0..len {
722 assert_eq!(!buf.value(i), mapped.value(i), "Mismatch at index {}", i);
723 }
724 }
725
726 #[test]
727 fn test_map_cmp_conditional() {
728 let len = 100;
730 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
731
732 let mapped = buf.map_cmp(|idx, bit| bit && idx % 4 == 0);
734
735 for i in 0..len {
736 let expected = (i % 2 == 0) && (i % 4 == 0);
737 assert_eq!(mapped.value(i), expected, "Mismatch at index {}", i);
738 }
739 }
740}