vortex_array/pipeline/bits/
view.rs1use std::fmt::{Debug, Formatter};
5
6use bitvec::prelude::*;
7use vortex_error::{VortexError, vortex_err};
8
9use crate::pipeline::{N, N_WORDS};
10
11#[derive(Clone, Copy)]
18pub struct BitView<'a> {
19 bits: &'a BitArray<[usize; N_WORDS], Lsb0>,
20 true_count: usize,
21}
22
23impl Debug for BitView<'_> {
24 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25 f.debug_struct("BitView")
26 .field("true_count", &self.true_count)
27 .field("bits", &self.as_raw())
28 .finish()
29 }
30}
31
32impl BitView<'static> {
33 pub fn all_true() -> Self {
34 static ALL_TRUE: [usize; N_WORDS] = [usize::MAX; N_WORDS];
35 unsafe {
36 BitView::new_unchecked(
37 std::mem::transmute::<&[usize; N_WORDS], &BitArray<[usize; N_WORDS], Lsb0>>(
38 &ALL_TRUE,
39 ),
40 N,
41 )
42 }
43 }
44
45 pub fn all_false() -> Self {
46 static ALL_FALSE: [usize; N_WORDS] = [0; N_WORDS];
47 unsafe {
48 BitView::new_unchecked(
49 std::mem::transmute::<&[usize; N_WORDS], &BitArray<[usize; N_WORDS], Lsb0>>(
50 &ALL_FALSE,
51 ),
52 0,
53 )
54 }
55 }
56}
57
58impl<'a> BitView<'a> {
59 pub fn new(bits: &[usize; N_WORDS]) -> Self {
60 let true_count = bits.iter().map(|&word| word.count_ones() as usize).sum();
61 let bits: &BitArray<[usize; N_WORDS], Lsb0> = unsafe {
62 std::mem::transmute::<&[usize; N_WORDS], &BitArray<[usize; N_WORDS], Lsb0>>(bits)
63 };
64 BitView { bits, true_count }
65 }
66
67 pub(crate) unsafe fn new_unchecked(
68 bits: &'a BitArray<[usize; N_WORDS], Lsb0>,
69 true_count: usize,
70 ) -> Self {
71 BitView { bits, true_count }
72 }
73
74 pub fn true_count(&self) -> usize {
76 self.true_count
77 }
78
79 pub fn iter_ones<F>(&self, mut f: F)
81 where
82 F: FnMut(usize),
83 {
84 match self.true_count {
85 0 => {}
86 N => (0..N).for_each(&mut f),
87 _ => {
88 let mut bit_idx = 0;
89 for mut raw in self.bits.into_inner() {
90 while raw != 0 {
91 let bit_pos = raw.trailing_zeros();
92 f(bit_idx + bit_pos as usize);
93 raw &= raw - 1; }
95 bit_idx += usize::BITS as usize;
96 }
97 }
98 }
99 }
100
101 pub fn iter_zeros<F>(&self, mut f: F)
103 where
104 F: FnMut(usize),
105 {
106 match self.true_count {
107 0 => (0..N).for_each(&mut f),
108 N => {}
109 _ => {
110 let mut bit_idx = 0;
111 for mut raw in self.bits.into_inner() {
112 while raw != usize::MAX {
113 let bit_pos = raw.trailing_ones();
114 f(bit_idx + bit_pos as usize);
115 raw |= 1usize << bit_pos; }
117 bit_idx += usize::BITS as usize;
118 }
119 }
120 }
121 }
122
123 pub fn iter_slices<F>(&self, mut f: F)
128 where
129 F: FnMut((usize, usize)),
130 {
131 match self.true_count {
132 0 => {}
133 N => f((0, N)),
134 _ => {
135 let mut bit_idx = 0;
136 for mut raw in self.bits.into_inner() {
137 let mut offset = 0;
138 while raw != 0 {
139 let zeros = raw.leading_zeros();
141 offset += zeros;
142 raw <<= zeros;
143
144 if offset >= 64 {
145 break;
146 }
147
148 let ones = raw.leading_ones();
150 if ones > 0 {
151 f((bit_idx + offset as usize, ones as usize));
152 offset += ones;
153 raw <<= ones;
154 }
155 }
156 bit_idx += usize::BITS as usize; }
158 }
159 }
160 }
161
162 pub fn as_raw(&self) -> &[usize; N_WORDS] {
163 let raw = self.bits.as_raw_slice();
165 unsafe { &*(raw.as_ptr() as *const [usize; N_WORDS]) }
166 }
167}
168
169impl<'a> From<&'a [usize; N_WORDS]> for BitView<'a> {
170 fn from(value: &'a [usize; N_WORDS]) -> Self {
171 Self::new(value)
172 }
173}
174
175impl<'a> From<&'a BitArray<[usize; N_WORDS], Lsb0>> for BitView<'a> {
176 fn from(bits: &'a BitArray<[usize; N_WORDS], Lsb0>) -> Self {
177 BitView::new(unsafe {
178 std::mem::transmute::<&BitArray<[usize; N_WORDS]>, &[usize; N_WORDS]>(bits)
179 })
180 }
181}
182
183impl<'a> TryFrom<&'a BitSlice<usize, Lsb0>> for BitView<'a> {
184 type Error = VortexError;
185
186 fn try_from(value: &'a BitSlice<usize, Lsb0>) -> Result<Self, Self::Error> {
187 let bits: &BitArray<[usize; N_WORDS], Lsb0> = value
188 .try_into()
189 .map_err(|e| vortex_err!("Failed to convert BitSlice to BitArray: {}", e))?;
190 Ok(BitView::new(unsafe {
191 std::mem::transmute::<&BitArray<[usize; N_WORDS]>, &[usize; N_WORDS]>(bits)
192 }))
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use vortex_mask::Mask;
199
200 use super::*;
201 use crate::pipeline::bits::BitVector;
202
203 #[test]
204 fn test_iter_ones_empty() {
205 let bits = [0usize; N_WORDS];
206 let view = BitView::new(&bits);
207
208 let mut ones = Vec::new();
209 view.iter_ones(|idx| ones.push(idx));
210
211 assert_eq!(ones, Vec::<usize>::new());
212 assert_eq!(view.true_count(), 0);
213 }
214
215 #[test]
216 fn test_iter_ones_all_set() {
217 let view = BitView::all_true();
218
219 let mut ones = Vec::new();
220 view.iter_ones(|idx| ones.push(idx));
221
222 assert_eq!(ones.len(), N);
223 assert_eq!(ones, (0..N).collect::<Vec<_>>());
224 assert_eq!(view.true_count(), N);
225 }
226
227 #[test]
228 fn test_iter_zeros_empty() {
229 let bits = [0usize; N_WORDS];
230 let view = BitView::new(&bits);
231
232 let mut zeros = Vec::new();
233 view.iter_zeros(|idx| zeros.push(idx));
234
235 assert_eq!(zeros.len(), N);
236 assert_eq!(zeros, (0..N).collect::<Vec<_>>());
237 }
238
239 #[test]
240 fn test_iter_zeros_all_set() {
241 let view = BitView::all_true();
242
243 let mut zeros = Vec::new();
244 view.iter_zeros(|idx| zeros.push(idx));
245
246 assert_eq!(zeros, Vec::<usize>::new());
247 }
248
249 #[test]
250 fn test_iter_ones_single_bit() {
251 let mut bits = [0usize; N_WORDS];
252 bits[0] = 1; let view = BitView::new(&bits);
254
255 let mut ones = Vec::new();
256 view.iter_ones(|idx| ones.push(idx));
257
258 assert_eq!(ones, vec![0]);
259 assert_eq!(view.true_count(), 1);
260 }
261
262 #[test]
263 fn test_iter_zeros_single_bit_unset() {
264 let mut bits = [usize::MAX; N_WORDS];
265 bits[0] = usize::MAX ^ 1; let view = BitView::new(&bits);
267
268 let mut zeros = Vec::new();
269 view.iter_zeros(|idx| zeros.push(idx));
270
271 assert_eq!(zeros, vec![0]);
272 }
273
274 #[test]
275 fn test_iter_ones_multiple_bits_first_word() {
276 let mut bits = [0usize; N_WORDS];
277 bits[0] = 0b1010101; let view = BitView::new(&bits);
279
280 let mut ones = Vec::new();
281 view.iter_ones(|idx| ones.push(idx));
282
283 assert_eq!(ones, vec![0, 2, 4, 6]);
284 assert_eq!(view.true_count(), 4);
285 }
286
287 #[test]
288 fn test_iter_zeros_multiple_bits_first_word() {
289 let mut bits = [usize::MAX; N_WORDS];
290 bits[0] = !0b1010101; let view = BitView::new(&bits);
292
293 let mut zeros = Vec::new();
294 view.iter_zeros(|idx| zeros.push(idx));
295
296 assert_eq!(zeros, vec![0, 2, 4, 6]);
297 }
298
299 #[test]
300 fn test_iter_ones_across_words() {
301 let mut bits = [0usize; N_WORDS];
302 bits[0] = 1 << 63; bits[1] = 1; bits[2] = 1 << 31; let view = BitView::new(&bits);
306
307 let mut ones = Vec::new();
308 view.iter_ones(|idx| ones.push(idx));
309
310 assert_eq!(ones, vec![63, 64, 159]);
311 assert_eq!(view.true_count(), 3);
312 }
313
314 #[test]
315 fn test_iter_zeros_across_words() {
316 let mut bits = [usize::MAX; N_WORDS];
317 bits[0] = !(1 << 63); bits[1] = !1; bits[2] = !(1 << 31); let view = BitView::new(&bits);
321
322 let mut zeros = Vec::new();
323 view.iter_zeros(|idx| zeros.push(idx));
324
325 assert_eq!(zeros, vec![63, 64, 159]);
326 }
327
328 #[test]
329 fn test_lsb_bit_ordering() {
330 let mut bits = [0usize; N_WORDS];
331 bits[0] = 0b11111111; let view = BitView::new(&bits);
333
334 let mut ones = Vec::new();
335 view.iter_ones(|idx| ones.push(idx));
336
337 assert_eq!(ones, vec![0, 1, 2, 3, 4, 5, 6, 7]);
338 assert_eq!(view.true_count(), 8);
339 }
340
341 #[test]
342 fn test_iter_ones_and_zeros_complement() {
343 let mut bits = [0usize; N_WORDS];
344 bits[0] = 0xAAAAAAAAAAAAAAAA; let view = BitView::new(&bits);
346
347 let mut ones = Vec::new();
348 let mut zeros = Vec::new();
349 view.iter_ones(|idx| ones.push(idx));
350 view.iter_zeros(|idx| zeros.push(idx));
351
352 let mut all_indices = ones.clone();
354 all_indices.extend(&zeros);
355 all_indices.sort_unstable();
356
357 assert_eq!(all_indices, (0..N).collect::<Vec<_>>());
358
359 for one_idx in &ones {
361 assert!(!zeros.contains(one_idx));
362 }
363 }
364
365 #[test]
366 fn test_all_false_static() {
367 let view = BitView::all_false();
368
369 let mut ones = Vec::new();
370 let mut zeros = Vec::new();
371 view.iter_ones(|idx| ones.push(idx));
372 view.iter_zeros(|idx| zeros.push(idx));
373
374 assert_eq!(ones, Vec::<usize>::new());
375 assert_eq!(zeros, (0..N).collect::<Vec<_>>());
376 assert_eq!(view.true_count(), 0);
377 }
378
379 #[test]
380 fn test_compatibility_with_mask_all_true() {
381 let mask = Mask::new_true(N);
383
384 let view = BitView::all_true();
386
387 let mut bitview_ones = Vec::new();
389 view.iter_ones(|idx| bitview_ones.push(idx));
390
391 let expected_indices: Vec<usize> = (0..N).collect();
393
394 assert_eq!(bitview_ones, expected_indices);
395 assert_eq!(view.true_count(), N);
396 }
397
398 #[test]
399 fn test_compatibility_with_mask_all_false() {
400 let mask = Mask::new_false(N);
402
403 let view = BitView::all_false();
405
406 let mut bitview_ones = Vec::new();
408 view.iter_ones(|idx| bitview_ones.push(idx));
409
410 let mut bitview_zeros = Vec::new();
412 view.iter_zeros(|idx| bitview_zeros.push(idx));
413
414 assert_eq!(bitview_ones, Vec::<usize>::new());
415 assert_eq!(bitview_zeros, (0..N).collect::<Vec<_>>());
416 assert_eq!(view.true_count(), 0);
417 }
418
419 #[test]
420 fn test_compatibility_with_mask_from_indices() {
421 let indices = vec![0, 10, 20, 63, 64, 100, 500, 1023];
423 let mask = Mask::from_indices(N, indices.clone());
424
425 let mut bits = [0usize; N_WORDS];
427 for idx in &indices {
428 let word_idx = idx / 64;
429 let bit_idx = idx % 64;
430 bits[word_idx] |= 1usize << bit_idx;
431 }
432 let view = BitView::new(&bits);
433
434 let mut bitview_ones = Vec::new();
436 view.iter_ones(|idx| bitview_ones.push(idx));
437
438 assert_eq!(bitview_ones, indices);
439 assert_eq!(view.true_count(), indices.len());
440 }
441
442 #[test]
443 fn test_compatibility_with_mask_slices() {
444 let slices = vec![(0, 10), (100, 110), (500, 510)];
446 let mask = Mask::from_slices(N, slices.clone());
447
448 let mut bits = [0usize; N_WORDS];
450 for (start, end) in &slices {
451 for idx in *start..*end {
452 let word_idx = idx / 64;
453 let bit_idx = idx % 64;
454 bits[word_idx] |= 1usize << bit_idx;
455 }
456 }
457 let view = BitView::new(&bits);
458
459 let mut bitview_ones = Vec::new();
461 view.iter_ones(|idx| bitview_ones.push(idx));
462
463 let mut expected_indices = Vec::new();
465 for (start, end) in &slices {
466 expected_indices.extend(*start..*end);
467 }
468
469 assert_eq!(bitview_ones, expected_indices);
470 assert_eq!(view.true_count(), expected_indices.len());
471 }
472
473 #[test]
474 fn test_mask_and_bitview_iter_match() {
475 let mut bits = [0usize; N_WORDS];
477 bits[0] = 0xAAAAAAAAAAAAAAAA; bits[1] = 0xFF00FF00FF00FF00; let view = BitView::new(&bits);
481
482 let mut bitview_ones = Vec::new();
484 view.iter_ones(|idx| bitview_ones.push(idx));
485
486 let mask = Mask::from_indices(N, bitview_ones.clone());
488
489 mask.iter_bools(|iter| {
491 let mask_bools: Vec<bool> = iter.collect();
492
493 for i in 0..N {
495 let expected = bitview_ones.contains(&i);
496 assert_eq!(mask_bools[i], expected, "Mismatch at index {}", i);
497 }
498 });
499 }
500
501 #[test]
502 fn test_mask_and_bitview_all_true() {
503 let mask = Mask::AllTrue(5);
504
505 let vector = BitVector::true_until(5);
506
507 let view = vector.as_view();
508
509 let mut bitview_ones = Vec::new();
511 view.iter_ones(|idx| bitview_ones.push(idx));
512
513 let mask_ones = mask.iter_bools(|iter| {
515 iter.enumerate()
516 .filter(|(_, b)| *b)
517 .map(|(i, _)| i)
518 .collect::<Vec<_>>()
519 });
520
521 assert_eq!(bitview_ones, mask_ones);
522 }
523
524 #[test]
525 fn test_bitview_zeros_complement_mask() {
526 let mut bits = [0usize; N_WORDS];
528 bits[0] = 0b11110000111100001111000011110000;
529
530 let view = BitView::new(&bits);
531
532 let mut bitview_ones = Vec::new();
534 let mut bitview_zeros = Vec::new();
535 view.iter_ones(|idx| bitview_ones.push(idx));
536 view.iter_zeros(|idx| bitview_zeros.push(idx));
537
538 let ones_mask = Mask::from_indices(N, bitview_ones);
540 let zeros_mask = Mask::from_indices(N, bitview_zeros);
541
542 ones_mask.iter_bools(|ones_iter| {
544 zeros_mask.iter_bools(|zeros_iter| {
545 let ones_bools: Vec<bool> = ones_iter.collect();
546 let zeros_bools: Vec<bool> = zeros_iter.collect();
547
548 for i in 0..N {
549 assert_ne!(
551 ones_bools[i], zeros_bools[i],
552 "Index {} should be in exactly one set",
553 i
554 );
555 }
556 });
557 });
558 }
559}