1use std::ops::{BitAnd, BitOr, BitXor, Not, RangeBounds};
5
6use crate::bit::ops::{bitwise_binary_op, bitwise_unary_op};
7use crate::bit::{
8 BitChunks, BitIndexIterator, BitIterator, BitSliceIterator, UnalignedBitChunk,
9 get_bit_unchecked,
10};
11use crate::{Alignment, BitBufferMut, Buffer, BufferMut, ByteBuffer, buffer};
12
13#[derive(Debug, Clone, Eq)]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16pub struct BitBuffer {
17 buffer: ByteBuffer,
18 offset: usize,
22 len: usize,
23}
24
25impl PartialEq for BitBuffer {
26 fn eq(&self, other: &Self) -> bool {
27 if self.len != other.len {
28 return false;
29 }
30
31 self.chunks()
32 .iter_padded()
33 .zip(other.chunks().iter_padded())
34 .all(|(a, b)| a == b)
35 }
36}
37
38impl BitBuffer {
39 pub fn new(buffer: ByteBuffer, len: usize) -> Self {
43 assert!(
44 buffer.len() * 8 >= len,
45 "provided ByteBuffer not large enough to back BoolBuffer with len {len}"
46 );
47
48 let buffer = buffer.aligned(Alignment::none());
50
51 Self {
52 buffer,
53 len,
54 offset: 0,
55 }
56 }
57
58 pub fn new_with_offset(buffer: ByteBuffer, len: usize, offset: usize) -> Self {
63 assert!(
64 len.saturating_add(offset) <= buffer.len().saturating_mul(8),
65 "provided ByteBuffer (len={}) not large enough to back BoolBuffer with offset {offset} len {len}",
66 buffer.len()
67 );
68
69 let buffer = buffer.aligned(Alignment::none());
71
72 Self {
73 buffer,
74 offset,
75 len,
76 }
77 }
78
79 pub fn new_set(len: usize) -> Self {
81 let words = len.div_ceil(8);
82 let buffer = buffer![0xFF; words];
83
84 Self {
85 buffer,
86 len,
87 offset: 0,
88 }
89 }
90
91 pub fn new_unset(len: usize) -> Self {
93 let words = len.div_ceil(8);
94 let buffer = Buffer::zeroed(words);
95
96 Self {
97 buffer,
98 len,
99 offset: 0,
100 }
101 }
102
103 pub fn empty() -> Self {
105 Self::new_set(0)
106 }
107
108 pub fn full(value: bool, len: usize) -> Self {
110 if value {
111 Self::new_set(len)
112 } else {
113 Self::new_unset(len)
114 }
115 }
116
117 pub fn collect_bool<F: FnMut(usize) -> bool>(len: usize, mut f: F) -> Self {
119 let mut buffer = BufferMut::with_capacity(len.div_ceil(64) * 8);
120
121 let chunks = len / 64;
122 let remainder = len % 64;
123 for chunk in 0..chunks {
124 let mut packed = 0;
125 for bit_idx in 0..64 {
126 let i = bit_idx + chunk * 64;
127 packed |= (f(i) as u64) << bit_idx;
128 }
129
130 unsafe { buffer.push_unchecked(packed) }
132 }
133
134 if remainder != 0 {
135 let mut packed = 0;
136 for bit_idx in 0..remainder {
137 let i = bit_idx + chunks * 64;
138 packed |= (f(i) as u64) << bit_idx;
139 }
140
141 unsafe { buffer.push_unchecked(packed) }
143 }
144
145 buffer.truncate(len.div_ceil(8));
146
147 Self::new(buffer.freeze().into_byte_buffer(), len)
148 }
149
150 #[inline]
155 pub fn len(&self) -> usize {
156 self.len
157 }
158
159 #[inline]
161 pub fn is_empty(&self) -> bool {
162 self.len() == 0
163 }
164
165 #[inline(always)]
167 pub fn offset(&self) -> usize {
168 self.offset
169 }
170
171 #[inline(always)]
173 pub fn inner(&self) -> &ByteBuffer {
174 &self.buffer
175 }
176
177 #[inline]
183 pub fn value(&self, index: usize) -> bool {
184 assert!(index < self.len);
185 unsafe { self.value_unchecked(index) }
186 }
187
188 #[inline]
193 pub unsafe fn value_unchecked(&self, index: usize) -> bool {
194 unsafe { get_bit_unchecked(self.buffer.as_ptr(), index + self.offset) }
195 }
196
197 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
202 let start = match range.start_bound() {
203 std::ops::Bound::Included(&s) => s,
204 std::ops::Bound::Excluded(&s) => s + 1,
205 std::ops::Bound::Unbounded => 0,
206 };
207 let end = match range.end_bound() {
208 std::ops::Bound::Included(&e) => e + 1,
209 std::ops::Bound::Excluded(&e) => e,
210 std::ops::Bound::Unbounded => self.len,
211 };
212
213 assert!(start <= end);
214 assert!(start <= self.len);
215 assert!(end <= self.len);
216 let len = end - start;
217
218 Self::new_with_offset(self.buffer.clone(), len, self.offset + start)
219 }
220
221 pub fn shrink_offset(self) -> Self {
223 let bit_offset = self.offset % 8;
224 let len = self.len;
225 let buffer = self.into_inner();
226 BitBuffer::new_with_offset(buffer, len, bit_offset)
227 }
228
229 pub fn unaligned_chunks(&self) -> UnalignedBitChunk<'_> {
231 UnalignedBitChunk::new(self.buffer.as_slice(), self.offset, self.len)
232 }
233
234 pub fn chunks(&self) -> BitChunks<'_> {
238 BitChunks::new(self.buffer.as_slice(), self.offset, self.len)
239 }
240
241 pub fn true_count(&self) -> usize {
243 self.unaligned_chunks().count_ones()
244 }
245
246 pub fn false_count(&self) -> usize {
248 self.len - self.true_count()
249 }
250
251 pub fn iter(&self) -> BitIterator<'_> {
253 BitIterator::new(self.buffer.as_slice(), self.offset, self.len)
254 }
255
256 pub fn set_indices(&self) -> BitIndexIterator<'_> {
258 BitIndexIterator::new(self.buffer.as_slice(), self.offset, self.len)
259 }
260
261 pub fn set_slices(&self) -> BitSliceIterator<'_> {
263 BitSliceIterator::new(self.buffer.as_slice(), self.offset, self.len)
264 }
265
266 pub fn sliced(&self) -> Self {
268 if self.offset % 8 == 0 {
269 return Self::new(
270 self.buffer.slice(self.offset / 8..self.len.div_ceil(8)),
271 self.len,
272 );
273 }
274 bitwise_unary_op(self, |a| a)
275 }
276}
277
278impl BitBuffer {
281 pub fn into_inner(self) -> ByteBuffer {
284 let word_start = self.offset / 8;
285 let word_end = (self.offset + self.len).div_ceil(8);
286
287 self.buffer.slice(word_start..word_end)
288 }
289
290 pub fn try_into_mut(self) -> Result<BitBufferMut, Self> {
292 match self.buffer.try_into_mut() {
293 Ok(buffer) => Ok(BitBufferMut::from_buffer(buffer, self.offset, self.len)),
294 Err(buffer) => Err(BitBuffer::new_with_offset(buffer, self.len, self.offset)),
295 }
296 }
297
298 pub fn into_mut(self) -> BitBufferMut {
303 let offset = self.offset;
304 let len = self.len;
305 let inner = self.into_inner().into_mut();
307 BitBufferMut::from_buffer(inner, offset, len)
308 }
309}
310
311impl From<&[bool]> for BitBuffer {
312 fn from(value: &[bool]) -> Self {
313 BitBufferMut::from(value).freeze()
314 }
315}
316
317impl From<Vec<bool>> for BitBuffer {
318 fn from(value: Vec<bool>) -> Self {
319 BitBufferMut::from(value).freeze()
320 }
321}
322
323impl FromIterator<bool> for BitBuffer {
324 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
325 BitBufferMut::from_iter(iter).freeze()
326 }
327}
328
329impl BitOr for &BitBuffer {
330 type Output = BitBuffer;
331
332 fn bitor(self, rhs: Self) -> Self::Output {
333 bitwise_binary_op(self, rhs, |a, b| a | b)
334 }
335}
336
337impl BitOr<&BitBuffer> for BitBuffer {
338 type Output = BitBuffer;
339
340 fn bitor(self, rhs: &BitBuffer) -> Self::Output {
341 (&self).bitor(rhs)
342 }
343}
344
345impl BitAnd for &BitBuffer {
346 type Output = BitBuffer;
347
348 fn bitand(self, rhs: Self) -> Self::Output {
349 bitwise_binary_op(self, rhs, |a, b| a & b)
350 }
351}
352
353impl BitAnd<BitBuffer> for &BitBuffer {
354 type Output = BitBuffer;
355
356 fn bitand(self, rhs: BitBuffer) -> Self::Output {
357 self.bitand(&rhs)
358 }
359}
360
361impl BitAnd<&BitBuffer> for BitBuffer {
362 type Output = BitBuffer;
363
364 fn bitand(self, rhs: &BitBuffer) -> Self::Output {
365 (&self).bitand(rhs)
366 }
367}
368
369impl Not for &BitBuffer {
370 type Output = BitBuffer;
371
372 fn not(self) -> Self::Output {
373 bitwise_unary_op(self, |a| !a)
374 }
375}
376
377impl Not for BitBuffer {
378 type Output = BitBuffer;
379
380 fn not(self) -> Self::Output {
381 (&self).not()
382 }
383}
384
385impl BitXor for &BitBuffer {
386 type Output = BitBuffer;
387
388 fn bitxor(self, rhs: Self) -> Self::Output {
389 bitwise_binary_op(self, rhs, |a, b| a ^ b)
390 }
391}
392
393impl BitXor<&BitBuffer> for BitBuffer {
394 type Output = BitBuffer;
395
396 fn bitxor(self, rhs: &BitBuffer) -> Self::Output {
397 (&self).bitxor(rhs)
398 }
399}
400
401impl BitBuffer {
402 pub fn bitand_not(&self, rhs: &BitBuffer) -> BitBuffer {
407 bitwise_binary_op(self, rhs, |a, b| a & !b)
408 }
409
410 #[inline]
420 pub fn iter_bits<F>(&self, mut f: F)
421 where
422 F: FnMut(usize, bool),
423 {
424 let total_bits = self.len;
425 if total_bits == 0 {
426 return;
427 }
428
429 let is_bit_set = |byte: u8, bit_idx: usize| (byte & (1 << bit_idx)) != 0;
430 let bit_offset = self.offset % 8;
431 let mut buffer_ptr = unsafe { self.buffer.as_ptr().add(self.offset / 8) };
432 let mut callback_idx = 0;
433
434 if bit_offset > 0 {
436 let bits_in_first_byte = (8 - bit_offset).min(total_bits);
437 let byte = unsafe { *buffer_ptr };
438
439 for bit_idx in 0..bits_in_first_byte {
440 f(callback_idx, is_bit_set(byte, bit_offset + bit_idx));
441 callback_idx += 1;
442 }
443
444 buffer_ptr = unsafe { buffer_ptr.add(1) };
445 }
446
447 let complete_bytes = (total_bits - callback_idx) / 8;
449 for _ in 0..complete_bytes {
450 let byte = unsafe { *buffer_ptr };
451
452 for bit_idx in 0..8 {
453 f(callback_idx, is_bit_set(byte, bit_idx));
454 callback_idx += 1;
455 }
456 buffer_ptr = unsafe { buffer_ptr.add(1) };
457 }
458
459 let remaining_bits = total_bits - callback_idx;
461 if remaining_bits > 0 {
462 let byte = unsafe { *buffer_ptr };
463
464 for bit_idx in 0..remaining_bits {
465 f(callback_idx, is_bit_set(byte, bit_idx));
466 callback_idx += 1;
467 }
468 }
469 }
470}
471
472impl<'a> IntoIterator for &'a BitBuffer {
473 type Item = bool;
474 type IntoIter = BitIterator<'a>;
475
476 fn into_iter(self) -> Self::IntoIter {
477 self.iter()
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use rstest::rstest;
484
485 use crate::bit::BitBuffer;
486 use crate::{ByteBuffer, buffer};
487
488 #[test]
489 fn test_bool() {
490 let buffer: ByteBuffer = buffer![1 << 7; 1024];
492 let bools = BitBuffer::new(buffer, 1024 * 8);
493
494 assert_eq!(bools.len(), 1024 * 8);
496 assert!(!bools.is_empty());
497 assert_eq!(bools.true_count(), 1024);
498 assert_eq!(bools.false_count(), 1024 * 7);
499
500 for word in 0..1024 {
502 for bit in 0..8 {
503 if bit == 7 {
504 assert!(bools.value(word * 8 + bit));
505 } else {
506 assert!(!bools.value(word * 8 + bit));
507 }
508 }
509 }
510
511 let sliced = bools.slice(64..72);
513
514 assert_eq!(sliced.len(), 8);
516 assert!(!sliced.is_empty());
517 assert_eq!(sliced.true_count(), 1);
518 assert_eq!(sliced.false_count(), 7);
519
520 for bit in 0..8 {
522 if bit == 7 {
523 assert!(sliced.value(bit));
524 } else {
525 assert!(!sliced.value(bit));
526 }
527 }
528 }
529
530 #[test]
531 fn test_padded_equaltiy() {
532 let buf1 = BitBuffer::new_set(64); let buf2 = BitBuffer::collect_bool(64, |x| x < 32); for i in 0..32 {
536 assert_eq!(buf1.value(i), buf2.value(i), "Bit {} should be the same", i);
537 }
538
539 for i in 32..64 {
540 assert_ne!(buf1.value(i), buf2.value(i), "Bit {} should differ", i);
541 }
542
543 assert_eq!(
544 buf1.slice(0..32),
545 buf2.slice(0..32),
546 "Buffer slices with same bits should be equal (`PartialEq` needs `iter_padded()`)"
547 );
548 assert_ne!(
549 buf1.slice(32..64),
550 buf2.slice(32..64),
551 "Buffer slices with different bits should not be equal (`PartialEq` needs `iter_padded()`)"
552 );
553 }
554
555 #[test]
556 fn test_slice_offset_calculation() {
557 let buf = BitBuffer::collect_bool(16, |_| true);
558 let sliced = buf.slice(10..16);
559 assert_eq!(sliced.offset(), 10);
560 }
561
562 #[rstest]
563 #[case(5)]
564 #[case(8)]
565 #[case(10)]
566 #[case(13)]
567 #[case(16)]
568 #[case(23)]
569 #[case(100)]
570 fn test_iter_bits(#[case] len: usize) {
571 let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
572
573 let mut collected = Vec::new();
574 buf.iter_bits(|idx, is_set| {
575 collected.push((idx, is_set));
576 });
577
578 assert_eq!(collected.len(), len);
579
580 for (idx, is_set) in collected {
581 assert_eq!(is_set, idx % 2 == 0);
582 }
583 }
584
585 #[rstest]
586 #[case(3, 5)]
587 #[case(3, 8)]
588 #[case(5, 10)]
589 #[case(2, 16)]
590 #[case(8, 16)]
591 #[case(9, 16)]
592 #[case(17, 16)]
593 fn test_iter_bits_with_offset(#[case] offset: usize, #[case] len: usize) {
594 let total_bits = offset + len;
595 let buf = BitBuffer::collect_bool(total_bits, |i| i % 2 == 0);
596 let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
597
598 let mut collected = Vec::new();
599 buf_with_offset.iter_bits(|idx, is_set| {
600 collected.push((idx, is_set));
601 });
602
603 assert_eq!(collected.len(), len);
604
605 for (idx, is_set) in collected {
606 assert_eq!(is_set, (offset + idx) % 2 == 0);
608 }
609 }
610
611 #[rstest]
612 #[case(8, 10)]
613 #[case(9, 7)]
614 #[case(16, 8)]
615 #[case(17, 10)]
616 fn test_iter_bits_catches_wrong_byte_offset(#[case] offset: usize, #[case] len: usize) {
617 let total_bits = offset + len;
618 let buf = BitBuffer::collect_bool(total_bits, |i| (i / 8) % 2 == 0);
620
621 let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
622
623 let mut collected = Vec::new();
624 buf_with_offset.iter_bits(|idx, is_set| {
625 collected.push((idx, is_set));
626 });
627
628 assert_eq!(collected.len(), len);
629
630 for (idx, is_set) in collected {
631 let bit_position = offset + idx;
632 let byte_index = bit_position / 8;
633 let expected_is_set = byte_index % 2 == 0;
634
635 assert_eq!(
636 is_set, expected_is_set,
637 "Bit mismatch at index {}: expected {} got {}",
638 bit_position, expected_is_set, is_set
639 );
640 }
641 }
642}