vortex_mask/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! A mask is a set of sorted unique positive integers.
5#![deny(missing_docs)]
6
7mod bitops;
8mod eq;
9mod intersect_by_rank;
10mod iter_bools;
11mod mask_mut;
12
13#[cfg(feature = "arrow")]
14mod arrow;
15#[cfg(test)]
16mod tests;
17
18use std::cmp::Ordering;
19use std::fmt::Debug;
20use std::fmt::Formatter;
21use std::ops::Bound;
22use std::ops::RangeBounds;
23use std::sync::Arc;
24use std::sync::OnceLock;
25
26use itertools::Itertools;
27pub use mask_mut::*;
28use vortex_buffer::BitBuffer;
29use vortex_buffer::BitBufferMut;
30use vortex_buffer::set_bit_unchecked;
31use vortex_error::VortexResult;
32use vortex_error::vortex_panic;
33
34/// Represents a set of values that are all included, all excluded, or some mixture of both.
35pub enum AllOr<T> {
36    /// All values are included.
37    All,
38    /// No values are included.
39    None,
40    /// Some values are included.
41    Some(T),
42}
43
44impl<T> AllOr<T> {
45    /// Returns the `Some` variant of the enum, or a default value.
46    #[inline]
47    pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
48    where
49        F: FnOnce() -> T,
50        G: FnOnce() -> T,
51    {
52        match self {
53            Self::Some(v) => v,
54            AllOr::All => all_true(),
55            AllOr::None => all_false(),
56        }
57    }
58}
59
60impl<T> AllOr<&T> {
61    /// Clone the inner value.
62    #[inline]
63    pub fn cloned(self) -> AllOr<T>
64    where
65        T: Clone,
66    {
67        match self {
68            Self::All => AllOr::All,
69            Self::None => AllOr::None,
70            Self::Some(v) => AllOr::Some(v.clone()),
71        }
72    }
73}
74
75impl<T> Debug for AllOr<T>
76where
77    T: Debug,
78{
79    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
80        match self {
81            Self::All => f.write_str("All"),
82            Self::None => f.write_str("None"),
83            Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
84        }
85    }
86}
87
88impl<T> PartialEq for AllOr<T>
89where
90    T: PartialEq,
91{
92    fn eq(&self, other: &Self) -> bool {
93        match (self, other) {
94            (Self::All, Self::All) => true,
95            (Self::None, Self::None) => true,
96            (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
97            _ => false,
98        }
99    }
100}
101
102impl<T> Eq for AllOr<T> where T: Eq {}
103
104/// Represents a set of sorted unique positive integers.
105///
106/// A [`Mask`] can be constructed from various representations, and converted to various
107/// others. Internally, these are cached.
108#[derive(Debug, Clone)]
109#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
110pub enum Mask {
111    /// All values are included.
112    AllTrue(usize),
113    /// No values are included.
114    AllFalse(usize),
115    /// Some values are included, represented as a [`BitBuffer`].
116    Values(Arc<MaskValues>),
117}
118
119impl Default for Mask {
120    fn default() -> Self {
121        Self::new_true(0)
122    }
123}
124
125/// Represents the values of a [`Mask`] that contains some true and some false elements.
126#[derive(Debug)]
127#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
128pub struct MaskValues {
129    buffer: BitBuffer,
130
131    // We cached the indices and slices representations, since it can be faster than iterating
132    // the bit-mask over and over again.
133    #[cfg_attr(feature = "serde", serde(skip))]
134    indices: OnceLock<Vec<usize>>,
135    #[cfg_attr(feature = "serde", serde(skip))]
136    slices: OnceLock<Vec<(usize, usize)>>,
137
138    // Pre-computed values.
139    true_count: usize,
140    // i.e., the fraction of values that are true
141    density: f64,
142}
143
144impl Mask {
145    /// Create a new Mask with the given length.
146    pub fn new(length: usize, value: bool) -> Self {
147        if value {
148            Self::AllTrue(length)
149        } else {
150            Self::AllFalse(length)
151        }
152    }
153
154    /// Create a new Mask where all values are set.
155    #[inline]
156    pub fn new_true(length: usize) -> Self {
157        Self::AllTrue(length)
158    }
159
160    /// Create a new Mask where no values are set.
161    #[inline]
162    pub fn new_false(length: usize) -> Self {
163        Self::AllFalse(length)
164    }
165
166    /// Create a new [`Mask`] from a [`BitBuffer`].
167    pub fn from_buffer(buffer: BitBuffer) -> Self {
168        let len = buffer.len();
169        let true_count = buffer.true_count();
170
171        if true_count == 0 {
172            return Self::AllFalse(len);
173        }
174        if true_count == len {
175            return Self::AllTrue(len);
176        }
177
178        Self::Values(Arc::new(MaskValues {
179            buffer,
180            indices: Default::default(),
181            slices: Default::default(),
182            true_count,
183            density: true_count as f64 / len as f64,
184        }))
185    }
186
187    /// Create a new [`Mask`] from a [`Vec<usize>`].
188    // TODO(ngates): this should take an IntoIterator<usize>.
189    pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
190        let true_count = indices.len();
191        assert!(indices.is_sorted(), "Mask indices must be sorted");
192        assert!(
193            indices.last().is_none_or(|&idx| idx < len),
194            "Mask indices must be in bounds (len={len})"
195        );
196
197        if true_count == 0 {
198            return Self::AllFalse(len);
199        }
200        if true_count == len {
201            return Self::AllTrue(len);
202        }
203
204        let mut buf = BitBufferMut::new_unset(len);
205        // TODO(ngates): for dense indices, we can do better by collecting into u64s.
206        indices.iter().for_each(|&idx| buf.set(idx));
207        debug_assert_eq!(buf.len(), len);
208
209        Self::Values(Arc::new(MaskValues {
210            buffer: buf.freeze(),
211            indices: OnceLock::from(indices),
212            slices: Default::default(),
213            true_count,
214            density: true_count as f64 / len as f64,
215        }))
216    }
217
218    /// Create a new [`Mask`] from an [`IntoIterator<Item = usize>`] of indices to be excluded.
219    pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
220        let mut buf = BitBufferMut::new_set(len);
221
222        let mut false_count: usize = 0;
223        indices.into_iter().for_each(|idx| {
224            buf.unset(idx);
225            false_count += 1;
226        });
227        debug_assert_eq!(buf.len(), len);
228        let true_count = len - false_count;
229
230        // Return optimized variants when appropriate
231        if false_count == 0 {
232            return Self::AllTrue(len);
233        }
234        if false_count == len {
235            return Self::AllFalse(len);
236        }
237
238        Self::Values(Arc::new(MaskValues {
239            buffer: buf.freeze(),
240            indices: Default::default(),
241            slices: Default::default(),
242            true_count,
243            density: true_count as f64 / len as f64,
244        }))
245    }
246
247    /// Create a new [`Mask`] from a [`Vec<(usize, usize)>`] where each range
248    /// represents a contiguous range of true values.
249    pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
250        Self::check_slices(len, &vec);
251        Self::from_slices_unchecked(len, vec)
252    }
253
254    fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
255        #[cfg(debug_assertions)]
256        Self::check_slices(len, &slices);
257
258        let true_count = slices.iter().map(|(b, e)| e - b).sum();
259        if true_count == 0 {
260            return Self::AllFalse(len);
261        }
262        if true_count == len {
263            return Self::AllTrue(len);
264        }
265
266        let mut buf = BitBufferMut::new_unset(len);
267        for (start, end) in slices.iter().copied() {
268            (start..end).for_each(|idx| buf.set(idx));
269        }
270        debug_assert_eq!(buf.len(), len);
271
272        Self::Values(Arc::new(MaskValues {
273            buffer: buf.freeze(),
274            indices: Default::default(),
275            slices: OnceLock::from(slices),
276            true_count,
277            density: true_count as f64 / len as f64,
278        }))
279    }
280
281    #[inline(always)]
282    fn check_slices(len: usize, vec: &[(usize, usize)]) {
283        assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
284        for (first, second) in vec.iter().tuple_windows() {
285            assert!(
286                first.0 < second.0,
287                "Slices must be sorted, got {first:?} and {second:?}"
288            );
289            assert!(
290                first.1 <= second.0,
291                "Slices must be non-overlapping, got {first:?} and {second:?}"
292            );
293        }
294    }
295
296    /// Create a new [`Mask`] from the intersection of two indices slices.
297    pub fn from_intersection_indices(
298        len: usize,
299        lhs: impl Iterator<Item = usize>,
300        rhs: impl Iterator<Item = usize>,
301    ) -> Self {
302        let mut intersection = Vec::with_capacity(len);
303        let mut lhs = lhs.peekable();
304        let mut rhs = rhs.peekable();
305        while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
306            match l.cmp(&r) {
307                Ordering::Less => {
308                    lhs.next();
309                }
310                Ordering::Greater => {
311                    rhs.next();
312                }
313                Ordering::Equal => {
314                    intersection.push(l);
315                    lhs.next();
316                    rhs.next();
317                }
318            }
319        }
320        Self::from_indices(len, intersection)
321    }
322
323    /// Clears the mask of all data. Drops any allocated capacity.
324    pub fn clear(&mut self) {
325        *self = Self::new_false(0);
326    }
327
328    /// Returns the length of the mask (not the number of true values).
329    #[inline]
330    pub fn len(&self) -> usize {
331        match self {
332            Self::AllTrue(len) => *len,
333            Self::AllFalse(len) => *len,
334            Self::Values(values) => values.len(),
335        }
336    }
337
338    /// Returns true if the mask is empty i.e., it's length is 0.
339    #[inline]
340    pub fn is_empty(&self) -> bool {
341        match self {
342            Self::AllTrue(len) => *len == 0,
343            Self::AllFalse(len) => *len == 0,
344            Self::Values(values) => values.is_empty(),
345        }
346    }
347
348    /// Get the true count of the mask.
349    #[inline]
350    pub fn true_count(&self) -> usize {
351        match &self {
352            Self::AllTrue(len) => *len,
353            Self::AllFalse(_) => 0,
354            Self::Values(values) => values.true_count,
355        }
356    }
357
358    /// Get the false count of the mask.
359    #[inline]
360    pub fn false_count(&self) -> usize {
361        match &self {
362            Self::AllTrue(_) => 0,
363            Self::AllFalse(len) => *len,
364            Self::Values(values) => values.buffer.len() - values.true_count,
365        }
366    }
367
368    /// Returns true if all values in the mask are true.
369    #[inline]
370    pub fn all_true(&self) -> bool {
371        match &self {
372            Self::AllTrue(_) => true,
373            Self::AllFalse(0) => true,
374            Self::AllFalse(_) => false,
375            Self::Values(values) => values.buffer.len() == values.true_count,
376        }
377    }
378
379    /// Returns true if all values in the mask are false.
380    #[inline]
381    pub fn all_false(&self) -> bool {
382        self.true_count() == 0
383    }
384
385    /// Return the density of the full mask.
386    #[inline]
387    pub fn density(&self) -> f64 {
388        match &self {
389            Self::AllTrue(_) => 1.0,
390            Self::AllFalse(_) => 0.0,
391            Self::Values(values) => values.density,
392        }
393    }
394
395    /// Returns the boolean value at a given index.
396    ///
397    /// ## Panics
398    ///
399    /// Panics if the index is out of bounds.
400    #[inline]
401    pub fn value(&self, idx: usize) -> bool {
402        match self {
403            Mask::AllTrue(_) => true,
404            Mask::AllFalse(_) => false,
405            Mask::Values(values) => values.buffer.value(idx),
406        }
407    }
408
409    /// Returns the first true index in the mask.
410    pub fn first(&self) -> Option<usize> {
411        match &self {
412            Self::AllTrue(len) => (*len > 0).then_some(0),
413            Self::AllFalse(_) => None,
414            Self::Values(values) => {
415                if let Some(indices) = values.indices.get() {
416                    return indices.first().copied();
417                }
418                if let Some(slices) = values.slices.get() {
419                    return slices.first().map(|(start, _)| *start);
420                }
421                values.buffer.set_indices().next()
422            }
423        }
424    }
425
426    /// Returns the position in the mask of the nth true value.
427    pub fn rank(&self, n: usize) -> usize {
428        if n >= self.true_count() {
429            vortex_panic!(
430                "Rank {n} out of bounds for mask with true count {}",
431                self.true_count()
432            );
433        }
434        match &self {
435            Self::AllTrue(_) => n,
436            Self::AllFalse(_) => unreachable!("no true values in all-false mask"),
437            Self::Values(values) => values.indices()[n],
438        }
439    }
440
441    /// Slice the mask.
442    pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
443        let start = match range.start_bound() {
444            Bound::Included(&s) => s,
445            Bound::Excluded(&s) => s + 1,
446            Bound::Unbounded => 0,
447        };
448        let end = match range.end_bound() {
449            Bound::Included(&e) => e + 1,
450            Bound::Excluded(&e) => e,
451            Bound::Unbounded => self.len(),
452        };
453
454        assert!(start <= end);
455        assert!(start <= self.len());
456        assert!(end <= self.len());
457        let len = end - start;
458
459        match &self {
460            Self::AllTrue(_) => Self::new_true(len),
461            Self::AllFalse(_) => Self::new_false(len),
462            Self::Values(values) => Self::from_buffer(values.buffer.slice(range)),
463        }
464    }
465
466    /// Return the boolean buffer representation of the mask.
467    #[inline]
468    pub fn bit_buffer(&self) -> AllOr<&BitBuffer> {
469        match &self {
470            Self::AllTrue(_) => AllOr::All,
471            Self::AllFalse(_) => AllOr::None,
472            Self::Values(values) => AllOr::Some(&values.buffer),
473        }
474    }
475
476    /// Return a boolean buffer representation of the mask, allocating new buffers for all-true
477    /// and all-false variants.
478    #[inline]
479    pub fn to_bit_buffer(&self) -> BitBuffer {
480        match self {
481            Self::AllTrue(l) => BitBuffer::new_set(*l),
482            Self::AllFalse(l) => BitBuffer::new_unset(*l),
483            Self::Values(values) => values.bit_buffer().clone(),
484        }
485    }
486
487    /// Return a boolean buffer representation of the mask, allocating new buffers for all-true
488    /// and all-false variants.
489    #[inline]
490    pub fn into_bit_buffer(self) -> BitBuffer {
491        match self {
492            Self::AllTrue(l) => BitBuffer::new_set(l),
493            Self::AllFalse(l) => BitBuffer::new_unset(l),
494            Self::Values(values) => Arc::try_unwrap(values)
495                .map(|v| v.into_bit_buffer())
496                .unwrap_or_else(|v| v.bit_buffer().clone()),
497        }
498    }
499
500    /// Return the indices representation of the mask.
501    #[inline]
502    pub fn indices(&self) -> AllOr<&[usize]> {
503        match &self {
504            Self::AllTrue(_) => AllOr::All,
505            Self::AllFalse(_) => AllOr::None,
506            Self::Values(values) => AllOr::Some(values.indices()),
507        }
508    }
509
510    /// Return the slices representation of the mask.
511    #[inline]
512    pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
513        match &self {
514            Self::AllTrue(_) => AllOr::All,
515            Self::AllFalse(_) => AllOr::None,
516            Self::Values(values) => AllOr::Some(values.slices()),
517        }
518    }
519
520    /// Return an iterator over either indices or slices of the mask based on a density threshold.
521    #[inline]
522    pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
523        match &self {
524            Self::AllTrue(_) => AllOr::All,
525            Self::AllFalse(_) => AllOr::None,
526            Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
527        }
528    }
529
530    /// Return [`MaskValues`] if the mask is not all true or all false.
531    #[inline]
532    pub fn values(&self) -> Option<&MaskValues> {
533        if let Self::Values(values) = self {
534            Some(values)
535        } else {
536            None
537        }
538    }
539
540    /// Given monotonically increasing `indices` in [0, n_rows], returns the
541    /// count of valid elements up to each index.
542    ///
543    /// This is O(n_rows).
544    pub fn valid_counts_for_indices(&self, indices: &[usize]) -> Vec<usize> {
545        match self {
546            Self::AllTrue(_) => indices.to_vec(),
547            Self::AllFalse(_) => vec![0; indices.len()],
548            Self::Values(values) => {
549                let mut bool_iter = values.bit_buffer().iter();
550                let mut valid_counts = Vec::with_capacity(indices.len());
551                let mut valid_count = 0;
552                let mut idx = 0;
553                for &next_idx in indices {
554                    while idx < next_idx {
555                        idx += 1;
556                        valid_count += bool_iter
557                            .next()
558                            .unwrap_or_else(|| vortex_panic!("Row indices exceed array length"))
559                            as usize;
560                    }
561                    valid_counts.push(valid_count);
562                }
563
564                valid_counts
565            }
566        }
567    }
568
569    /// Limit the mask to the first `limit` true values
570    pub fn limit(self, limit: usize) -> Self {
571        // Early return optimization: if we're asking for more true values than the total
572        // length of the mask, then even if all values were true, we couldn't exceed the
573        // limit, so return the original mask unchanged.
574        if self.len() <= limit {
575            return self;
576        }
577
578        match self {
579            Mask::AllTrue(len) => {
580                Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
581            }
582            Mask::AllFalse(_) => self,
583            Mask::Values(ref mask_values) => {
584                if limit >= mask_values.true_count() {
585                    return self;
586                }
587
588                let existing_buffer = mask_values.bit_buffer();
589
590                let mut new_buffer_builder = BitBufferMut::new_unset(mask_values.len());
591                debug_assert!(limit < mask_values.len());
592
593                let ptr = new_buffer_builder.as_mut_ptr();
594                for index in existing_buffer.set_indices().take(limit) {
595                    // SAFETY: We checked that `limit` was less than the mask values length,
596                    // therefore `index` must be within the bounds of the bit buffer.
597                    unsafe { set_bit_unchecked(ptr, index) }
598                }
599
600                Self::from(new_buffer_builder.freeze())
601            }
602        }
603    }
604
605    /// Concatenate multiple masks together into a single mask.
606    pub fn concat<'a>(masks: impl Iterator<Item = &'a Self>) -> VortexResult<Self> {
607        let masks: Vec<_> = masks.collect();
608        let len = masks.iter().map(|t| t.len()).sum();
609
610        if masks.iter().all(|t| t.all_true()) {
611            return Ok(Mask::AllTrue(len));
612        }
613
614        if masks.iter().all(|t| t.all_false()) {
615            return Ok(Mask::AllFalse(len));
616        }
617
618        let mut builder = BitBufferMut::with_capacity(len);
619
620        for mask in masks {
621            match mask {
622                Mask::AllTrue(n) => builder.append_n(true, *n),
623                Mask::AllFalse(n) => builder.append_n(false, *n),
624                Mask::Values(v) => builder.append_buffer(v.bit_buffer()),
625            }
626        }
627
628        Ok(Mask::from_buffer(builder.freeze()))
629    }
630}
631
632impl MaskValues {
633    /// Returns the length of the mask.
634    #[inline]
635    pub fn len(&self) -> usize {
636        self.buffer.len()
637    }
638
639    /// Returns true if the mask is empty i.e., it's length is 0.
640    #[inline]
641    pub fn is_empty(&self) -> bool {
642        self.buffer.is_empty()
643    }
644
645    /// Returns the density of the mask.
646    #[inline]
647    pub fn density(&self) -> f64 {
648        self.density
649    }
650
651    /// Returns the true count of the mask.
652    #[inline]
653    pub fn true_count(&self) -> usize {
654        self.true_count
655    }
656
657    /// Returns the boolean buffer representation of the mask.
658    #[inline]
659    pub fn bit_buffer(&self) -> &BitBuffer {
660        &self.buffer
661    }
662
663    /// Returns the boolean buffer representation of the mask.
664    #[inline]
665    pub fn into_bit_buffer(self) -> BitBuffer {
666        self.buffer
667    }
668
669    /// Returns the boolean value at a given index.
670    #[inline]
671    pub fn value(&self, index: usize) -> bool {
672        self.buffer.value(index)
673    }
674
675    /// Constructs an indices vector from one of the other representations.
676    pub fn indices(&self) -> &[usize] {
677        self.indices.get_or_init(|| {
678            if self.true_count == 0 {
679                return vec![];
680            }
681
682            if self.true_count == self.len() {
683                return (0..self.len()).collect();
684            }
685
686            if let Some(slices) = self.slices.get() {
687                let mut indices = Vec::with_capacity(self.true_count);
688                indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
689                debug_assert!(indices.is_sorted());
690                assert_eq!(indices.len(), self.true_count);
691                return indices;
692            }
693
694            let mut indices = Vec::with_capacity(self.true_count);
695            indices.extend(self.buffer.set_indices());
696            debug_assert!(indices.is_sorted());
697            assert_eq!(indices.len(), self.true_count);
698            indices
699        })
700    }
701
702    /// Constructs a slices vector from one of the other representations.
703    #[inline]
704    pub fn slices(&self) -> &[(usize, usize)] {
705        self.slices.get_or_init(|| {
706            if self.true_count == self.len() {
707                return vec![(0, self.len())];
708            }
709
710            self.buffer.set_slices().collect()
711        })
712    }
713
714    /// Return an iterator over either indices or slices of the mask based on a density threshold.
715    #[inline]
716    pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
717        if self.density >= threshold {
718            MaskIter::Slices(self.slices())
719        } else {
720            MaskIter::Indices(self.indices())
721        }
722    }
723
724    /// Extracts the internal [`BitBuffer`].
725    pub(crate) fn into_buffer(self) -> BitBuffer {
726        self.buffer
727    }
728}
729
730/// Iterator over the indices or slices of a mask.
731pub enum MaskIter<'a> {
732    /// Slice of pre-cached indices of a mask.
733    Indices(&'a [usize]),
734    /// Slice of pre-cached slices of a mask.
735    Slices(&'a [(usize, usize)]),
736}
737
738impl From<BitBuffer> for Mask {
739    fn from(value: BitBuffer) -> Self {
740        Self::from_buffer(value)
741    }
742}
743
744impl FromIterator<bool> for Mask {
745    #[inline]
746    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
747        Self::from_buffer(BitBuffer::from_iter(iter))
748    }
749}
750
751impl FromIterator<Mask> for Mask {
752    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
753        let masks = iter
754            .into_iter()
755            .filter(|m| !m.is_empty())
756            .collect::<Vec<_>>();
757        let total_length = masks.iter().map(|v| v.len()).sum();
758
759        // If they're all valid, then return a single validity.
760        if masks.iter().all(|v| v.all_true()) {
761            return Self::AllTrue(total_length);
762        }
763        // If they're all invalid, then return a single invalidity.
764        if masks.iter().all(|v| v.all_false()) {
765            return Self::AllFalse(total_length);
766        }
767
768        // Else, construct the boolean buffer
769        let mut buffer = BitBufferMut::with_capacity(total_length);
770        for mask in masks {
771            match mask {
772                Mask::AllTrue(count) => buffer.append_n(true, count),
773                Mask::AllFalse(count) => buffer.append_n(false, count),
774                Mask::Values(values) => {
775                    buffer.append_buffer(values.bit_buffer());
776                }
777            };
778        }
779        Self::from_buffer(buffer.freeze())
780    }
781}