vortex_array/pipeline/bits/
view.rs1use std::fmt::{Debug, Formatter};
5
6use bitvec::prelude::*;
7use vortex_error::{VortexError, VortexResult, vortex_err};
8
9use crate::pipeline::{N, N_WORDS};
10
11#[derive(Clone, Copy)]
18pub struct BitView<'a> {
19 bits: &'a BitArray<[usize; N_WORDS], Lsb0>,
20 true_count: usize,
24}
25
26impl Debug for BitView<'_> {
27 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("BitView")
29 .field("true_count", &self.true_count)
30 .field("bits", &self.as_raw())
31 .finish()
32 }
33}
34
35impl BitView<'static> {
36 pub fn all_true() -> Self {
37 static ALL_TRUE: [usize; N_WORDS] = [usize::MAX; N_WORDS];
38 unsafe {
39 BitView::new_unchecked(
40 std::mem::transmute::<&[usize; N_WORDS], &BitArray<[usize; N_WORDS], Lsb0>>(
41 &ALL_TRUE,
42 ),
43 N,
44 )
45 }
46 }
47
48 pub fn all_false() -> Self {
49 static ALL_FALSE: [usize; N_WORDS] = [0; N_WORDS];
50 unsafe {
51 BitView::new_unchecked(
52 std::mem::transmute::<&[usize; N_WORDS], &BitArray<[usize; N_WORDS], Lsb0>>(
53 &ALL_FALSE,
54 ),
55 0,
56 )
57 }
58 }
59}
60
61impl<'a> BitView<'a> {
62 pub fn new(bits: &[usize; N_WORDS]) -> Self {
63 let true_count = bits.iter().map(|&word| word.count_ones() as usize).sum();
64 let bits: &BitArray<[usize; N_WORDS], Lsb0> = unsafe {
65 std::mem::transmute::<&[usize; N_WORDS], &BitArray<[usize; N_WORDS], Lsb0>>(bits)
66 };
67 BitView { bits, true_count }
68 }
69
70 pub(crate) unsafe fn new_unchecked(
71 bits: &'a BitArray<[usize; N_WORDS], Lsb0>,
72 true_count: usize,
73 ) -> Self {
74 BitView { bits, true_count }
75 }
76
77 pub fn true_count(&self) -> usize {
79 self.true_count
80 }
81
82 pub fn iter_ones<F>(&self, mut f: F)
84 where
85 F: FnMut(usize),
86 {
87 match self.true_count {
88 0 => {}
89 N => (0..N).for_each(&mut f),
90 _ => {
91 let mut bit_idx = 0;
92 for mut raw in self.bits.into_inner() {
93 while raw != 0 {
94 let bit_pos = raw.trailing_zeros();
95 f(bit_idx + bit_pos as usize);
96 raw &= raw - 1; }
98 bit_idx += usize::BITS as usize;
99 }
100 }
101 }
102 }
103
104 pub fn try_iter_ones<F>(&self, mut f: F) -> VortexResult<()>
106 where
107 F: FnMut(usize) -> VortexResult<()>,
108 {
109 match self.true_count {
110 0 => Ok(()),
111 N => {
112 for i in 0..N {
113 f(i)?;
114 }
115 Ok(())
116 }
117 _ => {
118 let mut bit_idx = 0;
119 for mut raw in self.bits.into_inner() {
120 while raw != 0 {
121 let bit_pos = raw.trailing_zeros();
122 f(bit_idx + bit_pos as usize)?;
123 raw &= raw - 1; }
125 bit_idx += usize::BITS as usize;
126 }
127 Ok(())
128 }
129 }
130 }
131
132 pub fn iter_zeros<F>(&self, mut f: F)
134 where
135 F: FnMut(usize),
136 {
137 match self.true_count {
138 0 => (0..N).for_each(&mut f),
139 N => {}
140 _ => {
141 let mut bit_idx = 0;
142 for mut raw in self.bits.into_inner() {
143 while raw != usize::MAX {
144 let bit_pos = raw.trailing_ones();
145 f(bit_idx + bit_pos as usize);
146 raw |= 1usize << bit_pos; }
148 bit_idx += usize::BITS as usize;
149 }
150 }
151 }
152 }
153
154 pub fn iter_slices<F>(&self, mut f: F)
159 where
160 F: FnMut((usize, usize)),
161 {
162 match self.true_count {
163 0 => {}
164 N => f((0, N)),
165 _ => {
166 let mut bit_idx = 0;
167 for mut raw in self.bits.into_inner() {
168 let mut offset = 0;
169 while raw != 0 {
170 let zeros = raw.leading_zeros();
172 offset += zeros;
173 raw <<= zeros;
174
175 if offset >= 64 {
176 break;
177 }
178
179 let ones = raw.leading_ones();
181 if ones > 0 {
182 f((bit_idx + offset as usize, ones as usize));
183 offset += ones;
184 raw <<= ones;
185 }
186 }
187 bit_idx += usize::BITS as usize; }
189 }
190 }
191 }
192
193 pub fn as_raw(&self) -> &[usize; N_WORDS] {
194 let raw = self.bits.as_raw_slice();
196 unsafe { &*(raw.as_ptr() as *const [usize; N_WORDS]) }
197 }
198}
199
200impl<'a> From<&'a [usize; N_WORDS]> for BitView<'a> {
201 fn from(value: &'a [usize; N_WORDS]) -> Self {
202 Self::new(value)
203 }
204}
205
206impl<'a> From<&'a BitArray<[usize; N_WORDS], Lsb0>> for BitView<'a> {
207 fn from(bits: &'a BitArray<[usize; N_WORDS], Lsb0>) -> Self {
208 BitView::new(unsafe {
209 std::mem::transmute::<&BitArray<[usize; N_WORDS]>, &[usize; N_WORDS]>(bits)
210 })
211 }
212}
213
214impl<'a> TryFrom<&'a BitSlice<usize, Lsb0>> for BitView<'a> {
215 type Error = VortexError;
216
217 fn try_from(value: &'a BitSlice<usize, Lsb0>) -> Result<Self, Self::Error> {
218 let bits: &BitArray<[usize; N_WORDS], Lsb0> = value
219 .try_into()
220 .map_err(|e| vortex_err!("Failed to convert BitSlice to BitArray: {}", e))?;
221 Ok(BitView::new(unsafe {
222 std::mem::transmute::<&BitArray<[usize; N_WORDS]>, &[usize; N_WORDS]>(bits)
223 }))
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use vortex_mask::Mask;
230
231 use super::*;
232 use crate::pipeline::bits::BitVector;
233
234 #[test]
235 fn test_iter_ones_empty() {
236 let bits = [0usize; N_WORDS];
237 let view = BitView::new(&bits);
238
239 let mut ones = Vec::new();
240 view.iter_ones(|idx| ones.push(idx));
241
242 assert_eq!(ones, Vec::<usize>::new());
243 assert_eq!(view.true_count(), 0);
244 }
245
246 #[test]
247 fn test_iter_ones_all_set() {
248 let view = BitView::all_true();
249
250 let mut ones = Vec::new();
251 view.iter_ones(|idx| ones.push(idx));
252
253 assert_eq!(ones.len(), N);
254 assert_eq!(ones, (0..N).collect::<Vec<_>>());
255 assert_eq!(view.true_count(), N);
256 }
257
258 #[test]
259 fn test_iter_zeros_empty() {
260 let bits = [0usize; N_WORDS];
261 let view = BitView::new(&bits);
262
263 let mut zeros = Vec::new();
264 view.iter_zeros(|idx| zeros.push(idx));
265
266 assert_eq!(zeros.len(), N);
267 assert_eq!(zeros, (0..N).collect::<Vec<_>>());
268 }
269
270 #[test]
271 fn test_iter_zeros_all_set() {
272 let view = BitView::all_true();
273
274 let mut zeros = Vec::new();
275 view.iter_zeros(|idx| zeros.push(idx));
276
277 assert_eq!(zeros, Vec::<usize>::new());
278 }
279
280 #[test]
281 fn test_iter_ones_single_bit() {
282 let mut bits = [0usize; N_WORDS];
283 bits[0] = 1; let view = BitView::new(&bits);
285
286 let mut ones = Vec::new();
287 view.iter_ones(|idx| ones.push(idx));
288
289 assert_eq!(ones, vec![0]);
290 assert_eq!(view.true_count(), 1);
291 }
292
293 #[test]
294 fn test_iter_zeros_single_bit_unset() {
295 let mut bits = [usize::MAX; N_WORDS];
296 bits[0] = usize::MAX ^ 1; let view = BitView::new(&bits);
298
299 let mut zeros = Vec::new();
300 view.iter_zeros(|idx| zeros.push(idx));
301
302 assert_eq!(zeros, vec![0]);
303 }
304
305 #[test]
306 fn test_iter_ones_multiple_bits_first_word() {
307 let mut bits = [0usize; N_WORDS];
308 bits[0] = 0b1010101; let view = BitView::new(&bits);
310
311 let mut ones = Vec::new();
312 view.iter_ones(|idx| ones.push(idx));
313
314 assert_eq!(ones, vec![0, 2, 4, 6]);
315 assert_eq!(view.true_count(), 4);
316 }
317
318 #[test]
319 fn test_iter_zeros_multiple_bits_first_word() {
320 let mut bits = [usize::MAX; N_WORDS];
321 bits[0] = !0b1010101; let view = BitView::new(&bits);
323
324 let mut zeros = Vec::new();
325 view.iter_zeros(|idx| zeros.push(idx));
326
327 assert_eq!(zeros, vec![0, 2, 4, 6]);
328 }
329
330 #[test]
331 fn test_iter_ones_across_words() {
332 let mut bits = [0usize; N_WORDS];
333 bits[0] = 1 << 63; bits[1] = 1; bits[2] = 1 << 31; let view = BitView::new(&bits);
337
338 let mut ones = Vec::new();
339 view.iter_ones(|idx| ones.push(idx));
340
341 assert_eq!(ones, vec![63, 64, 159]);
342 assert_eq!(view.true_count(), 3);
343 }
344
345 #[test]
346 fn test_iter_zeros_across_words() {
347 let mut bits = [usize::MAX; N_WORDS];
348 bits[0] = !(1 << 63); bits[1] = !1; bits[2] = !(1 << 31); let view = BitView::new(&bits);
352
353 let mut zeros = Vec::new();
354 view.iter_zeros(|idx| zeros.push(idx));
355
356 assert_eq!(zeros, vec![63, 64, 159]);
357 }
358
359 #[test]
360 fn test_lsb_bit_ordering() {
361 let mut bits = [0usize; N_WORDS];
362 bits[0] = 0b11111111; let view = BitView::new(&bits);
364
365 let mut ones = Vec::new();
366 view.iter_ones(|idx| ones.push(idx));
367
368 assert_eq!(ones, vec![0, 1, 2, 3, 4, 5, 6, 7]);
369 assert_eq!(view.true_count(), 8);
370 }
371
372 #[test]
373 fn test_iter_ones_and_zeros_complement() {
374 let mut bits = [0usize; N_WORDS];
375 bits[0] = 0xAAAAAAAAAAAAAAAA; let view = BitView::new(&bits);
377
378 let mut ones = Vec::new();
379 let mut zeros = Vec::new();
380 view.iter_ones(|idx| ones.push(idx));
381 view.iter_zeros(|idx| zeros.push(idx));
382
383 let mut all_indices = ones.clone();
385 all_indices.extend(&zeros);
386 all_indices.sort_unstable();
387
388 assert_eq!(all_indices, (0..N).collect::<Vec<_>>());
389
390 for one_idx in &ones {
392 assert!(!zeros.contains(one_idx));
393 }
394 }
395
396 #[test]
397 fn test_all_false_static() {
398 let view = BitView::all_false();
399
400 let mut ones = Vec::new();
401 let mut zeros = Vec::new();
402 view.iter_ones(|idx| ones.push(idx));
403 view.iter_zeros(|idx| zeros.push(idx));
404
405 assert_eq!(ones, Vec::<usize>::new());
406 assert_eq!(zeros, (0..N).collect::<Vec<_>>());
407 assert_eq!(view.true_count(), 0);
408 }
409
410 #[test]
411 fn test_compatibility_with_mask_all_true() {
412 let view = BitView::all_true();
414
415 let mut bitview_ones = Vec::new();
417 view.iter_ones(|idx| bitview_ones.push(idx));
418
419 let expected_indices: Vec<usize> = (0..N).collect();
421
422 assert_eq!(bitview_ones, expected_indices);
423 assert_eq!(view.true_count(), N);
424 }
425
426 #[test]
427 fn test_compatibility_with_mask_all_false() {
428 let view = BitView::all_false();
430
431 let mut bitview_ones = Vec::new();
433 view.iter_ones(|idx| bitview_ones.push(idx));
434
435 let mut bitview_zeros = Vec::new();
437 view.iter_zeros(|idx| bitview_zeros.push(idx));
438
439 assert_eq!(bitview_ones, Vec::<usize>::new());
440 assert_eq!(bitview_zeros, (0..N).collect::<Vec<_>>());
441 assert_eq!(view.true_count(), 0);
442 }
443
444 #[test]
445 fn test_compatibility_with_mask_from_indices() {
446 let indices = vec![0, 10, 20, 63, 64, 100, 500, 1023];
448
449 let mut bits = [0usize; N_WORDS];
451 for idx in &indices {
452 let word_idx = idx / 64;
453 let bit_idx = idx % 64;
454 bits[word_idx] |= 1usize << bit_idx;
455 }
456 let view = BitView::new(&bits);
457
458 let mut bitview_ones = Vec::new();
460 view.iter_ones(|idx| bitview_ones.push(idx));
461
462 assert_eq!(bitview_ones, indices);
463 assert_eq!(view.true_count(), indices.len());
464 }
465
466 #[test]
467 fn test_compatibility_with_mask_slices() {
468 let slices = vec![(0, 10), (100, 110), (500, 510)];
470
471 let mut bits = [0usize; N_WORDS];
473 for (start, end) in &slices {
474 for idx in *start..*end {
475 let word_idx = idx / 64;
476 let bit_idx = idx % 64;
477 bits[word_idx] |= 1usize << bit_idx;
478 }
479 }
480 let view = BitView::new(&bits);
481
482 let mut bitview_ones = Vec::new();
484 view.iter_ones(|idx| bitview_ones.push(idx));
485
486 let mut expected_indices = Vec::new();
488 for (start, end) in &slices {
489 expected_indices.extend(*start..*end);
490 }
491
492 assert_eq!(bitview_ones, expected_indices);
493 assert_eq!(view.true_count(), expected_indices.len());
494 }
495
496 #[test]
497 fn test_mask_and_bitview_iter_match() {
498 let mut bits = [0usize; N_WORDS];
500 bits[0] = 0xAAAAAAAAAAAAAAAA; bits[1] = 0xFF00FF00FF00FF00; let view = BitView::new(&bits);
504
505 let mut bitview_ones = Vec::new();
507 view.iter_ones(|idx| bitview_ones.push(idx));
508
509 let mask = Mask::from_indices(N, bitview_ones.clone());
511
512 mask.iter_bools(|iter| {
514 let mask_bools: Vec<bool> = iter.collect();
515
516 for i in 0..N {
518 let expected = bitview_ones.contains(&i);
519 assert_eq!(mask_bools[i], expected, "Mismatch at index {}", i);
520 }
521 });
522 }
523
524 #[test]
525 fn test_mask_and_bitview_all_true() {
526 let mask = Mask::AllTrue(5);
527
528 let vector = BitVector::true_until(5);
529
530 let view = vector.as_view();
531
532 let mut bitview_ones = Vec::new();
534 view.iter_ones(|idx| bitview_ones.push(idx));
535
536 let mask_ones = mask.iter_bools(|iter| {
538 iter.enumerate()
539 .filter(|(_, b)| *b)
540 .map(|(i, _)| i)
541 .collect::<Vec<_>>()
542 });
543
544 assert_eq!(bitview_ones, mask_ones);
545 }
546
547 #[test]
548 fn test_bitview_zeros_complement_mask() {
549 let mut bits = [0usize; N_WORDS];
551 bits[0] = 0b11110000111100001111000011110000;
552
553 let view = BitView::new(&bits);
554
555 let mut bitview_ones = Vec::new();
557 let mut bitview_zeros = Vec::new();
558 view.iter_ones(|idx| bitview_ones.push(idx));
559 view.iter_zeros(|idx| bitview_zeros.push(idx));
560
561 let ones_mask = Mask::from_indices(N, bitview_ones);
563 let zeros_mask = Mask::from_indices(N, bitview_zeros);
564
565 ones_mask.iter_bools(|ones_iter| {
567 zeros_mask.iter_bools(|zeros_iter| {
568 let ones_bools: Vec<bool> = ones_iter.collect();
569 let zeros_bools: Vec<bool> = zeros_iter.collect();
570
571 for i in 0..N {
572 assert_ne!(
574 ones_bools[i], zeros_bools[i],
575 "Index {} should be in exactly one set",
576 i
577 );
578 }
579 });
580 });
581 }
582}