1use std::borrow::Cow;
5use std::fmt::Debug;
6use std::fmt::Formatter;
7use std::marker::PhantomData;
8
9use vortex_error::VortexResult;
10
11use crate::BitBuffer;
12use crate::BitBufferMut;
13
14pub struct BitView<'a, const NB: usize> {
25 bits: Cow<'a, [u8; NB]>,
26 true_count: usize,
30}
31
32impl<const NB: usize> Debug for BitView<'_, NB> {
33 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct(&format!("BitView[{}]", NB * 8))
35 .field("true_count", &self.true_count)
36 .field("bits", &self.as_raw())
37 .finish()
38 }
39}
40
41impl<const NB: usize> BitView<'static, NB> {
42 const ALL_TRUE: [u8; NB] = [u8::MAX; NB];
43 const ALL_FALSE: [u8; NB] = [0; NB];
44
45 pub const fn all_true() -> Self {
47 unsafe { BitView::new_unchecked(&Self::ALL_TRUE, NB * 8) }
48 }
49
50 pub const fn all_false() -> Self {
52 unsafe { BitView::new_unchecked(&Self::ALL_FALSE, 0) }
53 }
54}
55
56impl<'a, const NB: usize> BitView<'a, NB> {
57 pub const N: usize = NB * 8;
59 pub const N_WORDS: usize = NB * 8 / (usize::BITS as usize);
61
62 const _ASSERT_MULTIPLE_OF_8: () = assert!(
63 NB % 8 == 0,
64 "NB must be a multiple of 8 for N to be a multiple of 64"
65 );
66
67 pub fn new(bits: &'a [u8; NB]) -> Self {
69 let ptr = bits.as_ptr().cast::<usize>();
70 let true_count = (0..Self::N_WORDS)
71 .map(|idx| unsafe { ptr.add(idx).read_unaligned().count_ones() as usize })
72 .sum();
73 BitView {
74 bits: Cow::Borrowed(bits),
75 true_count,
76 }
77 }
78
79 pub fn new_owned(bits: [u8; NB]) -> Self {
81 let ptr = bits.as_ptr().cast::<usize>();
82 let true_count = (0..Self::N_WORDS)
83 .map(|idx| unsafe { ptr.add(idx).read_unaligned().count_ones() as usize })
84 .sum();
85 BitView {
86 bits: Cow::Owned(bits),
87 true_count,
88 }
89 }
90
91 pub(crate) const unsafe fn new_unchecked(bits: &'a [u8; NB], true_count: usize) -> Self {
97 BitView {
98 bits: Cow::Borrowed(bits),
99 true_count,
100 }
101 }
102
103 pub fn from_slice(bits: &'a [u8]) -> Self {
109 assert_eq!(bits.len(), NB);
110 let bits_array = unsafe { &*(bits.as_ptr() as *const [u8; NB]) };
111 BitView::new(bits_array)
112 }
113
114 pub fn with_prefix(n_true: usize) -> Self {
117 assert!(n_true <= Self::N);
118
119 let mut bits = [0u8; NB];
121
122 let n_full_words = n_true / (usize::BITS as usize);
124 let remaining_bits = n_true % (usize::BITS as usize);
125
126 let ptr = bits.as_mut_ptr().cast::<usize>();
127
128 for word_idx in 0..n_full_words {
130 unsafe { ptr.add(word_idx).write_unaligned(usize::MAX) };
131 }
132
133 if remaining_bits > 0 {
135 let mask = (1usize << remaining_bits) - 1;
136 unsafe { ptr.add(n_full_words).write_unaligned(mask) };
137 }
138
139 Self {
140 bits: Cow::Owned(bits),
141 true_count: n_true,
142 }
143 }
144
145 pub fn true_count(&self) -> usize {
147 self.true_count
148 }
149
150 pub fn iter_words(&self) -> impl Iterator<Item = usize> + '_ {
156 let ptr = self.bits.as_ptr().cast::<usize>();
157 (0..Self::N_WORDS).map(move |idx| unsafe { ptr.add(idx).read_unaligned() })
159 }
160
161 pub fn iter_ones<F>(&self, mut f: F)
163 where
164 F: FnMut(usize),
165 {
166 match self.true_count {
167 0 => {}
168 n if n == Self::N => (0..Self::N).for_each(&mut f),
169 _ => {
170 let mut bit_idx = 0;
171 for mut raw in self.iter_words() {
172 while raw != 0 {
173 let bit_pos = raw.trailing_zeros();
174 f(bit_idx + bit_pos as usize);
175 raw &= raw - 1; }
177 bit_idx += usize::BITS as usize;
178 }
179 }
180 }
181 }
182
183 pub fn try_iter_ones<F>(&self, mut f: F) -> VortexResult<()>
185 where
186 F: FnMut(usize) -> VortexResult<()>,
187 {
188 match self.true_count {
189 0 => Ok(()),
190 n if n == Self::N => {
191 for i in 0..Self::N {
192 f(i)?;
193 }
194 Ok(())
195 }
196 _ => {
197 let mut bit_idx = 0;
198 for mut raw in self.iter_words() {
199 while raw != 0 {
200 let bit_pos = raw.trailing_zeros();
201 f(bit_idx + bit_pos as usize)?;
202 raw &= raw - 1; }
204 bit_idx += usize::BITS as usize;
205 }
206 Ok(())
207 }
208 }
209 }
210
211 pub fn iter_zeros<F>(&self, mut f: F)
213 where
214 F: FnMut(usize),
215 {
216 match self.true_count {
217 0 => (0..Self::N).for_each(&mut f),
218 n if n == Self::N => {}
219 _ => {
220 let mut bit_idx = 0;
221 for mut raw in self.iter_words() {
222 while raw != usize::MAX {
223 let bit_pos = raw.trailing_ones();
224 f(bit_idx + bit_pos as usize);
225 raw |= 1usize << bit_pos; }
227 bit_idx += usize::BITS as usize;
228 }
229 }
230 }
231 }
232
233 pub fn iter_slices<F>(&self, mut f: F)
240 where
241 F: FnMut(BitSlice),
242 {
243 if self.true_count == 0 {
244 return;
245 }
246
247 let mut abs_bit_offset: usize = 0; let mut slice_start_bit: usize = 0; let mut slice_length: usize = 0; for mut word in self.iter_words() {
252 match word {
253 0 => {
254 if slice_length > 0 {
256 f(BitSlice {
257 start: slice_start_bit,
258 len: slice_length,
259 });
260 slice_length = 0;
261 }
262 }
263 usize::MAX => {
264 if slice_length == 0 {
266 slice_start_bit = abs_bit_offset;
267 }
268 slice_length += usize::BITS as usize;
270 }
271 _ => {
272 while word != 0 {
273 let zeros = word.trailing_zeros() as usize;
275
276 if slice_length > 0 && zeros > 0 {
278 f(BitSlice {
279 start: slice_start_bit,
280 len: slice_length,
281 });
282 slice_length = 0; }
284
285 word >>= zeros;
287
288 if word == 0 {
289 break;
290 }
291
292 let ones = word.trailing_ones() as usize;
294
295 if slice_length == 0 {
297 let current_word_idx = abs_bit_offset + zeros;
299 slice_start_bit = current_word_idx;
300 }
301
302 slice_length += ones;
304
305 word >>= ones;
307 }
308 }
309 }
310
311 abs_bit_offset += usize::BITS as usize;
312 }
313
314 if slice_length > 0 {
315 f(BitSlice {
316 start: slice_start_bit,
317 len: slice_length,
318 });
319 }
320 }
321
322 pub fn as_raw(&self) -> &[u8; NB] {
324 self.bits.as_ref()
325 }
326}
327
328pub struct BitSlice {
332 pub start: usize,
334 pub len: usize,
336}
337
338impl BitBuffer {
339 pub fn iter_bit_views<const NB: usize>(&self) -> impl Iterator<Item = BitView<'_, NB>> + '_ {
351 assert_eq!(
352 self.offset(),
353 0,
354 "BitView iteration requires zero bit offset"
355 );
356 BitViewIterator::new(self.inner().as_ref(), self.len())
357 }
358}
359
360impl BitBufferMut {
361 pub fn iter_bit_views<const NB: usize>(&self) -> impl Iterator<Item = BitView<'_, NB>> + '_ {
373 assert_eq!(
374 self.offset(),
375 0,
376 "BitView iteration requires zero bit offset"
377 );
378 BitViewIterator::new(self.inner().as_ref(), self.len())
379 }
380}
381
382pub(super) struct BitViewIterator<'a, const NB: usize> {
384 bits: &'a [u8],
385 view_idx: usize,
387 n_views: usize,
389 _phantom: PhantomData<[u8; NB]>,
391}
392
393impl<'a, const NB: usize> BitViewIterator<'a, NB> {
394 fn new(bits: &'a [u8], len: usize) -> Self {
396 debug_assert_eq!(len.div_ceil(8), bits.len());
397 let n_views = bits.len().div_ceil(NB);
398 BitViewIterator {
399 bits,
400 view_idx: 0,
401 n_views,
402 _phantom: PhantomData,
403 }
404 }
405}
406
407impl<'a, const NB: usize> Iterator for BitViewIterator<'a, NB> {
408 type Item = BitView<'a, NB>;
409
410 fn next(&mut self) -> Option<Self::Item> {
411 if self.view_idx == self.n_views {
412 return None;
413 }
414
415 let start_byte = self.view_idx * NB;
416 let end_byte = start_byte + NB;
417
418 let bits = if end_byte <= self.bits.len() {
419 BitView::from_slice(&self.bits[start_byte..end_byte])
421 } else {
422 let remaining_bytes = self.bits.len() - start_byte;
424 let mut remaining = [0u8; NB];
425 remaining[..remaining_bytes].copy_from_slice(&self.bits[start_byte..]);
426 BitView::new_owned(remaining)
427 };
428
429 self.view_idx += 1;
430 Some(bits)
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 const NB: usize = 128; const N: usize = NB * 8; #[test]
442 fn test_iter_ones_empty() {
443 let bits = [0; NB];
444 let view = BitView::<NB>::new(&bits);
445
446 let mut ones = Vec::new();
447 view.iter_ones(|idx| ones.push(idx));
448
449 assert_eq!(ones, Vec::<usize>::new());
450 assert_eq!(view.true_count(), 0);
451 }
452
453 #[test]
454 fn test_iter_ones_all_set() {
455 let view = BitView::<NB>::all_true();
456
457 let mut ones = Vec::new();
458 view.iter_ones(|idx| ones.push(idx));
459
460 assert_eq!(ones.len(), N);
461 assert_eq!(ones, (0..N).collect::<Vec<_>>());
462 assert_eq!(view.true_count(), N);
463 }
464
465 #[test]
466 fn test_iter_zeros_empty() {
467 let bits = [0; NB];
468 let view = BitView::<NB>::new(&bits);
469
470 let mut zeros = Vec::new();
471 view.iter_zeros(|idx| zeros.push(idx));
472
473 assert_eq!(zeros.len(), N);
474 assert_eq!(zeros, (0..N).collect::<Vec<_>>());
475 }
476
477 #[test]
478 fn test_iter_zeros_all_set() {
479 let view = BitView::<NB>::all_true();
480
481 let mut zeros = Vec::new();
482 view.iter_zeros(|idx| zeros.push(idx));
483
484 assert_eq!(zeros, Vec::<usize>::new());
485 }
486
487 #[test]
488 fn test_iter_ones_single_bit() {
489 let mut bits = [0; NB];
490 bits[0] = 1; let view = BitView::new(&bits);
492
493 let mut ones = Vec::new();
494 view.iter_ones(|idx| ones.push(idx));
495
496 assert_eq!(ones, vec![0]);
497 assert_eq!(view.true_count(), 1);
498 }
499
500 #[test]
501 fn test_iter_zeros_single_bit_unset() {
502 let mut bits = [u8::MAX; NB];
503 bits[0] = u8::MAX ^ 1; let view = BitView::new(&bits);
505
506 let mut zeros = Vec::new();
507 view.iter_zeros(|idx| zeros.push(idx));
508
509 assert_eq!(zeros, vec![0]);
510 }
511
512 #[test]
513 fn test_iter_ones_multiple_bits_first_word() {
514 let mut bits = [0; NB];
515 bits[0] = 0b1010101; let view = BitView::new(&bits);
517
518 let mut ones = Vec::new();
519 view.iter_ones(|idx| ones.push(idx));
520
521 assert_eq!(ones, vec![0, 2, 4, 6]);
522 assert_eq!(view.true_count(), 4);
523 }
524
525 #[test]
526 fn test_iter_zeros_multiple_bits_first_word() {
527 let mut bits = [u8::MAX; NB];
528 bits[0] = !0b1010101; let view = BitView::new(&bits);
530
531 let mut zeros = Vec::new();
532 view.iter_zeros(|idx| zeros.push(idx));
533
534 assert_eq!(zeros, vec![0, 2, 4, 6]);
535 }
536
537 #[test]
538 fn test_lsb_bit_ordering() {
539 let mut bits = [0; NB];
540 bits[0] = 0b11111111; let view = BitView::new(&bits);
542
543 let mut ones = Vec::new();
544 view.iter_ones(|idx| ones.push(idx));
545
546 assert_eq!(ones, vec![0, 1, 2, 3, 4, 5, 6, 7]);
547 assert_eq!(view.true_count(), 8);
548 }
549
550 #[test]
551 fn test_all_false_static() {
552 let view = BitView::<NB>::all_false();
553
554 let mut ones = Vec::new();
555 let mut zeros = Vec::new();
556 view.iter_ones(|idx| ones.push(idx));
557 view.iter_zeros(|idx| zeros.push(idx));
558
559 assert_eq!(ones, Vec::<usize>::new());
560 assert_eq!(zeros, (0..N).collect::<Vec<_>>());
561 assert_eq!(view.true_count(), 0);
562 }
563
564 #[test]
565 fn test_compatibility_with_mask_all_true() {
566 let view = BitView::<NB>::all_true();
568
569 let mut bitview_ones = Vec::new();
571 view.iter_ones(|idx| bitview_ones.push(idx));
572
573 let expected_indices: Vec<usize> = (0..N).collect();
575
576 assert_eq!(bitview_ones, expected_indices);
577 assert_eq!(view.true_count(), N);
578 }
579
580 #[test]
581 fn test_compatibility_with_mask_all_false() {
582 let view = BitView::<NB>::all_false();
584
585 let mut bitview_ones = Vec::new();
587 view.iter_ones(|idx| bitview_ones.push(idx));
588
589 let mut bitview_zeros = Vec::new();
591 view.iter_zeros(|idx| bitview_zeros.push(idx));
592
593 assert_eq!(bitview_ones, Vec::<usize>::new());
594 assert_eq!(bitview_zeros, (0..N).collect::<Vec<_>>());
595 assert_eq!(view.true_count(), 0);
596 }
597
598 #[test]
599 fn test_compatibility_with_mask_from_indices() {
600 let indices = vec![0, 10, 20, 63, 64, 100, 500, 1023];
602
603 let mut bits = [0; NB];
605 for idx in &indices {
606 let word_idx = idx / 8;
607 let bit_idx = idx % 8;
608 bits[word_idx] |= 1u8 << bit_idx;
609 }
610 let view = BitView::new(&bits);
611
612 let mut bitview_ones = Vec::new();
614 view.iter_ones(|idx| bitview_ones.push(idx));
615
616 assert_eq!(bitview_ones, indices);
617 assert_eq!(view.true_count(), indices.len());
618 }
619
620 #[test]
621 fn test_compatibility_with_mask_slices() {
622 let slices = vec![(0, 10), (100, 110), (500, 510)];
624
625 let mut bits = [0; NB];
627 for (start, end) in &slices {
628 for idx in *start..*end {
629 let word_idx = idx / 8;
630 let bit_idx = idx % 8;
631 bits[word_idx] |= 1u8 << bit_idx;
632 }
633 }
634 let view = BitView::new(&bits);
635
636 let mut bitview_ones = Vec::new();
638 view.iter_ones(|idx| bitview_ones.push(idx));
639
640 let mut expected_indices = Vec::new();
642 for (start, end) in &slices {
643 expected_indices.extend(*start..*end);
644 }
645
646 assert_eq!(bitview_ones, expected_indices);
647 assert_eq!(view.true_count(), expected_indices.len());
648 }
649
650 #[test]
651 fn test_with_prefix() {
652 assert_eq!(BitView::<NB>::with_prefix(0).true_count(), 0);
653
654 for i in 1..N {
656 let view = BitView::<NB>::with_prefix(i);
657
658 let mut slices = vec![];
660 view.iter_slices(|slice| slices.push(slice));
661
662 assert_eq!(slices.len(), 1);
663 }
664 }
665}