1#![deny(missing_docs)]
6
7mod bitops;
8mod eq;
9mod intersect_by_rank;
10mod iter_bools;
11mod mask_mut;
12
13#[cfg(feature = "arrow")]
14mod arrow;
15#[cfg(test)]
16mod tests;
17
18use std::cmp::Ordering;
19use std::fmt::Debug;
20use std::fmt::Formatter;
21use std::ops::Bound;
22use std::ops::RangeBounds;
23use std::sync::Arc;
24use std::sync::OnceLock;
25
26use itertools::Itertools;
27pub use mask_mut::*;
28use vortex_buffer::BitBuffer;
29use vortex_buffer::BitBufferMut;
30use vortex_buffer::set_bit_unchecked;
31use vortex_error::VortexResult;
32use vortex_error::vortex_panic;
33
34pub enum AllOr<T> {
36 All,
38 None,
40 Some(T),
42}
43
44impl<T> AllOr<T> {
45 #[inline]
47 pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
48 where
49 F: FnOnce() -> T,
50 G: FnOnce() -> T,
51 {
52 match self {
53 Self::Some(v) => v,
54 AllOr::All => all_true(),
55 AllOr::None => all_false(),
56 }
57 }
58}
59
60impl<T> AllOr<&T> {
61 #[inline]
63 pub fn cloned(self) -> AllOr<T>
64 where
65 T: Clone,
66 {
67 match self {
68 Self::All => AllOr::All,
69 Self::None => AllOr::None,
70 Self::Some(v) => AllOr::Some(v.clone()),
71 }
72 }
73}
74
75impl<T> Debug for AllOr<T>
76where
77 T: Debug,
78{
79 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
80 match self {
81 Self::All => f.write_str("All"),
82 Self::None => f.write_str("None"),
83 Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
84 }
85 }
86}
87
88impl<T> PartialEq for AllOr<T>
89where
90 T: PartialEq,
91{
92 fn eq(&self, other: &Self) -> bool {
93 match (self, other) {
94 (Self::All, Self::All) => true,
95 (Self::None, Self::None) => true,
96 (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
97 _ => false,
98 }
99 }
100}
101
102impl<T> Eq for AllOr<T> where T: Eq {}
103
104#[derive(Debug, Clone)]
109#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
110pub enum Mask {
111 AllTrue(usize),
113 AllFalse(usize),
115 Values(Arc<MaskValues>),
117}
118
119impl Default for Mask {
120 fn default() -> Self {
121 Self::new_true(0)
122 }
123}
124
125#[derive(Debug)]
127#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
128pub struct MaskValues {
129 buffer: BitBuffer,
130
131 #[cfg_attr(feature = "serde", serde(skip))]
134 indices: OnceLock<Vec<usize>>,
135 #[cfg_attr(feature = "serde", serde(skip))]
136 slices: OnceLock<Vec<(usize, usize)>>,
137
138 true_count: usize,
140 density: f64,
142}
143
144impl Mask {
145 pub fn new(length: usize, value: bool) -> Self {
147 if value {
148 Self::AllTrue(length)
149 } else {
150 Self::AllFalse(length)
151 }
152 }
153
154 #[inline]
156 pub fn new_true(length: usize) -> Self {
157 Self::AllTrue(length)
158 }
159
160 #[inline]
162 pub fn new_false(length: usize) -> Self {
163 Self::AllFalse(length)
164 }
165
166 pub fn from_buffer(buffer: BitBuffer) -> Self {
168 let len = buffer.len();
169 let true_count = buffer.true_count();
170
171 if true_count == 0 {
172 return Self::AllFalse(len);
173 }
174 if true_count == len {
175 return Self::AllTrue(len);
176 }
177
178 Self::Values(Arc::new(MaskValues {
179 buffer,
180 indices: Default::default(),
181 slices: Default::default(),
182 true_count,
183 density: true_count as f64 / len as f64,
184 }))
185 }
186
187 pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
190 let true_count = indices.len();
191 assert!(indices.is_sorted(), "Mask indices must be sorted");
192 assert!(
193 indices.last().is_none_or(|&idx| idx < len),
194 "Mask indices must be in bounds (len={len})"
195 );
196
197 if true_count == 0 {
198 return Self::AllFalse(len);
199 }
200 if true_count == len {
201 return Self::AllTrue(len);
202 }
203
204 let mut buf = BitBufferMut::new_unset(len);
205 indices.iter().for_each(|&idx| buf.set(idx));
207 debug_assert_eq!(buf.len(), len);
208
209 Self::Values(Arc::new(MaskValues {
210 buffer: buf.freeze(),
211 indices: OnceLock::from(indices),
212 slices: Default::default(),
213 true_count,
214 density: true_count as f64 / len as f64,
215 }))
216 }
217
218 pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
220 let mut buf = BitBufferMut::new_set(len);
221
222 let mut false_count: usize = 0;
223 indices.into_iter().for_each(|idx| {
224 buf.unset(idx);
225 false_count += 1;
226 });
227 debug_assert_eq!(buf.len(), len);
228 let true_count = len - false_count;
229
230 if false_count == 0 {
232 return Self::AllTrue(len);
233 }
234 if false_count == len {
235 return Self::AllFalse(len);
236 }
237
238 Self::Values(Arc::new(MaskValues {
239 buffer: buf.freeze(),
240 indices: Default::default(),
241 slices: Default::default(),
242 true_count,
243 density: true_count as f64 / len as f64,
244 }))
245 }
246
247 pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
250 Self::check_slices(len, &vec);
251 Self::from_slices_unchecked(len, vec)
252 }
253
254 fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
255 #[cfg(debug_assertions)]
256 Self::check_slices(len, &slices);
257
258 let true_count = slices.iter().map(|(b, e)| e - b).sum();
259 if true_count == 0 {
260 return Self::AllFalse(len);
261 }
262 if true_count == len {
263 return Self::AllTrue(len);
264 }
265
266 let mut buf = BitBufferMut::new_unset(len);
267 for (start, end) in slices.iter().copied() {
268 (start..end).for_each(|idx| buf.set(idx));
269 }
270 debug_assert_eq!(buf.len(), len);
271
272 Self::Values(Arc::new(MaskValues {
273 buffer: buf.freeze(),
274 indices: Default::default(),
275 slices: OnceLock::from(slices),
276 true_count,
277 density: true_count as f64 / len as f64,
278 }))
279 }
280
281 #[inline(always)]
282 fn check_slices(len: usize, vec: &[(usize, usize)]) {
283 assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
284 for (first, second) in vec.iter().tuple_windows() {
285 assert!(
286 first.0 < second.0,
287 "Slices must be sorted, got {first:?} and {second:?}"
288 );
289 assert!(
290 first.1 <= second.0,
291 "Slices must be non-overlapping, got {first:?} and {second:?}"
292 );
293 }
294 }
295
296 pub fn from_intersection_indices(
298 len: usize,
299 lhs: impl Iterator<Item = usize>,
300 rhs: impl Iterator<Item = usize>,
301 ) -> Self {
302 let mut intersection = Vec::with_capacity(len);
303 let mut lhs = lhs.peekable();
304 let mut rhs = rhs.peekable();
305 while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
306 match l.cmp(&r) {
307 Ordering::Less => {
308 lhs.next();
309 }
310 Ordering::Greater => {
311 rhs.next();
312 }
313 Ordering::Equal => {
314 intersection.push(l);
315 lhs.next();
316 rhs.next();
317 }
318 }
319 }
320 Self::from_indices(len, intersection)
321 }
322
323 pub fn clear(&mut self) {
325 *self = Self::new_false(0);
326 }
327
328 #[inline]
330 pub fn len(&self) -> usize {
331 match self {
332 Self::AllTrue(len) => *len,
333 Self::AllFalse(len) => *len,
334 Self::Values(values) => values.len(),
335 }
336 }
337
338 #[inline]
340 pub fn is_empty(&self) -> bool {
341 match self {
342 Self::AllTrue(len) => *len == 0,
343 Self::AllFalse(len) => *len == 0,
344 Self::Values(values) => values.is_empty(),
345 }
346 }
347
348 #[inline]
350 pub fn true_count(&self) -> usize {
351 match &self {
352 Self::AllTrue(len) => *len,
353 Self::AllFalse(_) => 0,
354 Self::Values(values) => values.true_count,
355 }
356 }
357
358 #[inline]
360 pub fn false_count(&self) -> usize {
361 match &self {
362 Self::AllTrue(_) => 0,
363 Self::AllFalse(len) => *len,
364 Self::Values(values) => values.buffer.len() - values.true_count,
365 }
366 }
367
368 #[inline]
370 pub fn all_true(&self) -> bool {
371 match &self {
372 Self::AllTrue(_) => true,
373 Self::AllFalse(0) => true,
374 Self::AllFalse(_) => false,
375 Self::Values(values) => values.buffer.len() == values.true_count,
376 }
377 }
378
379 #[inline]
381 pub fn all_false(&self) -> bool {
382 self.true_count() == 0
383 }
384
385 #[inline]
387 pub fn density(&self) -> f64 {
388 match &self {
389 Self::AllTrue(_) => 1.0,
390 Self::AllFalse(_) => 0.0,
391 Self::Values(values) => values.density,
392 }
393 }
394
395 #[inline]
401 pub fn value(&self, idx: usize) -> bool {
402 match self {
403 Mask::AllTrue(_) => true,
404 Mask::AllFalse(_) => false,
405 Mask::Values(values) => values.buffer.value(idx),
406 }
407 }
408
409 pub fn first(&self) -> Option<usize> {
411 match &self {
412 Self::AllTrue(len) => (*len > 0).then_some(0),
413 Self::AllFalse(_) => None,
414 Self::Values(values) => {
415 if let Some(indices) = values.indices.get() {
416 return indices.first().copied();
417 }
418 if let Some(slices) = values.slices.get() {
419 return slices.first().map(|(start, _)| *start);
420 }
421 values.buffer.set_indices().next()
422 }
423 }
424 }
425
426 pub fn rank(&self, n: usize) -> usize {
428 if n >= self.true_count() {
429 vortex_panic!(
430 "Rank {n} out of bounds for mask with true count {}",
431 self.true_count()
432 );
433 }
434 match &self {
435 Self::AllTrue(_) => n,
436 Self::AllFalse(_) => unreachable!("no true values in all-false mask"),
437 Self::Values(values) => values.indices()[n],
438 }
439 }
440
441 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
443 let start = match range.start_bound() {
444 Bound::Included(&s) => s,
445 Bound::Excluded(&s) => s + 1,
446 Bound::Unbounded => 0,
447 };
448 let end = match range.end_bound() {
449 Bound::Included(&e) => e + 1,
450 Bound::Excluded(&e) => e,
451 Bound::Unbounded => self.len(),
452 };
453
454 assert!(start <= end);
455 assert!(start <= self.len());
456 assert!(end <= self.len());
457 let len = end - start;
458
459 match &self {
460 Self::AllTrue(_) => Self::new_true(len),
461 Self::AllFalse(_) => Self::new_false(len),
462 Self::Values(values) => Self::from_buffer(values.buffer.slice(range)),
463 }
464 }
465
466 #[inline]
468 pub fn bit_buffer(&self) -> AllOr<&BitBuffer> {
469 match &self {
470 Self::AllTrue(_) => AllOr::All,
471 Self::AllFalse(_) => AllOr::None,
472 Self::Values(values) => AllOr::Some(&values.buffer),
473 }
474 }
475
476 #[inline]
479 pub fn to_bit_buffer(&self) -> BitBuffer {
480 match self {
481 Self::AllTrue(l) => BitBuffer::new_set(*l),
482 Self::AllFalse(l) => BitBuffer::new_unset(*l),
483 Self::Values(values) => values.bit_buffer().clone(),
484 }
485 }
486
487 #[inline]
490 pub fn into_bit_buffer(self) -> BitBuffer {
491 match self {
492 Self::AllTrue(l) => BitBuffer::new_set(l),
493 Self::AllFalse(l) => BitBuffer::new_unset(l),
494 Self::Values(values) => Arc::try_unwrap(values)
495 .map(|v| v.into_bit_buffer())
496 .unwrap_or_else(|v| v.bit_buffer().clone()),
497 }
498 }
499
500 #[inline]
502 pub fn indices(&self) -> AllOr<&[usize]> {
503 match &self {
504 Self::AllTrue(_) => AllOr::All,
505 Self::AllFalse(_) => AllOr::None,
506 Self::Values(values) => AllOr::Some(values.indices()),
507 }
508 }
509
510 #[inline]
512 pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
513 match &self {
514 Self::AllTrue(_) => AllOr::All,
515 Self::AllFalse(_) => AllOr::None,
516 Self::Values(values) => AllOr::Some(values.slices()),
517 }
518 }
519
520 #[inline]
522 pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
523 match &self {
524 Self::AllTrue(_) => AllOr::All,
525 Self::AllFalse(_) => AllOr::None,
526 Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
527 }
528 }
529
530 #[inline]
532 pub fn values(&self) -> Option<&MaskValues> {
533 if let Self::Values(values) = self {
534 Some(values)
535 } else {
536 None
537 }
538 }
539
540 pub fn valid_counts_for_indices(&self, indices: &[usize]) -> Vec<usize> {
545 match self {
546 Self::AllTrue(_) => indices.to_vec(),
547 Self::AllFalse(_) => vec![0; indices.len()],
548 Self::Values(values) => {
549 let mut bool_iter = values.bit_buffer().iter();
550 let mut valid_counts = Vec::with_capacity(indices.len());
551 let mut valid_count = 0;
552 let mut idx = 0;
553 for &next_idx in indices {
554 while idx < next_idx {
555 idx += 1;
556 valid_count += bool_iter
557 .next()
558 .unwrap_or_else(|| vortex_panic!("Row indices exceed array length"))
559 as usize;
560 }
561 valid_counts.push(valid_count);
562 }
563
564 valid_counts
565 }
566 }
567 }
568
569 pub fn limit(self, limit: usize) -> Self {
571 if self.len() <= limit {
575 return self;
576 }
577
578 match self {
579 Mask::AllTrue(len) => {
580 Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
581 }
582 Mask::AllFalse(_) => self,
583 Mask::Values(ref mask_values) => {
584 if limit >= mask_values.true_count() {
585 return self;
586 }
587
588 let existing_buffer = mask_values.bit_buffer();
589
590 let mut new_buffer_builder = BitBufferMut::new_unset(mask_values.len());
591 debug_assert!(limit < mask_values.len());
592
593 let ptr = new_buffer_builder.as_mut_ptr();
594 for index in existing_buffer.set_indices().take(limit) {
595 unsafe { set_bit_unchecked(ptr, index) }
598 }
599
600 Self::from(new_buffer_builder.freeze())
601 }
602 }
603 }
604
605 pub fn concat<'a>(masks: impl Iterator<Item = &'a Self>) -> VortexResult<Self> {
607 let masks: Vec<_> = masks.collect();
608 let len = masks.iter().map(|t| t.len()).sum();
609
610 if masks.iter().all(|t| t.all_true()) {
611 return Ok(Mask::AllTrue(len));
612 }
613
614 if masks.iter().all(|t| t.all_false()) {
615 return Ok(Mask::AllFalse(len));
616 }
617
618 let mut builder = BitBufferMut::with_capacity(len);
619
620 for mask in masks {
621 match mask {
622 Mask::AllTrue(n) => builder.append_n(true, *n),
623 Mask::AllFalse(n) => builder.append_n(false, *n),
624 Mask::Values(v) => builder.append_buffer(v.bit_buffer()),
625 }
626 }
627
628 Ok(Mask::from_buffer(builder.freeze()))
629 }
630}
631
632impl MaskValues {
633 #[inline]
635 pub fn len(&self) -> usize {
636 self.buffer.len()
637 }
638
639 #[inline]
641 pub fn is_empty(&self) -> bool {
642 self.buffer.is_empty()
643 }
644
645 #[inline]
647 pub fn density(&self) -> f64 {
648 self.density
649 }
650
651 #[inline]
653 pub fn true_count(&self) -> usize {
654 self.true_count
655 }
656
657 #[inline]
659 pub fn bit_buffer(&self) -> &BitBuffer {
660 &self.buffer
661 }
662
663 #[inline]
665 pub fn into_bit_buffer(self) -> BitBuffer {
666 self.buffer
667 }
668
669 #[inline]
671 pub fn value(&self, index: usize) -> bool {
672 self.buffer.value(index)
673 }
674
675 pub fn indices(&self) -> &[usize] {
677 self.indices.get_or_init(|| {
678 if self.true_count == 0 {
679 return vec![];
680 }
681
682 if self.true_count == self.len() {
683 return (0..self.len()).collect();
684 }
685
686 if let Some(slices) = self.slices.get() {
687 let mut indices = Vec::with_capacity(self.true_count);
688 indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
689 debug_assert!(indices.is_sorted());
690 assert_eq!(indices.len(), self.true_count);
691 return indices;
692 }
693
694 let mut indices = Vec::with_capacity(self.true_count);
695 indices.extend(self.buffer.set_indices());
696 debug_assert!(indices.is_sorted());
697 assert_eq!(indices.len(), self.true_count);
698 indices
699 })
700 }
701
702 #[inline]
704 pub fn slices(&self) -> &[(usize, usize)] {
705 self.slices.get_or_init(|| {
706 if self.true_count == self.len() {
707 return vec![(0, self.len())];
708 }
709
710 self.buffer.set_slices().collect()
711 })
712 }
713
714 #[inline]
716 pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
717 if self.density >= threshold {
718 MaskIter::Slices(self.slices())
719 } else {
720 MaskIter::Indices(self.indices())
721 }
722 }
723
724 pub(crate) fn into_buffer(self) -> BitBuffer {
726 self.buffer
727 }
728}
729
730pub enum MaskIter<'a> {
732 Indices(&'a [usize]),
734 Slices(&'a [(usize, usize)]),
736}
737
738impl From<BitBuffer> for Mask {
739 fn from(value: BitBuffer) -> Self {
740 Self::from_buffer(value)
741 }
742}
743
744impl FromIterator<bool> for Mask {
745 #[inline]
746 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
747 Self::from_buffer(BitBuffer::from_iter(iter))
748 }
749}
750
751impl FromIterator<Mask> for Mask {
752 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
753 let masks = iter
754 .into_iter()
755 .filter(|m| !m.is_empty())
756 .collect::<Vec<_>>();
757 let total_length = masks.iter().map(|v| v.len()).sum();
758
759 if masks.iter().all(|v| v.all_true()) {
761 return Self::AllTrue(total_length);
762 }
763 if masks.iter().all(|v| v.all_false()) {
765 return Self::AllFalse(total_length);
766 }
767
768 let mut buffer = BitBufferMut::with_capacity(total_length);
770 for mask in masks {
771 match mask {
772 Mask::AllTrue(count) => buffer.append_n(true, count),
773 Mask::AllFalse(count) => buffer.append_n(false, count),
774 Mask::Values(values) => {
775 buffer.append_buffer(values.bit_buffer());
776 }
777 };
778 }
779 Self::from_buffer(buffer.freeze())
780 }
781}