1use std::borrow::Cow;
5use std::fmt::{Debug, Formatter};
6use std::marker::PhantomData;
7
8use vortex_error::VortexResult;
9
10use crate::{BitBuffer, BitBufferMut};
11
12pub struct BitView<'a, const NB: usize> {
23 bits: Cow<'a, [u8; NB]>,
24 true_count: usize,
28}
29
30impl<const NB: usize> Debug for BitView<'_, NB> {
31 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct(&format!("BitView[{}]", NB * 8))
33 .field("true_count", &self.true_count)
34 .field("bits", &self.as_raw())
35 .finish()
36 }
37}
38
39impl<const NB: usize> BitView<'static, NB> {
40 const ALL_TRUE: [u8; NB] = [u8::MAX; NB];
41 const ALL_FALSE: [u8; NB] = [0; NB];
42
43 pub const fn all_true() -> Self {
45 unsafe { BitView::new_unchecked(&Self::ALL_TRUE, NB * 8) }
46 }
47
48 pub const fn all_false() -> Self {
50 unsafe { BitView::new_unchecked(&Self::ALL_FALSE, 0) }
51 }
52}
53
54impl<'a, const NB: usize> BitView<'a, NB> {
55 const N: usize = NB * 8;
56 const N_WORDS: usize = NB * 8 / (usize::BITS as usize);
57
58 const _ASSERT_MULTIPLE_OF_8: () = assert!(
59 NB % 8 == 0,
60 "NB must be a multiple of 8 for N to be a multiple of 64"
61 );
62
63 pub fn new(bits: &'a [u8; NB]) -> Self {
65 let ptr = bits.as_ptr().cast::<usize>();
66 let true_count = (0..Self::N_WORDS)
67 .map(|idx| unsafe { ptr.add(idx).read_unaligned().count_ones() as usize })
68 .sum();
69 BitView {
70 bits: Cow::Borrowed(bits),
71 true_count,
72 }
73 }
74
75 pub fn new_owned(bits: [u8; NB]) -> Self {
77 let ptr = bits.as_ptr().cast::<usize>();
78 let true_count = (0..Self::N_WORDS)
79 .map(|idx| unsafe { ptr.add(idx).read_unaligned().count_ones() as usize })
80 .sum();
81 BitView {
82 bits: Cow::Owned(bits),
83 true_count,
84 }
85 }
86
87 pub(crate) const unsafe fn new_unchecked(bits: &'a [u8; NB], true_count: usize) -> Self {
93 BitView {
94 bits: Cow::Borrowed(bits),
95 true_count,
96 }
97 }
98
99 pub fn from_slice(bits: &'a [u8]) -> Self {
105 assert_eq!(bits.len(), NB);
106 let bits_array = unsafe { &*(bits.as_ptr() as *const [u8; NB]) };
107 BitView::new(bits_array)
108 }
109
110 pub fn with_prefix(n_true: usize) -> Self {
113 assert!(n_true <= Self::N);
114
115 let mut bits = [0u8; NB];
117
118 let n_full_words = n_true / (usize::BITS as usize);
120 let remaining_bits = n_true % (usize::BITS as usize);
121
122 let ptr = bits.as_mut_ptr().cast::<usize>();
123
124 for word_idx in 0..n_full_words {
126 unsafe { ptr.add(word_idx).write_unaligned(usize::MAX) };
127 }
128
129 if remaining_bits > 0 {
131 let mask = (1usize << remaining_bits) - 1;
132 unsafe { ptr.add(n_full_words).write_unaligned(mask) };
133 }
134
135 Self {
136 bits: Cow::Owned(bits),
137 true_count: n_true,
138 }
139 }
140
141 pub fn true_count(&self) -> usize {
143 self.true_count
144 }
145
146 fn iter_words(&self) -> impl Iterator<Item = usize> + '_ {
152 let ptr = self.bits.as_ptr().cast::<usize>();
153 (0..Self::N_WORDS).map(move |idx| unsafe { ptr.add(idx).read_unaligned() })
155 }
156
157 pub fn iter_ones<F>(&self, mut f: F)
159 where
160 F: FnMut(usize),
161 {
162 match self.true_count {
163 0 => {}
164 n if n == Self::N => (0..Self::N).for_each(&mut f),
165 _ => {
166 let mut bit_idx = 0;
167 for mut raw in self.iter_words() {
168 while raw != 0 {
169 let bit_pos = raw.trailing_zeros();
170 f(bit_idx + bit_pos as usize);
171 raw &= raw - 1; }
173 bit_idx += usize::BITS as usize;
174 }
175 }
176 }
177 }
178
179 pub fn try_iter_ones<F>(&self, mut f: F) -> VortexResult<()>
181 where
182 F: FnMut(usize) -> VortexResult<()>,
183 {
184 match self.true_count {
185 0 => Ok(()),
186 n if n == Self::N => {
187 for i in 0..Self::N {
188 f(i)?;
189 }
190 Ok(())
191 }
192 _ => {
193 let mut bit_idx = 0;
194 for mut raw in self.iter_words() {
195 while raw != 0 {
196 let bit_pos = raw.trailing_zeros();
197 f(bit_idx + bit_pos as usize)?;
198 raw &= raw - 1; }
200 bit_idx += usize::BITS as usize;
201 }
202 Ok(())
203 }
204 }
205 }
206
207 pub fn iter_zeros<F>(&self, mut f: F)
209 where
210 F: FnMut(usize),
211 {
212 match self.true_count {
213 0 => (0..Self::N).for_each(&mut f),
214 n if n == Self::N => {}
215 _ => {
216 let mut bit_idx = 0;
217 for mut raw in self.iter_words() {
218 while raw != usize::MAX {
219 let bit_pos = raw.trailing_ones();
220 f(bit_idx + bit_pos as usize);
221 raw |= 1usize << bit_pos; }
223 bit_idx += usize::BITS as usize;
224 }
225 }
226 }
227 }
228
229 pub fn iter_slices<F>(&self, mut f: F)
236 where
237 F: FnMut(BitSlice),
238 {
239 if self.true_count == 0 {
240 return;
241 }
242
243 let mut abs_bit_offset: usize = 0; let mut slice_start_bit: usize = 0; let mut slice_length: usize = 0; for mut word in self.iter_words() {
248 match word {
249 0 => {
250 if slice_length > 0 {
252 f(BitSlice {
253 start: slice_start_bit,
254 len: slice_length,
255 });
256 slice_length = 0;
257 }
258 }
259 usize::MAX => {
260 if slice_length == 0 {
262 slice_start_bit = abs_bit_offset;
263 }
264 slice_length += usize::BITS as usize;
266 }
267 _ => {
268 while word != 0 {
269 let zeros = word.trailing_zeros() as usize;
271
272 if slice_length > 0 && zeros > 0 {
274 f(BitSlice {
275 start: slice_start_bit,
276 len: slice_length,
277 });
278 slice_length = 0; }
280
281 word >>= zeros;
283
284 if word == 0 {
285 break;
286 }
287
288 let ones = word.trailing_ones() as usize;
290
291 if slice_length == 0 {
293 let current_word_idx = abs_bit_offset + zeros;
295 slice_start_bit = current_word_idx;
296 }
297
298 slice_length += ones;
300
301 word >>= ones;
303 }
304 }
305 }
306
307 abs_bit_offset += usize::BITS as usize;
308 }
309
310 if slice_length > 0 {
311 f(BitSlice {
312 start: slice_start_bit,
313 len: slice_length,
314 });
315 }
316 }
317
318 pub fn as_raw(&self) -> &[u8; NB] {
320 self.bits.as_ref()
321 }
322}
323
324pub struct BitSlice {
328 pub start: usize,
330 pub len: usize,
332}
333
334impl BitBuffer {
335 pub fn iter_bit_views<const NB: usize>(&self) -> impl Iterator<Item = BitView<'_, NB>> + '_ {
347 assert_eq!(
348 self.offset(),
349 0,
350 "BitView iteration requires zero bit offset"
351 );
352 BitViewIterator::new(self.inner().as_ref())
353 }
354}
355
356impl BitBufferMut {
357 pub fn iter_bit_views<const NB: usize>(&self) -> impl Iterator<Item = BitView<'_, NB>> + '_ {
369 assert_eq!(
370 self.offset(),
371 0,
372 "BitView iteration requires zero bit offset"
373 );
374 BitViewIterator::new(self.inner().as_ref())
375 }
376}
377
378pub(super) struct BitViewIterator<'a, const NB: usize> {
380 bits: &'a [u8],
381 view_idx: usize,
383 n_views: usize,
385 _phantom: PhantomData<[u8; NB]>,
387}
388
389impl<'a, const NB: usize> BitViewIterator<'a, NB> {
390 pub fn new(bits: &'a [u8]) -> Self {
392 let n_views = bits.len().div_ceil(NB);
393 BitViewIterator {
394 bits,
395 view_idx: 0,
396 n_views,
397 _phantom: PhantomData,
398 }
399 }
400}
401
402impl<'a, const NB: usize> Iterator for BitViewIterator<'a, NB> {
403 type Item = BitView<'a, NB>;
404
405 fn next(&mut self) -> Option<Self::Item> {
406 if self.view_idx == self.n_views {
407 return None;
408 }
409
410 let start_byte = self.view_idx * NB;
411 let end_byte = start_byte + NB;
412
413 let bits = if end_byte <= self.bits.len() {
414 BitView::from_slice(&self.bits[start_byte..end_byte])
416 } else {
417 let remaining_bytes = self.bits.len() - start_byte;
419 let mut remaining = [0u8; NB];
420 remaining[..remaining_bytes].copy_from_slice(&self.bits[start_byte..]);
421 BitView::new_owned(remaining)
422 };
423
424 self.view_idx += 1;
425 Some(bits)
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 const NB: usize = 128; const N: usize = NB * 8; #[test]
437 fn test_iter_ones_empty() {
438 let bits = [0; NB];
439 let view = BitView::<NB>::new(&bits);
440
441 let mut ones = Vec::new();
442 view.iter_ones(|idx| ones.push(idx));
443
444 assert_eq!(ones, Vec::<usize>::new());
445 assert_eq!(view.true_count(), 0);
446 }
447
448 #[test]
449 fn test_iter_ones_all_set() {
450 let view = BitView::<NB>::all_true();
451
452 let mut ones = Vec::new();
453 view.iter_ones(|idx| ones.push(idx));
454
455 assert_eq!(ones.len(), N);
456 assert_eq!(ones, (0..N).collect::<Vec<_>>());
457 assert_eq!(view.true_count(), N);
458 }
459
460 #[test]
461 fn test_iter_zeros_empty() {
462 let bits = [0; NB];
463 let view = BitView::<NB>::new(&bits);
464
465 let mut zeros = Vec::new();
466 view.iter_zeros(|idx| zeros.push(idx));
467
468 assert_eq!(zeros.len(), N);
469 assert_eq!(zeros, (0..N).collect::<Vec<_>>());
470 }
471
472 #[test]
473 fn test_iter_zeros_all_set() {
474 let view = BitView::<NB>::all_true();
475
476 let mut zeros = Vec::new();
477 view.iter_zeros(|idx| zeros.push(idx));
478
479 assert_eq!(zeros, Vec::<usize>::new());
480 }
481
482 #[test]
483 fn test_iter_ones_single_bit() {
484 let mut bits = [0; NB];
485 bits[0] = 1; let view = BitView::new(&bits);
487
488 let mut ones = Vec::new();
489 view.iter_ones(|idx| ones.push(idx));
490
491 assert_eq!(ones, vec![0]);
492 assert_eq!(view.true_count(), 1);
493 }
494
495 #[test]
496 fn test_iter_zeros_single_bit_unset() {
497 let mut bits = [u8::MAX; NB];
498 bits[0] = u8::MAX ^ 1; let view = BitView::new(&bits);
500
501 let mut zeros = Vec::new();
502 view.iter_zeros(|idx| zeros.push(idx));
503
504 assert_eq!(zeros, vec![0]);
505 }
506
507 #[test]
508 fn test_iter_ones_multiple_bits_first_word() {
509 let mut bits = [0; NB];
510 bits[0] = 0b1010101; let view = BitView::new(&bits);
512
513 let mut ones = Vec::new();
514 view.iter_ones(|idx| ones.push(idx));
515
516 assert_eq!(ones, vec![0, 2, 4, 6]);
517 assert_eq!(view.true_count(), 4);
518 }
519
520 #[test]
521 fn test_iter_zeros_multiple_bits_first_word() {
522 let mut bits = [u8::MAX; NB];
523 bits[0] = !0b1010101; let view = BitView::new(&bits);
525
526 let mut zeros = Vec::new();
527 view.iter_zeros(|idx| zeros.push(idx));
528
529 assert_eq!(zeros, vec![0, 2, 4, 6]);
530 }
531
532 #[test]
533 fn test_lsb_bit_ordering() {
534 let mut bits = [0; NB];
535 bits[0] = 0b11111111; let view = BitView::new(&bits);
537
538 let mut ones = Vec::new();
539 view.iter_ones(|idx| ones.push(idx));
540
541 assert_eq!(ones, vec![0, 1, 2, 3, 4, 5, 6, 7]);
542 assert_eq!(view.true_count(), 8);
543 }
544
545 #[test]
546 fn test_all_false_static() {
547 let view = BitView::<NB>::all_false();
548
549 let mut ones = Vec::new();
550 let mut zeros = Vec::new();
551 view.iter_ones(|idx| ones.push(idx));
552 view.iter_zeros(|idx| zeros.push(idx));
553
554 assert_eq!(ones, Vec::<usize>::new());
555 assert_eq!(zeros, (0..N).collect::<Vec<_>>());
556 assert_eq!(view.true_count(), 0);
557 }
558
559 #[test]
560 fn test_compatibility_with_mask_all_true() {
561 let view = BitView::<NB>::all_true();
563
564 let mut bitview_ones = Vec::new();
566 view.iter_ones(|idx| bitview_ones.push(idx));
567
568 let expected_indices: Vec<usize> = (0..N).collect();
570
571 assert_eq!(bitview_ones, expected_indices);
572 assert_eq!(view.true_count(), N);
573 }
574
575 #[test]
576 fn test_compatibility_with_mask_all_false() {
577 let view = BitView::<NB>::all_false();
579
580 let mut bitview_ones = Vec::new();
582 view.iter_ones(|idx| bitview_ones.push(idx));
583
584 let mut bitview_zeros = Vec::new();
586 view.iter_zeros(|idx| bitview_zeros.push(idx));
587
588 assert_eq!(bitview_ones, Vec::<usize>::new());
589 assert_eq!(bitview_zeros, (0..N).collect::<Vec<_>>());
590 assert_eq!(view.true_count(), 0);
591 }
592
593 #[test]
594 fn test_compatibility_with_mask_from_indices() {
595 let indices = vec![0, 10, 20, 63, 64, 100, 500, 1023];
597
598 let mut bits = [0; NB];
600 for idx in &indices {
601 let word_idx = idx / 8;
602 let bit_idx = idx % 8;
603 bits[word_idx] |= 1u8 << bit_idx;
604 }
605 let view = BitView::new(&bits);
606
607 let mut bitview_ones = Vec::new();
609 view.iter_ones(|idx| bitview_ones.push(idx));
610
611 assert_eq!(bitview_ones, indices);
612 assert_eq!(view.true_count(), indices.len());
613 }
614
615 #[test]
616 fn test_compatibility_with_mask_slices() {
617 let slices = vec![(0, 10), (100, 110), (500, 510)];
619
620 let mut bits = [0; NB];
622 for (start, end) in &slices {
623 for idx in *start..*end {
624 let word_idx = idx / 8;
625 let bit_idx = idx % 8;
626 bits[word_idx] |= 1u8 << bit_idx;
627 }
628 }
629 let view = BitView::new(&bits);
630
631 let mut bitview_ones = Vec::new();
633 view.iter_ones(|idx| bitview_ones.push(idx));
634
635 let mut expected_indices = Vec::new();
637 for (start, end) in &slices {
638 expected_indices.extend(*start..*end);
639 }
640
641 assert_eq!(bitview_ones, expected_indices);
642 assert_eq!(view.true_count(), expected_indices.len());
643 }
644
645 #[test]
646 fn test_with_prefix() {
647 assert_eq!(BitView::<NB>::with_prefix(0).true_count(), 0);
648
649 for i in 1..N {
651 let view = BitView::<NB>::with_prefix(i);
652
653 let mut slices = vec![];
655 view.iter_slices(|slice| slices.push(slice));
656
657 assert_eq!(slices.len(), 1);
658 }
659 }
660}