1#![deny(missing_docs)]
6
7mod bitops;
8mod eq;
9mod intersect_by_rank;
10mod iter_bools;
11mod mask_mut;
12
13#[cfg(feature = "arrow")]
14mod arrow;
15#[cfg(test)]
16mod tests;
17
18use std::cmp::Ordering;
19use std::fmt::Debug;
20use std::fmt::Formatter;
21use std::ops::Bound;
22use std::ops::RangeBounds;
23use std::sync::Arc;
24use std::sync::OnceLock;
25
26use itertools::Itertools;
27pub use mask_mut::*;
28use vortex_buffer::BitBuffer;
29use vortex_buffer::BitBufferMut;
30use vortex_buffer::set_bit_unchecked;
31use vortex_error::VortexResult;
32use vortex_error::vortex_panic;
33
34pub enum AllOr<T> {
36 All,
38 None,
40 Some(T),
42}
43
44impl<T> AllOr<T> {
45 #[inline]
47 pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
48 where
49 F: FnOnce() -> T,
50 G: FnOnce() -> T,
51 {
52 match self {
53 Self::Some(v) => v,
54 AllOr::All => all_true(),
55 AllOr::None => all_false(),
56 }
57 }
58}
59
60impl<T> AllOr<&T> {
61 #[inline]
63 pub fn cloned(self) -> AllOr<T>
64 where
65 T: Clone,
66 {
67 match self {
68 Self::All => AllOr::All,
69 Self::None => AllOr::None,
70 Self::Some(v) => AllOr::Some(v.clone()),
71 }
72 }
73}
74
75impl<T> Debug for AllOr<T>
76where
77 T: Debug,
78{
79 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
80 match self {
81 Self::All => f.write_str("All"),
82 Self::None => f.write_str("None"),
83 Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
84 }
85 }
86}
87
88impl<T> PartialEq for AllOr<T>
89where
90 T: PartialEq,
91{
92 fn eq(&self, other: &Self) -> bool {
93 match (self, other) {
94 (Self::All, Self::All) => true,
95 (Self::None, Self::None) => true,
96 (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
97 _ => false,
98 }
99 }
100}
101
102impl<T> Eq for AllOr<T> where T: Eq {}
103
104#[derive(Clone)]
110#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
111pub enum Mask {
112 AllTrue(usize),
114 AllFalse(usize),
116 Values(Arc<MaskValues>),
118}
119
120impl Debug for Mask {
121 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122 match self {
123 Self::AllTrue(len) => write!(f, "All true({len})"),
124 Self::AllFalse(len) => write!(f, "All false({len})"),
125 Self::Values(mask) => write!(f, "{mask:?}"),
126 }
127 }
128}
129
130impl Default for Mask {
131 fn default() -> Self {
132 Self::new_true(0)
133 }
134}
135
136#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
138pub struct MaskValues {
139 buffer: BitBuffer,
140
141 #[cfg_attr(feature = "serde", serde(skip))]
144 indices: OnceLock<Vec<usize>>,
145 #[cfg_attr(feature = "serde", serde(skip))]
146 slices: OnceLock<Vec<(usize, usize)>>,
147
148 true_count: usize,
150 density: f64,
152}
153
154impl Debug for MaskValues {
155 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
156 write!(f, "true_count={}, ", self.true_count)?;
157 write!(f, "density={}, ", self.density)?;
158 if let Some(v) = self.indices.get() {
159 write!(f, "indices={v:?}, ")?;
160 }
161 if let Some(v) = self.slices.get() {
162 write!(f, "slices={v:?}, ")?;
163 }
164 if f.alternate() {
165 f.write_str("\n")?;
166 }
167 write!(f, "{}", self.buffer)
168 }
169}
170
171impl Mask {
172 pub fn new(length: usize, value: bool) -> Self {
174 if value {
175 Self::AllTrue(length)
176 } else {
177 Self::AllFalse(length)
178 }
179 }
180
181 #[inline]
183 pub fn new_true(length: usize) -> Self {
184 Self::AllTrue(length)
185 }
186
187 #[inline]
189 pub fn new_false(length: usize) -> Self {
190 Self::AllFalse(length)
191 }
192
193 pub fn from_buffer(buffer: BitBuffer) -> Self {
195 let len = buffer.len();
196 let true_count = buffer.true_count();
197
198 if true_count == 0 {
199 return Self::AllFalse(len);
200 }
201 if true_count == len {
202 return Self::AllTrue(len);
203 }
204
205 Self::Values(Arc::new(MaskValues {
206 buffer,
207 indices: Default::default(),
208 slices: Default::default(),
209 true_count,
210 density: true_count as f64 / len as f64,
211 }))
212 }
213
214 pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
217 let true_count = indices.len();
218 assert!(indices.is_sorted(), "Mask indices must be sorted");
219 assert!(
220 indices.last().is_none_or(|&idx| idx < len),
221 "Mask indices must be in bounds (len={len})"
222 );
223
224 if true_count == 0 {
225 return Self::AllFalse(len);
226 }
227 if true_count == len {
228 return Self::AllTrue(len);
229 }
230
231 let mut buf = BitBufferMut::new_unset(len);
232 indices.iter().for_each(|&idx| buf.set(idx));
234 debug_assert_eq!(buf.len(), len);
235
236 Self::Values(Arc::new(MaskValues {
237 buffer: buf.freeze(),
238 indices: OnceLock::from(indices),
239 slices: Default::default(),
240 true_count,
241 density: true_count as f64 / len as f64,
242 }))
243 }
244
245 pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
247 let mut buf = BitBufferMut::new_set(len);
248
249 let mut false_count: usize = 0;
250 indices.into_iter().for_each(|idx| {
251 buf.unset(idx);
252 false_count += 1;
253 });
254 debug_assert_eq!(buf.len(), len);
255 let true_count = len - false_count;
256
257 if false_count == 0 {
259 return Self::AllTrue(len);
260 }
261 if false_count == len {
262 return Self::AllFalse(len);
263 }
264
265 Self::Values(Arc::new(MaskValues {
266 buffer: buf.freeze(),
267 indices: Default::default(),
268 slices: Default::default(),
269 true_count,
270 density: true_count as f64 / len as f64,
271 }))
272 }
273
274 pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
277 Self::check_slices(len, &vec);
278 Self::from_slices_unchecked(len, vec)
279 }
280
281 fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
282 #[cfg(debug_assertions)]
283 Self::check_slices(len, &slices);
284
285 let true_count = slices.iter().map(|(b, e)| e - b).sum();
286 if true_count == 0 {
287 return Self::AllFalse(len);
288 }
289 if true_count == len {
290 return Self::AllTrue(len);
291 }
292
293 let mut buf = BitBufferMut::new_unset(len);
294 for (start, end) in slices.iter().copied() {
295 (start..end).for_each(|idx| buf.set(idx));
296 }
297 debug_assert_eq!(buf.len(), len);
298
299 Self::Values(Arc::new(MaskValues {
300 buffer: buf.freeze(),
301 indices: Default::default(),
302 slices: OnceLock::from(slices),
303 true_count,
304 density: true_count as f64 / len as f64,
305 }))
306 }
307
308 #[inline(always)]
309 fn check_slices(len: usize, vec: &[(usize, usize)]) {
310 assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
311 for (first, second) in vec.iter().tuple_windows() {
312 assert!(
313 first.0 < second.0,
314 "Slices must be sorted, got {first:?} and {second:?}"
315 );
316 assert!(
317 first.1 <= second.0,
318 "Slices must be non-overlapping, got {first:?} and {second:?}"
319 );
320 }
321 }
322
323 pub fn from_intersection_indices(
325 len: usize,
326 lhs: impl Iterator<Item = usize>,
327 rhs: impl Iterator<Item = usize>,
328 ) -> Self {
329 let mut intersection = Vec::with_capacity(len);
330 let mut lhs = lhs.peekable();
331 let mut rhs = rhs.peekable();
332 while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
333 match l.cmp(&r) {
334 Ordering::Less => {
335 lhs.next();
336 }
337 Ordering::Greater => {
338 rhs.next();
339 }
340 Ordering::Equal => {
341 intersection.push(l);
342 lhs.next();
343 rhs.next();
344 }
345 }
346 }
347 Self::from_indices(len, intersection)
348 }
349
350 pub fn clear(&mut self) {
352 *self = Self::new_false(0);
353 }
354
355 #[inline]
357 pub fn len(&self) -> usize {
358 match self {
359 Self::AllTrue(len) => *len,
360 Self::AllFalse(len) => *len,
361 Self::Values(values) => values.len(),
362 }
363 }
364
365 #[inline]
367 pub fn is_empty(&self) -> bool {
368 match self {
369 Self::AllTrue(len) => *len == 0,
370 Self::AllFalse(len) => *len == 0,
371 Self::Values(values) => values.is_empty(),
372 }
373 }
374
375 #[inline]
377 pub fn true_count(&self) -> usize {
378 match &self {
379 Self::AllTrue(len) => *len,
380 Self::AllFalse(_) => 0,
381 Self::Values(values) => values.true_count,
382 }
383 }
384
385 #[inline]
387 pub fn false_count(&self) -> usize {
388 match &self {
389 Self::AllTrue(_) => 0,
390 Self::AllFalse(len) => *len,
391 Self::Values(values) => values.buffer.len() - values.true_count,
392 }
393 }
394
395 #[inline]
397 pub fn all_true(&self) -> bool {
398 match &self {
399 Self::AllTrue(_) => true,
400 Self::AllFalse(0) => true,
401 Self::AllFalse(_) => false,
402 Self::Values(values) => values.buffer.len() == values.true_count,
403 }
404 }
405
406 #[inline]
408 pub fn all_false(&self) -> bool {
409 self.true_count() == 0
410 }
411
412 #[inline]
414 pub fn density(&self) -> f64 {
415 match &self {
416 Self::AllTrue(_) => 1.0,
417 Self::AllFalse(_) => 0.0,
418 Self::Values(values) => values.density,
419 }
420 }
421
422 #[inline]
428 pub fn value(&self, idx: usize) -> bool {
429 match self {
430 Mask::AllTrue(_) => true,
431 Mask::AllFalse(_) => false,
432 Mask::Values(values) => values.buffer.value(idx),
433 }
434 }
435
436 pub fn first(&self) -> Option<usize> {
438 match &self {
439 Self::AllTrue(len) => (*len > 0).then_some(0),
440 Self::AllFalse(_) => None,
441 Self::Values(values) => {
442 if let Some(indices) = values.indices.get() {
443 return indices.first().copied();
444 }
445 if let Some(slices) = values.slices.get() {
446 return slices.first().map(|(start, _)| *start);
447 }
448 values.buffer.set_indices().next()
449 }
450 }
451 }
452
453 pub fn last(&self) -> Option<usize> {
455 match &self {
456 Self::AllTrue(len) => (*len > 0).then_some(*len - 1),
457 Self::AllFalse(_) => None,
458 Self::Values(values) => {
459 if let Some(indices) = values.indices.get() {
460 return indices.last().copied();
461 }
462 if let Some(slices) = values.slices.get() {
463 return slices.last().map(|(_, end)| end - 1);
464 }
465 values.buffer.set_slices().last().map(|(_, end)| end - 1)
466 }
467 }
468 }
469
470 pub fn rank(&self, n: usize) -> usize {
472 if n >= self.true_count() {
473 vortex_panic!(
474 "Rank {n} out of bounds for mask with true count {}",
475 self.true_count()
476 );
477 }
478 match &self {
479 Self::AllTrue(_) => n,
480 Self::AllFalse(_) => unreachable!("no true values in all-false mask"),
481 Self::Values(values) => values.indices()[n],
483 }
484 }
485
486 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
488 let start = match range.start_bound() {
489 Bound::Included(&s) => s,
490 Bound::Excluded(&s) => s + 1,
491 Bound::Unbounded => 0,
492 };
493 let end = match range.end_bound() {
494 Bound::Included(&e) => e + 1,
495 Bound::Excluded(&e) => e,
496 Bound::Unbounded => self.len(),
497 };
498
499 assert!(start <= end);
500 assert!(start <= self.len());
501 assert!(end <= self.len());
502 let len = end - start;
503
504 match &self {
505 Self::AllTrue(_) => Self::new_true(len),
506 Self::AllFalse(_) => Self::new_false(len),
507 Self::Values(values) => Self::from_buffer(values.buffer.slice(range)),
508 }
509 }
510
511 #[inline]
513 pub fn bit_buffer(&self) -> AllOr<&BitBuffer> {
514 match &self {
515 Self::AllTrue(_) => AllOr::All,
516 Self::AllFalse(_) => AllOr::None,
517 Self::Values(values) => AllOr::Some(&values.buffer),
518 }
519 }
520
521 #[inline]
524 pub fn to_bit_buffer(&self) -> BitBuffer {
525 match self {
526 Self::AllTrue(l) => BitBuffer::new_set(*l),
527 Self::AllFalse(l) => BitBuffer::new_unset(*l),
528 Self::Values(values) => values.bit_buffer().clone(),
529 }
530 }
531
532 #[inline]
535 pub fn into_bit_buffer(self) -> BitBuffer {
536 match self {
537 Self::AllTrue(l) => BitBuffer::new_set(l),
538 Self::AllFalse(l) => BitBuffer::new_unset(l),
539 Self::Values(values) => Arc::try_unwrap(values)
540 .map(|v| v.into_bit_buffer())
541 .unwrap_or_else(|v| v.bit_buffer().clone()),
542 }
543 }
544
545 #[inline]
547 pub fn indices(&self) -> AllOr<&[usize]> {
548 match &self {
549 Self::AllTrue(_) => AllOr::All,
550 Self::AllFalse(_) => AllOr::None,
551 Self::Values(values) => AllOr::Some(values.indices()),
552 }
553 }
554
555 #[inline]
557 pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
558 match &self {
559 Self::AllTrue(_) => AllOr::All,
560 Self::AllFalse(_) => AllOr::None,
561 Self::Values(values) => AllOr::Some(values.slices()),
562 }
563 }
564
565 #[inline]
567 pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
568 match &self {
569 Self::AllTrue(_) => AllOr::All,
570 Self::AllFalse(_) => AllOr::None,
571 Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
572 }
573 }
574
575 #[inline]
577 pub fn values(&self) -> Option<&MaskValues> {
578 if let Self::Values(values) = self {
579 Some(values)
580 } else {
581 None
582 }
583 }
584
585 pub fn valid_counts_for_indices(&self, indices: &[usize]) -> Vec<usize> {
590 match self {
591 Self::AllTrue(_) => indices.to_vec(),
592 Self::AllFalse(_) => vec![0; indices.len()],
593 Self::Values(values) => {
594 let mut bool_iter = values.bit_buffer().iter();
595 let mut valid_counts = Vec::with_capacity(indices.len());
596 let mut valid_count = 0;
597 let mut idx = 0;
598 for &next_idx in indices {
599 while idx < next_idx {
600 idx += 1;
601 valid_count += bool_iter
602 .next()
603 .unwrap_or_else(|| vortex_panic!("Row indices exceed array length"))
604 as usize;
605 }
606 valid_counts.push(valid_count);
607 }
608
609 valid_counts
610 }
611 }
612 }
613
614 pub fn limit(self, limit: usize) -> Self {
616 if self.len() <= limit {
620 return self;
621 }
622
623 match self {
624 Mask::AllTrue(len) => {
625 Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
626 }
627 Mask::AllFalse(_) => self,
628 Mask::Values(ref mask_values) => {
629 if limit >= mask_values.true_count() {
630 return self;
631 }
632
633 let existing_buffer = mask_values.bit_buffer();
634
635 let mut new_buffer_builder = BitBufferMut::new_unset(mask_values.len());
636 debug_assert!(limit < mask_values.len());
637
638 let ptr = new_buffer_builder.as_mut_ptr();
639 for index in existing_buffer.set_indices().take(limit) {
640 unsafe { set_bit_unchecked(ptr, index) }
643 }
644
645 Self::from(new_buffer_builder.freeze())
646 }
647 }
648 }
649
650 pub fn concat<'a>(masks: impl Iterator<Item = &'a Self>) -> VortexResult<Self> {
652 let masks: Vec<_> = masks.collect();
653 let len = masks.iter().map(|t| t.len()).sum();
654
655 if masks.iter().all(|t| t.all_true()) {
656 return Ok(Mask::AllTrue(len));
657 }
658
659 if masks.iter().all(|t| t.all_false()) {
660 return Ok(Mask::AllFalse(len));
661 }
662
663 let mut builder = BitBufferMut::with_capacity(len);
664
665 for mask in masks {
666 match mask {
667 Mask::AllTrue(n) => builder.append_n(true, *n),
668 Mask::AllFalse(n) => builder.append_n(false, *n),
669 Mask::Values(v) => builder.append_buffer(v.bit_buffer()),
670 }
671 }
672
673 Ok(Mask::from_buffer(builder.freeze()))
674 }
675}
676
677impl MaskValues {
678 #[inline]
680 pub fn len(&self) -> usize {
681 self.buffer.len()
682 }
683
684 #[inline]
686 pub fn is_empty(&self) -> bool {
687 self.buffer.is_empty()
688 }
689
690 #[inline]
692 pub fn density(&self) -> f64 {
693 self.density
694 }
695
696 #[inline]
698 pub fn true_count(&self) -> usize {
699 self.true_count
700 }
701
702 #[inline]
704 pub fn bit_buffer(&self) -> &BitBuffer {
705 &self.buffer
706 }
707
708 #[inline]
710 pub fn into_bit_buffer(self) -> BitBuffer {
711 self.buffer
712 }
713
714 #[inline]
716 pub fn value(&self, index: usize) -> bool {
717 self.buffer.value(index)
718 }
719
720 pub fn indices(&self) -> &[usize] {
722 self.indices.get_or_init(|| {
723 if self.true_count == 0 {
724 return vec![];
725 }
726
727 if self.true_count == self.len() {
728 return (0..self.len()).collect();
729 }
730
731 if let Some(slices) = self.slices.get() {
732 let mut indices = Vec::with_capacity(self.true_count);
733 indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
734 debug_assert!(indices.is_sorted());
735 assert_eq!(indices.len(), self.true_count);
736 return indices;
737 }
738
739 let mut indices = Vec::with_capacity(self.true_count);
740 indices.extend(self.buffer.set_indices());
741 debug_assert!(indices.is_sorted());
742 assert_eq!(indices.len(), self.true_count);
743 indices
744 })
745 }
746
747 #[inline]
749 pub fn slices(&self) -> &[(usize, usize)] {
750 self.slices.get_or_init(|| {
751 if self.true_count == self.len() {
752 return vec![(0, self.len())];
753 }
754
755 self.buffer.set_slices().collect()
756 })
757 }
758
759 #[inline]
761 pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
762 if self.density >= threshold {
763 MaskIter::Slices(self.slices())
764 } else {
765 MaskIter::Indices(self.indices())
766 }
767 }
768
769 pub(crate) fn into_buffer(self) -> BitBuffer {
771 self.buffer
772 }
773}
774
775pub enum MaskIter<'a> {
777 Indices(&'a [usize]),
779 Slices(&'a [(usize, usize)]),
781}
782
783impl From<BitBuffer> for Mask {
784 fn from(value: BitBuffer) -> Self {
785 Self::from_buffer(value)
786 }
787}
788
789impl FromIterator<bool> for Mask {
790 #[inline]
791 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
792 Self::from_buffer(BitBuffer::from_iter(iter))
793 }
794}
795
796impl FromIterator<Mask> for Mask {
797 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
798 let masks = iter
799 .into_iter()
800 .filter(|m| !m.is_empty())
801 .collect::<Vec<_>>();
802 let total_length = masks.iter().map(|v| v.len()).sum();
803
804 if masks.iter().all(|v| v.all_true()) {
806 return Self::AllTrue(total_length);
807 }
808 if masks.iter().all(|v| v.all_false()) {
810 return Self::AllFalse(total_length);
811 }
812
813 let mut buffer = BitBufferMut::with_capacity(total_length);
815 for mask in masks {
816 match mask {
817 Mask::AllTrue(count) => buffer.append_n(true, count),
818 Mask::AllFalse(count) => buffer.append_n(false, count),
819 Mask::Values(values) => {
820 buffer.append_buffer(values.bit_buffer());
821 }
822 };
823 }
824 Self::from_buffer(buffer.freeze())
825 }
826}