1#![deny(missing_docs)]
6mod bitops;
7mod eq;
8mod intersect_by_rank;
9mod iter_bools;
10
11use std::cmp::Ordering;
12use std::fmt::{Debug, Formatter};
13use std::sync::{Arc, OnceLock};
14
15use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};
16use itertools::Itertools;
17use vortex_error::{VortexResult, vortex_err};
18
19pub enum AllOr<T> {
21 All,
23 None,
25 Some(T),
27}
28
29impl<T> AllOr<T> {
30 pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
32 where
33 F: FnOnce() -> T,
34 G: FnOnce() -> T,
35 {
36 match self {
37 Self::Some(v) => v,
38 AllOr::All => all_true(),
39 AllOr::None => all_false(),
40 }
41 }
42}
43
44impl<T> AllOr<&T> {
45 pub fn cloned(self) -> AllOr<T>
47 where
48 T: Clone,
49 {
50 match self {
51 Self::All => AllOr::All,
52 Self::None => AllOr::None,
53 Self::Some(v) => AllOr::Some(v.clone()),
54 }
55 }
56}
57
58impl<T> Debug for AllOr<T>
59where
60 T: Debug,
61{
62 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
63 match self {
64 Self::All => f.write_str("All"),
65 Self::None => f.write_str("None"),
66 Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
67 }
68 }
69}
70
71impl<T> PartialEq for AllOr<T>
72where
73 T: PartialEq,
74{
75 fn eq(&self, other: &Self) -> bool {
76 match (self, other) {
77 (Self::All, Self::All) => true,
78 (Self::None, Self::None) => true,
79 (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
80 _ => false,
81 }
82 }
83}
84
85impl<T> Eq for AllOr<T> where T: Eq {}
86
87#[derive(Clone, Debug)]
92pub enum Mask {
93 AllTrue(usize),
95 AllFalse(usize),
97 Values(Arc<MaskValues>),
99}
100
101#[derive(Debug)]
103pub struct MaskValues {
104 buffer: BooleanBuffer,
105
106 indices: OnceLock<Vec<usize>>,
109 slices: OnceLock<Vec<(usize, usize)>>,
110
111 true_count: usize,
113 density: f64,
115}
116
117impl MaskValues {
118 #[inline]
120 pub fn len(&self) -> usize {
121 self.buffer.len()
122 }
123
124 #[inline]
126 pub fn is_empty(&self) -> bool {
127 self.buffer.is_empty()
128 }
129
130 pub fn true_count(&self) -> usize {
132 self.true_count
133 }
134
135 pub fn boolean_buffer(&self) -> &BooleanBuffer {
137 &self.buffer
138 }
139
140 pub fn value(&self, index: usize) -> bool {
142 self.buffer.value(index)
143 }
144
145 pub fn indices(&self) -> &[usize] {
147 self.indices.get_or_init(|| {
148 if self.true_count == 0 {
149 return vec![];
150 }
151
152 if self.true_count == self.len() {
153 return (0..self.len()).collect();
154 }
155
156 if let Some(slices) = self.slices.get() {
157 let mut indices = Vec::with_capacity(self.true_count);
158 indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
159 debug_assert!(indices.is_sorted());
160 assert_eq!(indices.len(), self.true_count);
161 return indices;
162 }
163
164 let mut indices = Vec::with_capacity(self.true_count);
165 indices.extend(self.buffer.set_indices());
166 debug_assert!(indices.is_sorted());
167 assert_eq!(indices.len(), self.true_count);
168 indices
169 })
170 }
171
172 #[allow(clippy::cast_possible_truncation)]
174 pub fn slices(&self) -> &[(usize, usize)] {
175 self.slices.get_or_init(|| {
176 if self.true_count == self.len() {
177 return vec![(0, self.len())];
178 }
179
180 self.buffer.set_slices().collect()
181 })
182 }
183
184 pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
186 if self.density >= threshold {
187 MaskIter::Slices(self.slices())
188 } else {
189 MaskIter::Indices(self.indices())
190 }
191 }
192}
193
194impl Mask {
195 pub fn new_true(length: usize) -> Self {
197 Self::AllTrue(length)
198 }
199
200 pub fn new_false(length: usize) -> Self {
202 Self::AllFalse(length)
203 }
204
205 pub fn from_buffer(buffer: BooleanBuffer) -> Self {
207 let len = buffer.len();
208 let true_count = buffer.count_set_bits();
209
210 if true_count == 0 {
211 return Self::AllFalse(len);
212 }
213 if true_count == len {
214 return Self::AllTrue(len);
215 }
216
217 Self::Values(Arc::new(MaskValues {
218 buffer,
219 indices: Default::default(),
220 slices: Default::default(),
221 true_count,
222 density: true_count as f64 / len as f64,
223 }))
224 }
225
226 pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
229 let true_count = indices.len();
230 assert!(indices.is_sorted(), "Mask indices must be sorted");
231 assert!(
232 indices.last().is_none_or(|&idx| idx < len),
233 "Mask indices must be in bounds (len={len})"
234 );
235
236 if true_count == 0 {
237 return Self::AllFalse(len);
238 }
239 if true_count == len {
240 return Self::AllTrue(len);
241 }
242
243 let mut buf = BooleanBufferBuilder::new(len);
244 buf.append_n(len, false);
246 indices.iter().for_each(|idx| buf.set_bit(*idx, true));
247 debug_assert_eq!(buf.len(), len);
248
249 Self::Values(Arc::new(MaskValues {
250 buffer: buf.finish(),
251 indices: OnceLock::from(indices),
252 slices: Default::default(),
253 true_count,
254 density: true_count as f64 / len as f64,
255 }))
256 }
257
258 pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
260 let mut buf = BooleanBufferBuilder::new(len);
261 buf.append_n(len, true);
262
263 let mut false_count: usize = 0;
264 indices.into_iter().for_each(|idx| {
265 buf.set_bit(idx, false);
266 false_count += 1;
267 });
268 debug_assert_eq!(buf.len(), len);
269 let true_count = len - false_count;
270
271 Self::Values(Arc::new(MaskValues {
272 buffer: buf.finish(),
273 indices: Default::default(),
274 slices: Default::default(),
275 true_count,
276 density: true_count as f64 / len as f64,
277 }))
278 }
279
280 pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
283 Self::check_slices(len, &vec);
284 Self::from_slices_unchecked(len, vec)
285 }
286
287 fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
288 #[cfg(debug_assertions)]
289 Self::check_slices(len, &slices);
290
291 let true_count = slices.iter().map(|(b, e)| e - b).sum();
292 if true_count == 0 {
293 return Self::AllFalse(len);
294 }
295 if true_count == len {
296 return Self::AllTrue(len);
297 }
298
299 let mut buf = BooleanBufferBuilder::new(len);
300 for (start, end) in slices.iter().copied() {
301 buf.append_n(start - buf.len(), false);
302 buf.append_n(end - start, true);
303 }
304 if let Some((_, end)) = slices.last() {
305 buf.append_n(len - end, false);
306 }
307 debug_assert_eq!(buf.len(), len);
308
309 Self::Values(Arc::new(MaskValues {
310 buffer: buf.finish(),
311 indices: Default::default(),
312 slices: OnceLock::from(slices),
313 true_count,
314 density: true_count as f64 / len as f64,
315 }))
316 }
317
318 #[inline(always)]
319 fn check_slices(len: usize, vec: &[(usize, usize)]) {
320 assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
321 for (first, second) in vec.iter().tuple_windows() {
322 assert!(
323 first.0 < second.0,
324 "Slices must be sorted, got {first:?} and {second:?}"
325 );
326 assert!(
327 first.1 <= second.0,
328 "Slices must be non-overlapping, got {first:?} and {second:?}"
329 );
330 }
331 }
332
333 pub fn from_intersection_indices(
335 len: usize,
336 lhs: impl Iterator<Item = usize>,
337 rhs: impl Iterator<Item = usize>,
338 ) -> Self {
339 let mut intersection = Vec::with_capacity(len);
340 let mut lhs = lhs.peekable();
341 let mut rhs = rhs.peekable();
342 while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
343 match l.cmp(&r) {
344 Ordering::Less => {
345 lhs.next();
346 }
347 Ordering::Greater => {
348 rhs.next();
349 }
350 Ordering::Equal => {
351 intersection.push(l);
352 lhs.next();
353 rhs.next();
354 }
355 }
356 }
357 Self::from_indices(len, intersection)
358 }
359
360 #[inline]
362 pub fn len(&self) -> usize {
363 match self {
364 Self::AllTrue(len) => *len,
365 Self::AllFalse(len) => *len,
366 Self::Values(values) => values.len(),
367 }
368 }
369
370 #[inline]
372 pub fn is_empty(&self) -> bool {
373 match self {
374 Self::AllTrue(len) => *len == 0,
375 Self::AllFalse(len) => *len == 0,
376 Self::Values(values) => values.is_empty(),
377 }
378 }
379
380 #[inline]
382 pub fn true_count(&self) -> usize {
383 match &self {
384 Self::AllTrue(len) => *len,
385 Self::AllFalse(_) => 0,
386 Self::Values(values) => values.true_count,
387 }
388 }
389
390 #[inline]
392 pub fn false_count(&self) -> usize {
393 match &self {
394 Self::AllTrue(_) => 0,
395 Self::AllFalse(len) => *len,
396 Self::Values(values) => values.buffer.len() - values.true_count,
397 }
398 }
399
400 #[inline]
402 pub fn all_true(&self) -> bool {
403 match &self {
404 Self::AllTrue(_) => true,
405 Self::AllFalse(0) => true,
406 Self::AllFalse(_) => false,
407 Self::Values(values) => values.buffer.len() == values.true_count,
408 }
409 }
410
411 #[inline]
413 pub fn all_false(&self) -> bool {
414 self.true_count() == 0
415 }
416
417 #[inline]
419 pub fn density(&self) -> f64 {
420 match &self {
421 Self::AllTrue(_) => 1.0,
422 Self::AllFalse(_) => 0.0,
423 Self::Values(values) => values.density,
424 }
425 }
426
427 pub fn value(&self, idx: usize) -> bool {
433 match self {
434 Mask::AllTrue(_) => true,
435 Mask::AllFalse(_) => false,
436 Mask::Values(values) => values.buffer.value(idx),
437 }
438 }
439
440 pub fn first(&self) -> Option<usize> {
442 match &self {
443 Self::AllTrue(len) => (*len > 0).then_some(0),
444 Self::AllFalse(_) => None,
445 Self::Values(values) => {
446 if let Some(indices) = values.indices.get() {
447 return indices.first().copied();
448 }
449 if let Some(slices) = values.slices.get() {
450 return slices.first().map(|(start, _)| *start);
451 }
452 values.buffer.set_indices().next()
453 }
454 }
455 }
456
457 pub fn slice(&self, offset: usize, length: usize) -> Self {
459 assert!(offset + length <= self.len());
460 match &self {
461 Self::AllTrue(_) => Self::new_true(length),
462 Self::AllFalse(_) => Self::new_false(length),
463 Self::Values(values) => Self::from_buffer(values.buffer.slice(offset, length)),
464 }
465 }
466
467 pub fn boolean_buffer(&self) -> AllOr<&BooleanBuffer> {
469 match &self {
470 Self::AllTrue(_) => AllOr::All,
471 Self::AllFalse(_) => AllOr::None,
472 Self::Values(values) => AllOr::Some(&values.buffer),
473 }
474 }
475
476 pub fn to_boolean_buffer(&self) -> BooleanBuffer {
479 match self {
480 Self::AllTrue(l) => BooleanBuffer::new_set(*l),
481 Self::AllFalse(l) => BooleanBuffer::new_unset(*l),
482 Self::Values(values) => values.boolean_buffer().clone(),
483 }
484 }
485
486 pub fn to_null_buffer(&self) -> Option<NullBuffer> {
488 match self {
489 Mask::AllTrue(_) => None,
490 Mask::AllFalse(l) => Some(NullBuffer::new_null(*l)),
491 Mask::Values(values) => Some(NullBuffer::from(values.buffer.clone())),
492 }
493 }
494
495 pub fn indices(&self) -> AllOr<&[usize]> {
497 match &self {
498 Self::AllTrue(_) => AllOr::All,
499 Self::AllFalse(_) => AllOr::None,
500 Self::Values(values) => AllOr::Some(values.indices()),
501 }
502 }
503
504 pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
506 match &self {
507 Self::AllTrue(_) => AllOr::All,
508 Self::AllFalse(_) => AllOr::None,
509 Self::Values(values) => AllOr::Some(values.slices()),
510 }
511 }
512
513 pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
515 match &self {
516 Self::AllTrue(_) => AllOr::All,
517 Self::AllFalse(_) => AllOr::None,
518 Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
519 }
520 }
521
522 pub fn values(&self) -> Option<&MaskValues> {
524 match self {
525 Self::Values(values) => Some(values),
526 _ => None,
527 }
528 }
529
530 pub fn valid_counts_for_indices(&self, indices: &[usize]) -> VortexResult<Vec<usize>> {
535 Ok(match self {
536 Self::AllTrue(_) => indices.to_vec(),
537 Self::AllFalse(_) => vec![0; indices.len()],
538 Self::Values(values) => {
539 let mut bool_iter = values.boolean_buffer().iter();
540 let mut valid_counts = Vec::with_capacity(indices.len());
541 let mut valid_count = 0;
542 let mut idx = 0;
543 for &next_idx in indices {
544 while idx < next_idx {
545 idx += 1;
546 valid_count += bool_iter
547 .next()
548 .ok_or_else(|| vortex_err!("Row indices exceed array length"))?
549 as usize;
550 }
551 valid_counts.push(valid_count);
552 }
553
554 valid_counts
555 }
556 })
557 }
558
559 pub fn limit(self, limit: usize) -> Self {
561 if self.len() <= limit {
562 return self;
563 }
564
565 match self {
566 Mask::AllTrue(len) => {
567 Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
568 }
569 Mask::AllFalse(_) => self,
570 Mask::Values(ref mask_values) => {
571 if limit >= mask_values.true_count() {
572 return self;
573 }
574
575 let existing_buffer = mask_values.boolean_buffer();
576
577 let mut new_buffer_builder = BooleanBufferBuilder::new(mask_values.len());
578 new_buffer_builder.append_n(mask_values.len(), false);
579
580 for index in existing_buffer.set_indices().take(limit) {
581 new_buffer_builder.set_bit(index, true);
582 }
583
584 Self::from(new_buffer_builder.finish())
585 }
586 }
587 }
588}
589
590pub enum MaskIter<'a> {
592 Indices(&'a [usize]),
594 Slices(&'a [(usize, usize)]),
596}
597
598impl From<BooleanBuffer> for Mask {
599 fn from(value: BooleanBuffer) -> Self {
600 Self::from_buffer(value)
601 }
602}
603
604impl FromIterator<bool> for Mask {
605 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
606 Self::from_buffer(BooleanBuffer::from_iter(iter))
607 }
608}
609
610impl FromIterator<Mask> for Mask {
611 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
612 let masks = iter
613 .into_iter()
614 .filter(|m| !m.is_empty())
615 .collect::<Vec<_>>();
616 let total_length = masks.iter().map(|v| v.len()).sum();
617
618 if masks.iter().all(|v| v.all_true()) {
620 return Self::AllTrue(total_length);
621 }
622 if masks.iter().all(|v| v.all_false()) {
624 return Self::AllFalse(total_length);
625 }
626
627 let mut buffer = BooleanBufferBuilder::new(total_length);
629 for mask in masks {
630 match mask {
631 Mask::AllTrue(count) => buffer.append_n(count, true),
632 Mask::AllFalse(count) => buffer.append_n(count, false),
633 Mask::Values(values) => {
634 buffer.append_buffer(values.boolean_buffer());
635 }
636 };
637 }
638 Self::from_buffer(buffer.finish())
639 }
640}
641
642#[cfg(test)]
643mod test {
644 use super::*;
645
646 #[test]
647 fn mask_all_true() {
648 let mask = Mask::new_true(5);
649 assert_eq!(mask.len(), 5);
650 assert_eq!(mask.true_count(), 5);
651 assert_eq!(mask.density(), 1.0);
652 assert_eq!(mask.indices(), AllOr::All);
653 assert_eq!(mask.slices(), AllOr::All);
654 assert_eq!(mask.boolean_buffer(), AllOr::All,);
655 }
656
657 #[test]
658 fn mask_all_false() {
659 let mask = Mask::new_false(5);
660 assert_eq!(mask.len(), 5);
661 assert_eq!(mask.true_count(), 0);
662 assert_eq!(mask.density(), 0.0);
663 assert_eq!(mask.indices(), AllOr::None);
664 assert_eq!(mask.slices(), AllOr::None);
665 assert_eq!(mask.boolean_buffer(), AllOr::None,);
666 }
667
668 #[test]
669 fn mask_from() {
670 let masks = [
671 Mask::from_indices(5, vec![0, 2, 3]),
672 Mask::from_slices(5, vec![(0, 1), (2, 4)]),
673 Mask::from_buffer(BooleanBuffer::from_iter([true, false, true, true, false])),
674 ];
675
676 for mask in &masks {
677 assert_eq!(mask.len(), 5);
678 assert_eq!(mask.true_count(), 3);
679 assert_eq!(mask.density(), 0.6);
680 assert_eq!(mask.indices(), AllOr::Some(&[0, 2, 3][..]));
681 assert_eq!(mask.slices(), AllOr::Some(&[(0, 1), (2, 4)][..]));
682 assert_eq!(
683 mask.boolean_buffer(),
684 AllOr::Some(&BooleanBuffer::from_iter([true, false, true, true, false]))
685 );
686 }
687 }
688
689 #[test]
690 fn limit_all_true_mask() {
691 let all_true = Mask::new_true(4);
692 let limited_mask = all_true.clone().limit(2);
693 assert_eq!(all_true.len(), limited_mask.len());
694 assert_eq!(limited_mask.true_count(), 2);
695 assert_eq!(
696 limited_mask.boolean_buffer(),
697 AllOr::Some(&BooleanBuffer::from_iter([true, true, false, false]))
698 );
699
700 let limited_mask = all_true.clone().limit(5);
701 assert_eq!(limited_mask, all_true);
702 }
703
704 #[test]
705 fn limit_mask_values() {
706 let original_mask = Mask::from_iter([true, true, false, true, false, true]);
707 let limited_mask = original_mask.clone().limit(2);
708
709 assert_eq!(
710 limited_mask.boolean_buffer(),
711 AllOr::Some(&BooleanBuffer::from_iter([
712 true, true, false, false, false, false
713 ]))
714 );
715 assert_eq!(limited_mask.true_count(), 2);
716
717 let limited_mask = original_mask.limit(3);
718
719 assert_eq!(
720 limited_mask.boolean_buffer(),
721 AllOr::Some(&BooleanBuffer::from_iter([
722 true, true, false, true, false, false
723 ]))
724 );
725 assert_eq!(limited_mask.true_count(), 3);
726
727 let original_mask = Mask::from_iter([true, true, false, true, false, true]);
728 let limited_mask = original_mask.clone().limit(100);
729
730 assert_eq!(original_mask, limited_mask);
731 }
732
733 #[test]
734 fn length_zero_masks() {
735 let all_false = Mask::new_false(0);
736 let all_true = Mask::new_true(0);
737 let buffer_set = Mask::from_buffer(BooleanBuffer::new_set(0));
738 let buffer_unset = Mask::from_buffer(BooleanBuffer::new_unset(0));
739
740 assert!(all_false.all_false());
741 assert!(all_false.all_true());
742 assert!(all_true.all_false());
743 assert!(all_true.all_true());
744 assert!(buffer_set.all_false());
745 assert!(buffer_set.all_true());
746 assert!(buffer_unset.all_false());
747 assert!(buffer_unset.all_true());
748 }
749}