1#![deny(missing_docs)]
6
7mod bitops;
8mod eq;
9mod intersect_by_rank;
10mod iter_bools;
11mod mask_mut;
12
13#[cfg(feature = "arrow")]
14mod arrow;
15#[cfg(test)]
16mod tests;
17
18use std::cmp::Ordering;
19use std::fmt::Debug;
20use std::fmt::Formatter;
21use std::ops::Bound;
22use std::ops::RangeBounds;
23use std::sync::Arc;
24use std::sync::OnceLock;
25
26use itertools::Itertools;
27pub use mask_mut::*;
28use vortex_buffer::BitBuffer;
29use vortex_buffer::BitBufferMut;
30use vortex_buffer::set_bit_unchecked;
31use vortex_error::VortexResult;
32use vortex_error::vortex_panic;
33
34pub enum AllOr<T> {
36 All,
38 None,
40 Some(T),
42}
43
44impl<T> AllOr<T> {
45 #[inline]
47 pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
48 where
49 F: FnOnce() -> T,
50 G: FnOnce() -> T,
51 {
52 match self {
53 Self::Some(v) => v,
54 AllOr::All => all_true(),
55 AllOr::None => all_false(),
56 }
57 }
58}
59
60impl<T> AllOr<&T> {
61 #[inline]
63 pub fn cloned(self) -> AllOr<T>
64 where
65 T: Clone,
66 {
67 match self {
68 Self::All => AllOr::All,
69 Self::None => AllOr::None,
70 Self::Some(v) => AllOr::Some(v.clone()),
71 }
72 }
73}
74
75impl<T> Debug for AllOr<T>
76where
77 T: Debug,
78{
79 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
80 match self {
81 Self::All => f.write_str("All"),
82 Self::None => f.write_str("None"),
83 Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
84 }
85 }
86}
87
88impl<T> PartialEq for AllOr<T>
89where
90 T: PartialEq,
91{
92 fn eq(&self, other: &Self) -> bool {
93 match (self, other) {
94 (Self::All, Self::All) => true,
95 (Self::None, Self::None) => true,
96 (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
97 _ => false,
98 }
99 }
100}
101
102impl<T> Eq for AllOr<T> where T: Eq {}
103
104#[derive(Clone)]
110#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
111pub enum Mask {
112 AllTrue(usize),
114 AllFalse(usize),
116 Values(Arc<MaskValues>),
118}
119
120impl Debug for Mask {
121 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122 match self {
123 Self::AllTrue(len) => write!(f, "All true({len})"),
124 Self::AllFalse(len) => write!(f, "All false({len})"),
125 Self::Values(mask) => write!(f, "{mask:?}"),
126 }
127 }
128}
129
130impl Default for Mask {
131 fn default() -> Self {
132 Self::new_true(0)
133 }
134}
135
136#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
138pub struct MaskValues {
139 buffer: BitBuffer,
140
141 #[cfg_attr(feature = "serde", serde(skip))]
144 indices: OnceLock<Vec<usize>>,
145 #[cfg_attr(feature = "serde", serde(skip))]
146 slices: OnceLock<Vec<(usize, usize)>>,
147
148 true_count: usize,
150 density: f64,
152}
153
154impl Debug for MaskValues {
155 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
156 write!(f, "true_count={}, ", self.true_count)?;
157 write!(f, "density={}, ", self.density)?;
158 if let Some(v) = self.indices.get() {
159 write!(f, "indices={v:?}, ")?;
160 }
161 if let Some(v) = self.slices.get() {
162 write!(f, "slices={v:?}, ")?;
163 }
164 if f.alternate() {
165 f.write_str("\n")?;
166 }
167 write!(f, "{}", self.buffer)
168 }
169}
170
171impl Mask {
172 pub fn new(length: usize, value: bool) -> Self {
174 if value {
175 Self::AllTrue(length)
176 } else {
177 Self::AllFalse(length)
178 }
179 }
180
181 #[inline]
183 pub fn new_true(length: usize) -> Self {
184 Self::AllTrue(length)
185 }
186
187 #[inline]
189 pub fn new_false(length: usize) -> Self {
190 Self::AllFalse(length)
191 }
192
193 pub fn from_buffer(buffer: BitBuffer) -> Self {
195 let len = buffer.len();
196 let true_count = buffer.true_count();
197
198 if true_count == 0 {
199 return Self::AllFalse(len);
200 }
201 if true_count == len {
202 return Self::AllTrue(len);
203 }
204
205 Self::Values(Arc::new(MaskValues {
206 buffer,
207 indices: Default::default(),
208 slices: Default::default(),
209 true_count,
210 density: true_count as f64 / len as f64,
211 }))
212 }
213
214 pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
217 let true_count = indices.len();
218 assert!(indices.is_sorted(), "Mask indices must be sorted");
219 assert!(
220 indices.last().is_none_or(|&idx| idx < len),
221 "Mask indices must be in bounds (len={len})"
222 );
223
224 if true_count == 0 {
225 return Self::AllFalse(len);
226 }
227 if true_count == len {
228 return Self::AllTrue(len);
229 }
230
231 let mut buf = BitBufferMut::new_unset(len);
232 indices.iter().for_each(|&idx| buf.set(idx));
234 debug_assert_eq!(buf.len(), len);
235
236 Self::Values(Arc::new(MaskValues {
237 buffer: buf.freeze(),
238 indices: OnceLock::from(indices),
239 slices: Default::default(),
240 true_count,
241 density: true_count as f64 / len as f64,
242 }))
243 }
244
245 pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
247 let mut buf = BitBufferMut::new_set(len);
248
249 let mut false_count: usize = 0;
250 indices.into_iter().for_each(|idx| {
251 buf.unset(idx);
252 false_count += 1;
253 });
254 debug_assert_eq!(buf.len(), len);
255 let true_count = len - false_count;
256
257 if false_count == 0 {
259 return Self::AllTrue(len);
260 }
261 if false_count == len {
262 return Self::AllFalse(len);
263 }
264
265 Self::Values(Arc::new(MaskValues {
266 buffer: buf.freeze(),
267 indices: Default::default(),
268 slices: Default::default(),
269 true_count,
270 density: true_count as f64 / len as f64,
271 }))
272 }
273
274 pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
277 Self::check_slices(len, &vec);
278 Self::from_slices_unchecked(len, vec)
279 }
280
281 fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
282 #[cfg(debug_assertions)]
283 Self::check_slices(len, &slices);
284
285 let true_count = slices.iter().map(|(b, e)| e - b).sum();
286 if true_count == 0 {
287 return Self::AllFalse(len);
288 }
289 if true_count == len {
290 return Self::AllTrue(len);
291 }
292
293 let mut buf = BitBufferMut::new_unset(len);
294 for (start, end) in slices.iter().copied() {
295 (start..end).for_each(|idx| buf.set(idx));
296 }
297 debug_assert_eq!(buf.len(), len);
298
299 Self::Values(Arc::new(MaskValues {
300 buffer: buf.freeze(),
301 indices: Default::default(),
302 slices: OnceLock::from(slices),
303 true_count,
304 density: true_count as f64 / len as f64,
305 }))
306 }
307
308 #[inline(always)]
309 fn check_slices(len: usize, vec: &[(usize, usize)]) {
310 assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
311 for (first, second) in vec.iter().tuple_windows() {
312 assert!(
313 first.0 < second.0,
314 "Slices must be sorted, got {first:?} and {second:?}"
315 );
316 assert!(
317 first.1 <= second.0,
318 "Slices must be non-overlapping, got {first:?} and {second:?}"
319 );
320 }
321 }
322
323 pub fn from_intersection_indices(
325 len: usize,
326 lhs: impl Iterator<Item = usize>,
327 rhs: impl Iterator<Item = usize>,
328 ) -> Self {
329 let mut intersection = Vec::with_capacity(len);
330 let mut lhs = lhs.peekable();
331 let mut rhs = rhs.peekable();
332 while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
333 match l.cmp(&r) {
334 Ordering::Less => {
335 lhs.next();
336 }
337 Ordering::Greater => {
338 rhs.next();
339 }
340 Ordering::Equal => {
341 intersection.push(l);
342 lhs.next();
343 rhs.next();
344 }
345 }
346 }
347 Self::from_indices(len, intersection)
348 }
349
350 pub fn clear(&mut self) {
352 *self = Self::new_false(0);
353 }
354
355 #[inline]
357 pub fn len(&self) -> usize {
358 match self {
359 Self::AllTrue(len) => *len,
360 Self::AllFalse(len) => *len,
361 Self::Values(values) => values.len(),
362 }
363 }
364
365 #[inline]
367 pub fn is_empty(&self) -> bool {
368 match self {
369 Self::AllTrue(len) => *len == 0,
370 Self::AllFalse(len) => *len == 0,
371 Self::Values(values) => values.is_empty(),
372 }
373 }
374
375 #[inline]
377 pub fn true_count(&self) -> usize {
378 match &self {
379 Self::AllTrue(len) => *len,
380 Self::AllFalse(_) => 0,
381 Self::Values(values) => values.true_count,
382 }
383 }
384
385 #[inline]
387 pub fn false_count(&self) -> usize {
388 match &self {
389 Self::AllTrue(_) => 0,
390 Self::AllFalse(len) => *len,
391 Self::Values(values) => values.buffer.len() - values.true_count,
392 }
393 }
394
395 #[inline]
397 pub fn all_true(&self) -> bool {
398 match &self {
399 Self::AllTrue(_) => true,
400 Self::AllFalse(0) => true,
401 Self::AllFalse(_) => false,
402 Self::Values(values) => values.buffer.len() == values.true_count,
403 }
404 }
405
406 #[inline]
408 pub fn all_false(&self) -> bool {
409 self.true_count() == 0
410 }
411
412 #[inline]
414 pub fn density(&self) -> f64 {
415 match &self {
416 Self::AllTrue(_) => 1.0,
417 Self::AllFalse(_) => 0.0,
418 Self::Values(values) => values.density,
419 }
420 }
421
422 #[inline]
428 pub fn value(&self, idx: usize) -> bool {
429 match self {
430 Mask::AllTrue(_) => true,
431 Mask::AllFalse(_) => false,
432 Mask::Values(values) => values.buffer.value(idx),
433 }
434 }
435
436 pub fn first(&self) -> Option<usize> {
438 match &self {
439 Self::AllTrue(len) => (*len > 0).then_some(0),
440 Self::AllFalse(_) => None,
441 Self::Values(values) => {
442 if let Some(indices) = values.indices.get() {
443 return indices.first().copied();
444 }
445 if let Some(slices) = values.slices.get() {
446 return slices.first().map(|(start, _)| *start);
447 }
448 values.buffer.set_indices().next()
449 }
450 }
451 }
452
453 pub fn rank(&self, n: usize) -> usize {
455 if n >= self.true_count() {
456 vortex_panic!(
457 "Rank {n} out of bounds for mask with true count {}",
458 self.true_count()
459 );
460 }
461 match &self {
462 Self::AllTrue(_) => n,
463 Self::AllFalse(_) => unreachable!("no true values in all-false mask"),
464 Self::Values(values) => values.indices()[n],
466 }
467 }
468
469 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
471 let start = match range.start_bound() {
472 Bound::Included(&s) => s,
473 Bound::Excluded(&s) => s + 1,
474 Bound::Unbounded => 0,
475 };
476 let end = match range.end_bound() {
477 Bound::Included(&e) => e + 1,
478 Bound::Excluded(&e) => e,
479 Bound::Unbounded => self.len(),
480 };
481
482 assert!(start <= end);
483 assert!(start <= self.len());
484 assert!(end <= self.len());
485 let len = end - start;
486
487 match &self {
488 Self::AllTrue(_) => Self::new_true(len),
489 Self::AllFalse(_) => Self::new_false(len),
490 Self::Values(values) => Self::from_buffer(values.buffer.slice(range)),
491 }
492 }
493
494 #[inline]
496 pub fn bit_buffer(&self) -> AllOr<&BitBuffer> {
497 match &self {
498 Self::AllTrue(_) => AllOr::All,
499 Self::AllFalse(_) => AllOr::None,
500 Self::Values(values) => AllOr::Some(&values.buffer),
501 }
502 }
503
504 #[inline]
507 pub fn to_bit_buffer(&self) -> BitBuffer {
508 match self {
509 Self::AllTrue(l) => BitBuffer::new_set(*l),
510 Self::AllFalse(l) => BitBuffer::new_unset(*l),
511 Self::Values(values) => values.bit_buffer().clone(),
512 }
513 }
514
515 #[inline]
518 pub fn into_bit_buffer(self) -> BitBuffer {
519 match self {
520 Self::AllTrue(l) => BitBuffer::new_set(l),
521 Self::AllFalse(l) => BitBuffer::new_unset(l),
522 Self::Values(values) => Arc::try_unwrap(values)
523 .map(|v| v.into_bit_buffer())
524 .unwrap_or_else(|v| v.bit_buffer().clone()),
525 }
526 }
527
528 #[inline]
530 pub fn indices(&self) -> AllOr<&[usize]> {
531 match &self {
532 Self::AllTrue(_) => AllOr::All,
533 Self::AllFalse(_) => AllOr::None,
534 Self::Values(values) => AllOr::Some(values.indices()),
535 }
536 }
537
538 #[inline]
540 pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
541 match &self {
542 Self::AllTrue(_) => AllOr::All,
543 Self::AllFalse(_) => AllOr::None,
544 Self::Values(values) => AllOr::Some(values.slices()),
545 }
546 }
547
548 #[inline]
550 pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
551 match &self {
552 Self::AllTrue(_) => AllOr::All,
553 Self::AllFalse(_) => AllOr::None,
554 Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
555 }
556 }
557
558 #[inline]
560 pub fn values(&self) -> Option<&MaskValues> {
561 if let Self::Values(values) = self {
562 Some(values)
563 } else {
564 None
565 }
566 }
567
568 pub fn valid_counts_for_indices(&self, indices: &[usize]) -> Vec<usize> {
573 match self {
574 Self::AllTrue(_) => indices.to_vec(),
575 Self::AllFalse(_) => vec![0; indices.len()],
576 Self::Values(values) => {
577 let mut bool_iter = values.bit_buffer().iter();
578 let mut valid_counts = Vec::with_capacity(indices.len());
579 let mut valid_count = 0;
580 let mut idx = 0;
581 for &next_idx in indices {
582 while idx < next_idx {
583 idx += 1;
584 valid_count += bool_iter
585 .next()
586 .unwrap_or_else(|| vortex_panic!("Row indices exceed array length"))
587 as usize;
588 }
589 valid_counts.push(valid_count);
590 }
591
592 valid_counts
593 }
594 }
595 }
596
597 pub fn limit(self, limit: usize) -> Self {
599 if self.len() <= limit {
603 return self;
604 }
605
606 match self {
607 Mask::AllTrue(len) => {
608 Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
609 }
610 Mask::AllFalse(_) => self,
611 Mask::Values(ref mask_values) => {
612 if limit >= mask_values.true_count() {
613 return self;
614 }
615
616 let existing_buffer = mask_values.bit_buffer();
617
618 let mut new_buffer_builder = BitBufferMut::new_unset(mask_values.len());
619 debug_assert!(limit < mask_values.len());
620
621 let ptr = new_buffer_builder.as_mut_ptr();
622 for index in existing_buffer.set_indices().take(limit) {
623 unsafe { set_bit_unchecked(ptr, index) }
626 }
627
628 Self::from(new_buffer_builder.freeze())
629 }
630 }
631 }
632
633 pub fn concat<'a>(masks: impl Iterator<Item = &'a Self>) -> VortexResult<Self> {
635 let masks: Vec<_> = masks.collect();
636 let len = masks.iter().map(|t| t.len()).sum();
637
638 if masks.iter().all(|t| t.all_true()) {
639 return Ok(Mask::AllTrue(len));
640 }
641
642 if masks.iter().all(|t| t.all_false()) {
643 return Ok(Mask::AllFalse(len));
644 }
645
646 let mut builder = BitBufferMut::with_capacity(len);
647
648 for mask in masks {
649 match mask {
650 Mask::AllTrue(n) => builder.append_n(true, *n),
651 Mask::AllFalse(n) => builder.append_n(false, *n),
652 Mask::Values(v) => builder.append_buffer(v.bit_buffer()),
653 }
654 }
655
656 Ok(Mask::from_buffer(builder.freeze()))
657 }
658}
659
660impl MaskValues {
661 #[inline]
663 pub fn len(&self) -> usize {
664 self.buffer.len()
665 }
666
667 #[inline]
669 pub fn is_empty(&self) -> bool {
670 self.buffer.is_empty()
671 }
672
673 #[inline]
675 pub fn density(&self) -> f64 {
676 self.density
677 }
678
679 #[inline]
681 pub fn true_count(&self) -> usize {
682 self.true_count
683 }
684
685 #[inline]
687 pub fn bit_buffer(&self) -> &BitBuffer {
688 &self.buffer
689 }
690
691 #[inline]
693 pub fn into_bit_buffer(self) -> BitBuffer {
694 self.buffer
695 }
696
697 #[inline]
699 pub fn value(&self, index: usize) -> bool {
700 self.buffer.value(index)
701 }
702
703 pub fn indices(&self) -> &[usize] {
705 self.indices.get_or_init(|| {
706 if self.true_count == 0 {
707 return vec![];
708 }
709
710 if self.true_count == self.len() {
711 return (0..self.len()).collect();
712 }
713
714 if let Some(slices) = self.slices.get() {
715 let mut indices = Vec::with_capacity(self.true_count);
716 indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
717 debug_assert!(indices.is_sorted());
718 assert_eq!(indices.len(), self.true_count);
719 return indices;
720 }
721
722 let mut indices = Vec::with_capacity(self.true_count);
723 indices.extend(self.buffer.set_indices());
724 debug_assert!(indices.is_sorted());
725 assert_eq!(indices.len(), self.true_count);
726 indices
727 })
728 }
729
730 #[inline]
732 pub fn slices(&self) -> &[(usize, usize)] {
733 self.slices.get_or_init(|| {
734 if self.true_count == self.len() {
735 return vec![(0, self.len())];
736 }
737
738 self.buffer.set_slices().collect()
739 })
740 }
741
742 #[inline]
744 pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
745 if self.density >= threshold {
746 MaskIter::Slices(self.slices())
747 } else {
748 MaskIter::Indices(self.indices())
749 }
750 }
751
752 pub(crate) fn into_buffer(self) -> BitBuffer {
754 self.buffer
755 }
756}
757
758pub enum MaskIter<'a> {
760 Indices(&'a [usize]),
762 Slices(&'a [(usize, usize)]),
764}
765
766impl From<BitBuffer> for Mask {
767 fn from(value: BitBuffer) -> Self {
768 Self::from_buffer(value)
769 }
770}
771
772impl FromIterator<bool> for Mask {
773 #[inline]
774 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
775 Self::from_buffer(BitBuffer::from_iter(iter))
776 }
777}
778
779impl FromIterator<Mask> for Mask {
780 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
781 let masks = iter
782 .into_iter()
783 .filter(|m| !m.is_empty())
784 .collect::<Vec<_>>();
785 let total_length = masks.iter().map(|v| v.len()).sum();
786
787 if masks.iter().all(|v| v.all_true()) {
789 return Self::AllTrue(total_length);
790 }
791 if masks.iter().all(|v| v.all_false()) {
793 return Self::AllFalse(total_length);
794 }
795
796 let mut buffer = BitBufferMut::with_capacity(total_length);
798 for mask in masks {
799 match mask {
800 Mask::AllTrue(count) => buffer.append_n(true, count),
801 Mask::AllFalse(count) => buffer.append_n(false, count),
802 Mask::Values(values) => {
803 buffer.append_buffer(values.bit_buffer());
804 }
805 };
806 }
807 Self::from_buffer(buffer.freeze())
808 }
809}