1#![deny(missing_docs)]
6
7mod bitops;
8mod eq;
9mod intersect_by_rank;
10mod iter_bools;
11mod mask_mut;
12
13#[cfg(feature = "arrow")]
14mod arrow;
15#[cfg(test)]
16mod tests;
17
18use std::cmp::Ordering;
19use std::fmt::Debug;
20use std::fmt::Formatter;
21use std::ops::Bound;
22use std::ops::RangeBounds;
23use std::sync::Arc;
24use std::sync::OnceLock;
25
26use itertools::Itertools;
27pub use mask_mut::*;
28use vortex_buffer::BitBuffer;
29use vortex_buffer::BitBufferMut;
30use vortex_buffer::set_bit_unchecked;
31use vortex_error::VortexResult;
32use vortex_error::vortex_panic;
33
34pub enum AllOr<T> {
36 All,
38 None,
40 Some(T),
42}
43
44impl<T> AllOr<T> {
45 #[inline]
47 pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
48 where
49 F: FnOnce() -> T,
50 G: FnOnce() -> T,
51 {
52 match self {
53 Self::Some(v) => v,
54 AllOr::All => all_true(),
55 AllOr::None => all_false(),
56 }
57 }
58}
59
60impl<T> AllOr<&T> {
61 #[inline]
63 pub fn cloned(self) -> AllOr<T>
64 where
65 T: Clone,
66 {
67 match self {
68 Self::All => AllOr::All,
69 Self::None => AllOr::None,
70 Self::Some(v) => AllOr::Some(v.clone()),
71 }
72 }
73}
74
75impl<T> Debug for AllOr<T>
76where
77 T: Debug,
78{
79 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
80 match self {
81 Self::All => f.write_str("All"),
82 Self::None => f.write_str("None"),
83 Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
84 }
85 }
86}
87
88impl<T> PartialEq for AllOr<T>
89where
90 T: PartialEq,
91{
92 fn eq(&self, other: &Self) -> bool {
93 match (self, other) {
94 (Self::All, Self::All) => true,
95 (Self::None, Self::None) => true,
96 (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
97 _ => false,
98 }
99 }
100}
101
102impl<T> Eq for AllOr<T> where T: Eq {}
103
104#[derive(Debug, Clone)]
109#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
110pub enum Mask {
111 AllTrue(usize),
113 AllFalse(usize),
115 Values(Arc<MaskValues>),
117}
118
119impl Default for Mask {
120 fn default() -> Self {
121 Self::new_true(0)
122 }
123}
124
125#[derive(Debug)]
127#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
128pub struct MaskValues {
129 buffer: BitBuffer,
130
131 #[cfg_attr(feature = "serde", serde(skip))]
134 indices: OnceLock<Vec<usize>>,
135 #[cfg_attr(feature = "serde", serde(skip))]
136 slices: OnceLock<Vec<(usize, usize)>>,
137
138 true_count: usize,
140 density: f64,
142}
143
144impl Mask {
145 pub fn new(length: usize, value: bool) -> Self {
147 if value {
148 Self::AllTrue(length)
149 } else {
150 Self::AllFalse(length)
151 }
152 }
153
154 #[inline]
156 pub fn new_true(length: usize) -> Self {
157 Self::AllTrue(length)
158 }
159
160 #[inline]
162 pub fn new_false(length: usize) -> Self {
163 Self::AllFalse(length)
164 }
165
166 pub fn from_buffer(buffer: BitBuffer) -> Self {
168 let len = buffer.len();
169 let true_count = buffer.true_count();
170
171 if true_count == 0 {
172 return Self::AllFalse(len);
173 }
174 if true_count == len {
175 return Self::AllTrue(len);
176 }
177
178 Self::Values(Arc::new(MaskValues {
179 buffer,
180 indices: Default::default(),
181 slices: Default::default(),
182 true_count,
183 density: true_count as f64 / len as f64,
184 }))
185 }
186
187 pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
190 let true_count = indices.len();
191 assert!(indices.is_sorted(), "Mask indices must be sorted");
192 assert!(
193 indices.last().is_none_or(|&idx| idx < len),
194 "Mask indices must be in bounds (len={len})"
195 );
196
197 if true_count == 0 {
198 return Self::AllFalse(len);
199 }
200 if true_count == len {
201 return Self::AllTrue(len);
202 }
203
204 let mut buf = BitBufferMut::new_unset(len);
205 indices.iter().for_each(|&idx| buf.set(idx));
207 debug_assert_eq!(buf.len(), len);
208
209 Self::Values(Arc::new(MaskValues {
210 buffer: buf.freeze(),
211 indices: OnceLock::from(indices),
212 slices: Default::default(),
213 true_count,
214 density: true_count as f64 / len as f64,
215 }))
216 }
217
218 pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
220 let mut buf = BitBufferMut::new_set(len);
221
222 let mut false_count: usize = 0;
223 indices.into_iter().for_each(|idx| {
224 buf.unset(idx);
225 false_count += 1;
226 });
227 debug_assert_eq!(buf.len(), len);
228 let true_count = len - false_count;
229
230 if false_count == 0 {
232 return Self::AllTrue(len);
233 }
234 if false_count == len {
235 return Self::AllFalse(len);
236 }
237
238 Self::Values(Arc::new(MaskValues {
239 buffer: buf.freeze(),
240 indices: Default::default(),
241 slices: Default::default(),
242 true_count,
243 density: true_count as f64 / len as f64,
244 }))
245 }
246
247 pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
250 Self::check_slices(len, &vec);
251 Self::from_slices_unchecked(len, vec)
252 }
253
254 fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
255 #[cfg(debug_assertions)]
256 Self::check_slices(len, &slices);
257
258 let true_count = slices.iter().map(|(b, e)| e - b).sum();
259 if true_count == 0 {
260 return Self::AllFalse(len);
261 }
262 if true_count == len {
263 return Self::AllTrue(len);
264 }
265
266 let mut buf = BitBufferMut::new_unset(len);
267 for (start, end) in slices.iter().copied() {
268 (start..end).for_each(|idx| buf.set(idx));
269 }
270 debug_assert_eq!(buf.len(), len);
271
272 Self::Values(Arc::new(MaskValues {
273 buffer: buf.freeze(),
274 indices: Default::default(),
275 slices: OnceLock::from(slices),
276 true_count,
277 density: true_count as f64 / len as f64,
278 }))
279 }
280
281 #[inline(always)]
282 fn check_slices(len: usize, vec: &[(usize, usize)]) {
283 assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
284 for (first, second) in vec.iter().tuple_windows() {
285 assert!(
286 first.0 < second.0,
287 "Slices must be sorted, got {first:?} and {second:?}"
288 );
289 assert!(
290 first.1 <= second.0,
291 "Slices must be non-overlapping, got {first:?} and {second:?}"
292 );
293 }
294 }
295
296 pub fn from_intersection_indices(
298 len: usize,
299 lhs: impl Iterator<Item = usize>,
300 rhs: impl Iterator<Item = usize>,
301 ) -> Self {
302 let mut intersection = Vec::with_capacity(len);
303 let mut lhs = lhs.peekable();
304 let mut rhs = rhs.peekable();
305 while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
306 match l.cmp(&r) {
307 Ordering::Less => {
308 lhs.next();
309 }
310 Ordering::Greater => {
311 rhs.next();
312 }
313 Ordering::Equal => {
314 intersection.push(l);
315 lhs.next();
316 rhs.next();
317 }
318 }
319 }
320 Self::from_indices(len, intersection)
321 }
322
323 pub fn clear(&mut self) {
325 *self = Self::new_false(0);
326 }
327
328 #[inline]
330 pub fn len(&self) -> usize {
331 match self {
332 Self::AllTrue(len) => *len,
333 Self::AllFalse(len) => *len,
334 Self::Values(values) => values.len(),
335 }
336 }
337
338 #[inline]
340 pub fn is_empty(&self) -> bool {
341 match self {
342 Self::AllTrue(len) => *len == 0,
343 Self::AllFalse(len) => *len == 0,
344 Self::Values(values) => values.is_empty(),
345 }
346 }
347
348 #[inline]
350 pub fn true_count(&self) -> usize {
351 match &self {
352 Self::AllTrue(len) => *len,
353 Self::AllFalse(_) => 0,
354 Self::Values(values) => values.true_count,
355 }
356 }
357
358 #[inline]
360 pub fn false_count(&self) -> usize {
361 match &self {
362 Self::AllTrue(_) => 0,
363 Self::AllFalse(len) => *len,
364 Self::Values(values) => values.buffer.len() - values.true_count,
365 }
366 }
367
368 #[inline]
370 pub fn all_true(&self) -> bool {
371 match &self {
372 Self::AllTrue(_) => true,
373 Self::AllFalse(0) => true,
374 Self::AllFalse(_) => false,
375 Self::Values(values) => values.buffer.len() == values.true_count,
376 }
377 }
378
379 #[inline]
381 pub fn all_false(&self) -> bool {
382 self.true_count() == 0
383 }
384
385 #[inline]
387 pub fn density(&self) -> f64 {
388 match &self {
389 Self::AllTrue(_) => 1.0,
390 Self::AllFalse(_) => 0.0,
391 Self::Values(values) => values.density,
392 }
393 }
394
395 #[inline]
401 pub fn value(&self, idx: usize) -> bool {
402 match self {
403 Mask::AllTrue(_) => true,
404 Mask::AllFalse(_) => false,
405 Mask::Values(values) => values.buffer.value(idx),
406 }
407 }
408
409 pub fn first(&self) -> Option<usize> {
411 match &self {
412 Self::AllTrue(len) => (*len > 0).then_some(0),
413 Self::AllFalse(_) => None,
414 Self::Values(values) => {
415 if let Some(indices) = values.indices.get() {
416 return indices.first().copied();
417 }
418 if let Some(slices) = values.slices.get() {
419 return slices.first().map(|(start, _)| *start);
420 }
421 values.buffer.set_indices().next()
422 }
423 }
424 }
425
426 pub fn rank(&self, n: usize) -> usize {
428 if n >= self.true_count() {
429 vortex_panic!(
430 "Rank {n} out of bounds for mask with true count {}",
431 self.true_count()
432 );
433 }
434 match &self {
435 Self::AllTrue(_) => n,
436 Self::AllFalse(_) => unreachable!("no true values in all-false mask"),
437 Self::Values(values) => values.indices()[n],
439 }
440 }
441
442 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
444 let start = match range.start_bound() {
445 Bound::Included(&s) => s,
446 Bound::Excluded(&s) => s + 1,
447 Bound::Unbounded => 0,
448 };
449 let end = match range.end_bound() {
450 Bound::Included(&e) => e + 1,
451 Bound::Excluded(&e) => e,
452 Bound::Unbounded => self.len(),
453 };
454
455 assert!(start <= end);
456 assert!(start <= self.len());
457 assert!(end <= self.len());
458 let len = end - start;
459
460 match &self {
461 Self::AllTrue(_) => Self::new_true(len),
462 Self::AllFalse(_) => Self::new_false(len),
463 Self::Values(values) => Self::from_buffer(values.buffer.slice(range)),
464 }
465 }
466
467 #[inline]
469 pub fn bit_buffer(&self) -> AllOr<&BitBuffer> {
470 match &self {
471 Self::AllTrue(_) => AllOr::All,
472 Self::AllFalse(_) => AllOr::None,
473 Self::Values(values) => AllOr::Some(&values.buffer),
474 }
475 }
476
477 #[inline]
480 pub fn to_bit_buffer(&self) -> BitBuffer {
481 match self {
482 Self::AllTrue(l) => BitBuffer::new_set(*l),
483 Self::AllFalse(l) => BitBuffer::new_unset(*l),
484 Self::Values(values) => values.bit_buffer().clone(),
485 }
486 }
487
488 #[inline]
491 pub fn into_bit_buffer(self) -> BitBuffer {
492 match self {
493 Self::AllTrue(l) => BitBuffer::new_set(l),
494 Self::AllFalse(l) => BitBuffer::new_unset(l),
495 Self::Values(values) => Arc::try_unwrap(values)
496 .map(|v| v.into_bit_buffer())
497 .unwrap_or_else(|v| v.bit_buffer().clone()),
498 }
499 }
500
501 #[inline]
503 pub fn indices(&self) -> AllOr<&[usize]> {
504 match &self {
505 Self::AllTrue(_) => AllOr::All,
506 Self::AllFalse(_) => AllOr::None,
507 Self::Values(values) => AllOr::Some(values.indices()),
508 }
509 }
510
511 #[inline]
513 pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
514 match &self {
515 Self::AllTrue(_) => AllOr::All,
516 Self::AllFalse(_) => AllOr::None,
517 Self::Values(values) => AllOr::Some(values.slices()),
518 }
519 }
520
521 #[inline]
523 pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
524 match &self {
525 Self::AllTrue(_) => AllOr::All,
526 Self::AllFalse(_) => AllOr::None,
527 Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
528 }
529 }
530
531 #[inline]
533 pub fn values(&self) -> Option<&MaskValues> {
534 if let Self::Values(values) = self {
535 Some(values)
536 } else {
537 None
538 }
539 }
540
541 pub fn valid_counts_for_indices(&self, indices: &[usize]) -> Vec<usize> {
546 match self {
547 Self::AllTrue(_) => indices.to_vec(),
548 Self::AllFalse(_) => vec![0; indices.len()],
549 Self::Values(values) => {
550 let mut bool_iter = values.bit_buffer().iter();
551 let mut valid_counts = Vec::with_capacity(indices.len());
552 let mut valid_count = 0;
553 let mut idx = 0;
554 for &next_idx in indices {
555 while idx < next_idx {
556 idx += 1;
557 valid_count += bool_iter
558 .next()
559 .unwrap_or_else(|| vortex_panic!("Row indices exceed array length"))
560 as usize;
561 }
562 valid_counts.push(valid_count);
563 }
564
565 valid_counts
566 }
567 }
568 }
569
570 pub fn limit(self, limit: usize) -> Self {
572 if self.len() <= limit {
576 return self;
577 }
578
579 match self {
580 Mask::AllTrue(len) => {
581 Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
582 }
583 Mask::AllFalse(_) => self,
584 Mask::Values(ref mask_values) => {
585 if limit >= mask_values.true_count() {
586 return self;
587 }
588
589 let existing_buffer = mask_values.bit_buffer();
590
591 let mut new_buffer_builder = BitBufferMut::new_unset(mask_values.len());
592 debug_assert!(limit < mask_values.len());
593
594 let ptr = new_buffer_builder.as_mut_ptr();
595 for index in existing_buffer.set_indices().take(limit) {
596 unsafe { set_bit_unchecked(ptr, index) }
599 }
600
601 Self::from(new_buffer_builder.freeze())
602 }
603 }
604 }
605
606 pub fn concat<'a>(masks: impl Iterator<Item = &'a Self>) -> VortexResult<Self> {
608 let masks: Vec<_> = masks.collect();
609 let len = masks.iter().map(|t| t.len()).sum();
610
611 if masks.iter().all(|t| t.all_true()) {
612 return Ok(Mask::AllTrue(len));
613 }
614
615 if masks.iter().all(|t| t.all_false()) {
616 return Ok(Mask::AllFalse(len));
617 }
618
619 let mut builder = BitBufferMut::with_capacity(len);
620
621 for mask in masks {
622 match mask {
623 Mask::AllTrue(n) => builder.append_n(true, *n),
624 Mask::AllFalse(n) => builder.append_n(false, *n),
625 Mask::Values(v) => builder.append_buffer(v.bit_buffer()),
626 }
627 }
628
629 Ok(Mask::from_buffer(builder.freeze()))
630 }
631}
632
633impl MaskValues {
634 #[inline]
636 pub fn len(&self) -> usize {
637 self.buffer.len()
638 }
639
640 #[inline]
642 pub fn is_empty(&self) -> bool {
643 self.buffer.is_empty()
644 }
645
646 #[inline]
648 pub fn density(&self) -> f64 {
649 self.density
650 }
651
652 #[inline]
654 pub fn true_count(&self) -> usize {
655 self.true_count
656 }
657
658 #[inline]
660 pub fn bit_buffer(&self) -> &BitBuffer {
661 &self.buffer
662 }
663
664 #[inline]
666 pub fn into_bit_buffer(self) -> BitBuffer {
667 self.buffer
668 }
669
670 #[inline]
672 pub fn value(&self, index: usize) -> bool {
673 self.buffer.value(index)
674 }
675
676 pub fn indices(&self) -> &[usize] {
678 self.indices.get_or_init(|| {
679 if self.true_count == 0 {
680 return vec![];
681 }
682
683 if self.true_count == self.len() {
684 return (0..self.len()).collect();
685 }
686
687 if let Some(slices) = self.slices.get() {
688 let mut indices = Vec::with_capacity(self.true_count);
689 indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
690 debug_assert!(indices.is_sorted());
691 assert_eq!(indices.len(), self.true_count);
692 return indices;
693 }
694
695 let mut indices = Vec::with_capacity(self.true_count);
696 indices.extend(self.buffer.set_indices());
697 debug_assert!(indices.is_sorted());
698 assert_eq!(indices.len(), self.true_count);
699 indices
700 })
701 }
702
703 #[inline]
705 pub fn slices(&self) -> &[(usize, usize)] {
706 self.slices.get_or_init(|| {
707 if self.true_count == self.len() {
708 return vec![(0, self.len())];
709 }
710
711 self.buffer.set_slices().collect()
712 })
713 }
714
715 #[inline]
717 pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
718 if self.density >= threshold {
719 MaskIter::Slices(self.slices())
720 } else {
721 MaskIter::Indices(self.indices())
722 }
723 }
724
725 pub(crate) fn into_buffer(self) -> BitBuffer {
727 self.buffer
728 }
729}
730
731pub enum MaskIter<'a> {
733 Indices(&'a [usize]),
735 Slices(&'a [(usize, usize)]),
737}
738
739impl From<BitBuffer> for Mask {
740 fn from(value: BitBuffer) -> Self {
741 Self::from_buffer(value)
742 }
743}
744
745impl FromIterator<bool> for Mask {
746 #[inline]
747 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
748 Self::from_buffer(BitBuffer::from_iter(iter))
749 }
750}
751
752impl FromIterator<Mask> for Mask {
753 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
754 let masks = iter
755 .into_iter()
756 .filter(|m| !m.is_empty())
757 .collect::<Vec<_>>();
758 let total_length = masks.iter().map(|v| v.len()).sum();
759
760 if masks.iter().all(|v| v.all_true()) {
762 return Self::AllTrue(total_length);
763 }
764 if masks.iter().all(|v| v.all_false()) {
766 return Self::AllFalse(total_length);
767 }
768
769 let mut buffer = BitBufferMut::with_capacity(total_length);
771 for mask in masks {
772 match mask {
773 Mask::AllTrue(count) => buffer.append_n(true, count),
774 Mask::AllFalse(count) => buffer.append_n(false, count),
775 Mask::Values(values) => {
776 buffer.append_buffer(values.bit_buffer());
777 }
778 };
779 }
780 Self::from_buffer(buffer.freeze())
781 }
782}