1#![deny(missing_docs)]
3mod bitops;
4mod eq;
5mod intersect_by_rank;
6mod iter_bools;
7
8use std::cmp::Ordering;
9use std::fmt::{Debug, Formatter};
10use std::sync::{Arc, OnceLock};
11
12use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};
13use itertools::Itertools;
14use vortex_error::{VortexResult, vortex_err};
15
16pub enum AllOr<T> {
18 All,
20 None,
22 Some(T),
24}
25
26impl<T> AllOr<T> {
27 pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
29 where
30 F: FnOnce() -> T,
31 G: FnOnce() -> T,
32 {
33 match self {
34 Self::Some(v) => v,
35 AllOr::All => all_true(),
36 AllOr::None => all_false(),
37 }
38 }
39}
40
41impl<T> AllOr<&T> {
42 pub fn cloned(self) -> AllOr<T>
44 where
45 T: Clone,
46 {
47 match self {
48 Self::All => AllOr::All,
49 Self::None => AllOr::None,
50 Self::Some(v) => AllOr::Some(v.clone()),
51 }
52 }
53}
54
55impl<T> Debug for AllOr<T>
56where
57 T: Debug,
58{
59 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60 match self {
61 Self::All => f.write_str("All"),
62 Self::None => f.write_str("None"),
63 Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
64 }
65 }
66}
67
68impl<T> PartialEq for AllOr<T>
69where
70 T: PartialEq,
71{
72 fn eq(&self, other: &Self) -> bool {
73 match (self, other) {
74 (Self::All, Self::All) => true,
75 (Self::None, Self::None) => true,
76 (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
77 _ => false,
78 }
79 }
80}
81
82impl<T> Eq for AllOr<T> where T: Eq {}
83
84#[derive(Clone, Debug)]
89pub enum Mask {
90 AllTrue(usize),
92 AllFalse(usize),
94 Values(Arc<MaskValues>),
96}
97
98#[derive(Debug)]
100pub struct MaskValues {
101 buffer: BooleanBuffer,
102
103 indices: OnceLock<Vec<usize>>,
106 slices: OnceLock<Vec<(usize, usize)>>,
107
108 true_count: usize,
110 density: f64,
112}
113
114impl MaskValues {
115 #[inline]
117 #[allow(clippy::len_without_is_empty)]
118 pub fn len(&self) -> usize {
119 self.buffer.len()
120 }
121
122 pub fn true_count(&self) -> usize {
124 self.true_count
125 }
126
127 pub fn boolean_buffer(&self) -> &BooleanBuffer {
129 &self.buffer
130 }
131
132 pub fn value(&self, index: usize) -> bool {
134 self.buffer.value(index)
135 }
136
137 pub fn indices(&self) -> &[usize] {
139 self.indices.get_or_init(|| {
140 if self.true_count == 0 {
141 return vec![];
142 }
143
144 if self.true_count == self.len() {
145 return (0..self.len()).collect();
146 }
147
148 if let Some(slices) = self.slices.get() {
149 let mut indices = Vec::with_capacity(self.true_count);
150 indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
151 debug_assert!(indices.is_sorted());
152 assert_eq!(indices.len(), self.true_count);
153 return indices;
154 }
155
156 let mut indices = Vec::with_capacity(self.true_count);
157 indices.extend(self.buffer.set_indices());
158 debug_assert!(indices.is_sorted());
159 assert_eq!(indices.len(), self.true_count);
160 indices
161 })
162 }
163
164 #[allow(clippy::cast_possible_truncation)]
166 pub fn slices(&self) -> &[(usize, usize)] {
167 self.slices.get_or_init(|| {
168 if self.true_count == self.len() {
169 return vec![(0, self.len())];
170 }
171
172 self.buffer.set_slices().collect()
173 })
174 }
175
176 pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
178 if self.density >= threshold {
179 MaskIter::Slices(self.slices())
180 } else {
181 MaskIter::Indices(self.indices())
182 }
183 }
184}
185
186impl Mask {
187 pub fn new_true(length: usize) -> Self {
189 Self::AllTrue(length)
190 }
191
192 pub fn new_false(length: usize) -> Self {
194 Self::AllFalse(length)
195 }
196
197 pub fn from_buffer(buffer: BooleanBuffer) -> Self {
199 let len = buffer.len();
200 let true_count = buffer.count_set_bits();
201
202 if true_count == 0 {
203 return Self::AllFalse(len);
204 }
205 if true_count == len {
206 return Self::AllTrue(len);
207 }
208
209 Self::Values(Arc::new(MaskValues {
210 buffer,
211 indices: Default::default(),
212 slices: Default::default(),
213 true_count,
214 density: true_count as f64 / len as f64,
215 }))
216 }
217
218 pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
221 let true_count = indices.len();
222 assert!(indices.is_sorted(), "Mask indices must be sorted");
223 assert!(
224 indices.last().is_none_or(|&idx| idx < len),
225 "Mask indices must be in bounds (len={len})"
226 );
227
228 if true_count == 0 {
229 return Self::AllFalse(len);
230 }
231 if true_count == len {
232 return Self::AllTrue(len);
233 }
234
235 let mut buf = BooleanBufferBuilder::new(len);
236 buf.append_n(len, false);
238 indices.iter().for_each(|idx| buf.set_bit(*idx, true));
239 debug_assert_eq!(buf.len(), len);
240
241 Self::Values(Arc::new(MaskValues {
242 buffer: buf.finish(),
243 indices: OnceLock::from(indices),
244 slices: Default::default(),
245 true_count,
246 density: true_count as f64 / len as f64,
247 }))
248 }
249
250 pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
252 let mut buf = BooleanBufferBuilder::new(len);
253 buf.append_n(len, true);
254
255 let mut false_count: usize = 0;
256 indices.into_iter().for_each(|idx| {
257 buf.set_bit(idx, false);
258 false_count += 1;
259 });
260 debug_assert_eq!(buf.len(), len);
261 let true_count = len - false_count;
262
263 Self::Values(Arc::new(MaskValues {
264 buffer: buf.finish(),
265 indices: Default::default(),
266 slices: Default::default(),
267 true_count,
268 density: true_count as f64 / len as f64,
269 }))
270 }
271
272 pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
275 Self::check_slices(len, &vec);
276 Self::from_slices_unchecked(len, vec)
277 }
278
279 fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
280 #[cfg(debug_assertions)]
281 Self::check_slices(len, &slices);
282
283 let true_count = slices.iter().map(|(b, e)| e - b).sum();
284 if true_count == 0 {
285 return Self::AllFalse(len);
286 }
287 if true_count == len {
288 return Self::AllTrue(len);
289 }
290
291 let mut buf = BooleanBufferBuilder::new(len);
292 for (start, end) in slices.iter().copied() {
293 buf.append_n(start - buf.len(), false);
294 buf.append_n(end - start, true);
295 }
296 if let Some((_, end)) = slices.last() {
297 buf.append_n(len - end, false);
298 }
299 debug_assert_eq!(buf.len(), len);
300
301 Self::Values(Arc::new(MaskValues {
302 buffer: buf.finish(),
303 indices: Default::default(),
304 slices: OnceLock::from(slices),
305 true_count,
306 density: true_count as f64 / len as f64,
307 }))
308 }
309
310 #[inline(always)]
311 fn check_slices(len: usize, vec: &[(usize, usize)]) {
312 assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
313 for (first, second) in vec.iter().tuple_windows() {
314 assert!(
315 first.0 < second.0,
316 "Slices must be sorted, got {first:?} and {second:?}"
317 );
318 assert!(
319 first.1 <= second.0,
320 "Slices must be non-overlapping, got {first:?} and {second:?}"
321 );
322 }
323 }
324
325 pub fn from_intersection_indices(
327 len: usize,
328 lhs: impl Iterator<Item = usize>,
329 rhs: impl Iterator<Item = usize>,
330 ) -> Self {
331 let mut intersection = Vec::with_capacity(len);
332 let mut lhs = lhs.peekable();
333 let mut rhs = rhs.peekable();
334 while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
335 match l.cmp(&r) {
336 Ordering::Less => {
337 lhs.next();
338 }
339 Ordering::Greater => {
340 rhs.next();
341 }
342 Ordering::Equal => {
343 intersection.push(l);
344 lhs.next();
345 rhs.next();
346 }
347 }
348 }
349 Self::from_indices(len, intersection)
350 }
351
352 #[inline]
354 #[allow(clippy::len_without_is_empty)]
356 pub fn len(&self) -> usize {
357 match &self {
358 Self::AllTrue(len) => *len,
359 Self::AllFalse(len) => *len,
360 Self::Values(values) => values.len(),
361 }
362 }
363
364 #[inline]
366 pub fn true_count(&self) -> usize {
367 match &self {
368 Self::AllTrue(len) => *len,
369 Self::AllFalse(_) => 0,
370 Self::Values(values) => values.true_count,
371 }
372 }
373
374 #[inline]
376 pub fn false_count(&self) -> usize {
377 match &self {
378 Self::AllTrue(_) => 0,
379 Self::AllFalse(len) => *len,
380 Self::Values(values) => values.buffer.len() - values.true_count,
381 }
382 }
383
384 #[inline]
386 pub fn all_true(&self) -> bool {
387 match &self {
388 Self::AllTrue(_) => true,
389 Self::AllFalse(_) => false,
390 Self::Values(values) => values.buffer.len() == values.true_count,
391 }
392 }
393
394 #[inline]
396 pub fn all_false(&self) -> bool {
397 self.true_count() == 0
398 }
399
400 #[inline]
402 pub fn density(&self) -> f64 {
403 match &self {
404 Self::AllTrue(_) => 1.0,
405 Self::AllFalse(_) => 0.0,
406 Self::Values(values) => values.density,
407 }
408 }
409
410 pub fn value(&self, idx: usize) -> bool {
416 match self {
417 Mask::AllTrue(_) => true,
418 Mask::AllFalse(_) => false,
419 Mask::Values(values) => values.buffer.value(idx),
420 }
421 }
422
423 pub fn first(&self) -> Option<usize> {
425 match &self {
426 Self::AllTrue(len) => (*len > 0).then_some(0),
427 Self::AllFalse(_) => None,
428 Self::Values(values) => {
429 if let Some(indices) = values.indices.get() {
430 return indices.first().copied();
431 }
432 if let Some(slices) = values.slices.get() {
433 return slices.first().map(|(start, _)| *start);
434 }
435 values.buffer.set_indices().next()
436 }
437 }
438 }
439
440 pub fn slice(&self, offset: usize, length: usize) -> Self {
442 assert!(offset + length <= self.len());
443 match &self {
444 Self::AllTrue(_) => Self::new_true(length),
445 Self::AllFalse(_) => Self::new_false(length),
446 Self::Values(values) => Self::from_buffer(values.buffer.slice(offset, length)),
447 }
448 }
449
450 pub fn boolean_buffer(&self) -> AllOr<&BooleanBuffer> {
452 match &self {
453 Self::AllTrue(_) => AllOr::All,
454 Self::AllFalse(_) => AllOr::None,
455 Self::Values(values) => AllOr::Some(&values.buffer),
456 }
457 }
458
459 pub fn to_boolean_buffer(&self) -> BooleanBuffer {
462 match self {
463 Self::AllTrue(l) => BooleanBuffer::new_set(*l),
464 Self::AllFalse(l) => BooleanBuffer::new_unset(*l),
465 Self::Values(values) => values.boolean_buffer().clone(),
466 }
467 }
468
469 pub fn to_null_buffer(&self) -> Option<NullBuffer> {
471 match self {
472 Mask::AllTrue(_) => None,
473 Mask::AllFalse(l) => Some(NullBuffer::new_null(*l)),
474 Mask::Values(values) => Some(NullBuffer::from(values.buffer.clone())),
475 }
476 }
477
478 pub fn indices(&self) -> AllOr<&[usize]> {
480 match &self {
481 Self::AllTrue(_) => AllOr::All,
482 Self::AllFalse(_) => AllOr::None,
483 Self::Values(values) => AllOr::Some(values.indices()),
484 }
485 }
486
487 pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
489 match &self {
490 Self::AllTrue(_) => AllOr::All,
491 Self::AllFalse(_) => AllOr::None,
492 Self::Values(values) => AllOr::Some(values.slices()),
493 }
494 }
495
496 pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
498 match &self {
499 Self::AllTrue(_) => AllOr::All,
500 Self::AllFalse(_) => AllOr::None,
501 Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
502 }
503 }
504
505 pub fn values(&self) -> Option<&MaskValues> {
507 match self {
508 Self::Values(values) => Some(values),
509 _ => None,
510 }
511 }
512
513 pub fn valid_counts_for_indices(&self, indices: &[usize]) -> VortexResult<Vec<usize>> {
518 Ok(match self {
519 Self::AllTrue(_) => indices.to_vec(),
520 Self::AllFalse(_) => vec![0; indices.len()],
521 Self::Values(values) => {
522 let mut bool_iter = values.boolean_buffer().iter();
523 let mut valid_counts = Vec::with_capacity(indices.len());
524 let mut valid_count = 0;
525 let mut idx = 0;
526 for &next_idx in indices {
527 while idx < next_idx {
528 idx += 1;
529 valid_count += bool_iter
530 .next()
531 .ok_or_else(|| vortex_err!("Row indices exceed array length"))?
532 as usize;
533 }
534 valid_counts.push(valid_count);
535 }
536
537 valid_counts
538 }
539 })
540 }
541
542 pub fn limit(self, limit: usize) -> Self {
544 if self.len() <= limit {
545 return self;
546 }
547
548 match self {
549 Mask::AllTrue(len) => {
550 Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
551 }
552 Mask::AllFalse(_) => self,
553 Mask::Values(ref mask_values) => {
554 if limit >= mask_values.true_count() {
555 return self;
556 }
557
558 let existing_buffer = mask_values.boolean_buffer();
559
560 let mut new_buffer_builder = BooleanBufferBuilder::new(mask_values.len());
561 new_buffer_builder.append_n(mask_values.len(), false);
562
563 for index in existing_buffer.set_indices().take(limit) {
564 new_buffer_builder.set_bit(index, true);
565 }
566
567 Self::from(new_buffer_builder.finish())
568 }
569 }
570 }
571}
572
573pub enum MaskIter<'a> {
575 Indices(&'a [usize]),
577 Slices(&'a [(usize, usize)]),
579}
580
581impl From<BooleanBuffer> for Mask {
582 fn from(value: BooleanBuffer) -> Self {
583 Self::from_buffer(value)
584 }
585}
586
587impl FromIterator<bool> for Mask {
588 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
589 Self::from_buffer(BooleanBuffer::from_iter(iter))
590 }
591}
592
593impl FromIterator<Mask> for Mask {
594 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
595 let masks = iter.into_iter().collect::<Vec<_>>();
596 let total_length = masks.iter().map(|v| v.len()).sum();
597
598 if masks.iter().all(|v| v.all_true()) {
600 return Self::AllTrue(total_length);
601 }
602 if masks.iter().all(|v| v.all_false()) {
604 return Self::AllFalse(total_length);
605 }
606
607 let mut buffer = BooleanBufferBuilder::new(total_length);
609 for mask in masks {
610 match mask {
611 Mask::AllTrue(count) => buffer.append_n(count, true),
612 Mask::AllFalse(count) => buffer.append_n(count, false),
613 Mask::Values(values) => {
614 buffer.append_buffer(values.boolean_buffer());
615 }
616 };
617 }
618 Self::from_buffer(buffer.finish())
619 }
620}
621
622#[cfg(test)]
623mod test {
624 use super::*;
625
626 #[test]
627 fn mask_all_true() {
628 let mask = Mask::new_true(5);
629 assert_eq!(mask.len(), 5);
630 assert_eq!(mask.true_count(), 5);
631 assert_eq!(mask.density(), 1.0);
632 assert_eq!(mask.indices(), AllOr::All);
633 assert_eq!(mask.slices(), AllOr::All);
634 assert_eq!(mask.boolean_buffer(), AllOr::All,);
635 }
636
637 #[test]
638 fn mask_all_false() {
639 let mask = Mask::new_false(5);
640 assert_eq!(mask.len(), 5);
641 assert_eq!(mask.true_count(), 0);
642 assert_eq!(mask.density(), 0.0);
643 assert_eq!(mask.indices(), AllOr::None);
644 assert_eq!(mask.slices(), AllOr::None);
645 assert_eq!(mask.boolean_buffer(), AllOr::None,);
646 }
647
648 #[test]
649 fn mask_from() {
650 let masks = [
651 Mask::from_indices(5, vec![0, 2, 3]),
652 Mask::from_slices(5, vec![(0, 1), (2, 4)]),
653 Mask::from_buffer(BooleanBuffer::from_iter([true, false, true, true, false])),
654 ];
655
656 for mask in &masks {
657 assert_eq!(mask.len(), 5);
658 assert_eq!(mask.true_count(), 3);
659 assert_eq!(mask.density(), 0.6);
660 assert_eq!(mask.indices(), AllOr::Some(&[0, 2, 3][..]));
661 assert_eq!(mask.slices(), AllOr::Some(&[(0, 1), (2, 4)][..]));
662 assert_eq!(
663 mask.boolean_buffer(),
664 AllOr::Some(&BooleanBuffer::from_iter([true, false, true, true, false]))
665 );
666 }
667 }
668
669 #[test]
670 fn limit_all_true_mask() {
671 let all_true = Mask::new_true(4);
672 let limited_mask = all_true.clone().limit(2);
673 assert_eq!(all_true.len(), limited_mask.len());
674 assert_eq!(limited_mask.true_count(), 2);
675 assert_eq!(
676 limited_mask.boolean_buffer(),
677 AllOr::Some(&BooleanBuffer::from_iter([true, true, false, false]))
678 );
679
680 let limited_mask = all_true.clone().limit(5);
681 assert_eq!(limited_mask, all_true);
682 }
683
684 #[test]
685 fn limit_mask_values() {
686 let original_mask = Mask::from_iter([true, true, false, true, false, true]);
687 let limited_mask = original_mask.clone().limit(2);
688
689 assert_eq!(
690 limited_mask.boolean_buffer(),
691 AllOr::Some(&BooleanBuffer::from_iter([
692 true, true, false, false, false, false
693 ]))
694 );
695 assert_eq!(limited_mask.true_count(), 2);
696
697 let limited_mask = original_mask.limit(3);
698
699 assert_eq!(
700 limited_mask.boolean_buffer(),
701 AllOr::Some(&BooleanBuffer::from_iter([
702 true, true, false, true, false, false
703 ]))
704 );
705 assert_eq!(limited_mask.true_count(), 3);
706
707 let original_mask = Mask::from_iter([true, true, false, true, false, true]);
708 let limited_mask = original_mask.clone().limit(100);
709
710 assert_eq!(original_mask, limited_mask);
711 }
712}