vortex_mask/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! A mask is a set of sorted unique positive integers.
5#![deny(missing_docs)]
6mod bitops;
7mod eq;
8mod intersect_by_rank;
9mod iter_bools;
10
11use std::cmp::Ordering;
12use std::fmt::{Debug, Formatter};
13use std::sync::{Arc, OnceLock};
14
15use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};
16use itertools::Itertools;
17use vortex_error::{VortexResult, vortex_err};
18
19/// Represents a set of values that are all included, all excluded, or some mixture of both.
20pub enum AllOr<T> {
21    /// All values are included.
22    All,
23    /// No values are included.
24    None,
25    /// Some values are included.
26    Some(T),
27}
28
29impl<T> AllOr<T> {
30    /// Returns the `Some` variant of the enum, or a default value.
31    pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
32    where
33        F: FnOnce() -> T,
34        G: FnOnce() -> T,
35    {
36        match self {
37            Self::Some(v) => v,
38            AllOr::All => all_true(),
39            AllOr::None => all_false(),
40        }
41    }
42}
43
44impl<T> AllOr<&T> {
45    /// Clone the inner value.
46    pub fn cloned(self) -> AllOr<T>
47    where
48        T: Clone,
49    {
50        match self {
51            Self::All => AllOr::All,
52            Self::None => AllOr::None,
53            Self::Some(v) => AllOr::Some(v.clone()),
54        }
55    }
56}
57
58impl<T> Debug for AllOr<T>
59where
60    T: Debug,
61{
62    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
63        match self {
64            Self::All => f.write_str("All"),
65            Self::None => f.write_str("None"),
66            Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
67        }
68    }
69}
70
71impl<T> PartialEq for AllOr<T>
72where
73    T: PartialEq,
74{
75    fn eq(&self, other: &Self) -> bool {
76        match (self, other) {
77            (Self::All, Self::All) => true,
78            (Self::None, Self::None) => true,
79            (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
80            _ => false,
81        }
82    }
83}
84
85impl<T> Eq for AllOr<T> where T: Eq {}
86
87/// Represents a set of sorted unique positive integers.
88///
89/// A [`Mask`] can be constructed from various representations, and converted to various
90/// others. Internally, these are cached.
91#[derive(Clone, Debug)]
92pub enum Mask {
93    /// All values are included.
94    AllTrue(usize),
95    /// No values are included.
96    AllFalse(usize),
97    /// Some values are included, represented as a [`BooleanBuffer`].
98    Values(Arc<MaskValues>),
99}
100
101/// Represents the values of a [`Mask`] that contains some true and some false elements.
102#[derive(Debug)]
103pub struct MaskValues {
104    buffer: BooleanBuffer,
105
106    // We cached the indices and slices representations, since it can be faster than iterating
107    // the bit-mask over and over again.
108    indices: OnceLock<Vec<usize>>,
109    slices: OnceLock<Vec<(usize, usize)>>,
110
111    // Pre-computed values.
112    true_count: usize,
113    // i.e., the fraction of values that are true
114    density: f64,
115}
116
117impl MaskValues {
118    /// Returns the length of the mask.
119    #[inline]
120    pub fn len(&self) -> usize {
121        self.buffer.len()
122    }
123
124    /// Returns true if the mask is empty i.e., it's length is 0.
125    #[inline]
126    pub fn is_empty(&self) -> bool {
127        self.buffer.is_empty()
128    }
129
130    /// Returns the true count of the mask.
131    pub fn true_count(&self) -> usize {
132        self.true_count
133    }
134
135    /// Returns the boolean buffer representation of the mask.
136    pub fn boolean_buffer(&self) -> &BooleanBuffer {
137        &self.buffer
138    }
139
140    /// Returns the boolean value at a given index.
141    pub fn value(&self, index: usize) -> bool {
142        self.buffer.value(index)
143    }
144
145    /// Constructs an indices vector from one of the other representations.
146    pub fn indices(&self) -> &[usize] {
147        self.indices.get_or_init(|| {
148            if self.true_count == 0 {
149                return vec![];
150            }
151
152            if self.true_count == self.len() {
153                return (0..self.len()).collect();
154            }
155
156            if let Some(slices) = self.slices.get() {
157                let mut indices = Vec::with_capacity(self.true_count);
158                indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
159                debug_assert!(indices.is_sorted());
160                assert_eq!(indices.len(), self.true_count);
161                return indices;
162            }
163
164            let mut indices = Vec::with_capacity(self.true_count);
165            indices.extend(self.buffer.set_indices());
166            debug_assert!(indices.is_sorted());
167            assert_eq!(indices.len(), self.true_count);
168            indices
169        })
170    }
171
172    /// Constructs a slices vector from one of the other representations.
173    #[allow(clippy::cast_possible_truncation)]
174    pub fn slices(&self) -> &[(usize, usize)] {
175        self.slices.get_or_init(|| {
176            if self.true_count == self.len() {
177                return vec![(0, self.len())];
178            }
179
180            self.buffer.set_slices().collect()
181        })
182    }
183
184    /// Return an iterator over either indices or slices of the mask based on a density threshold.
185    pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
186        if self.density >= threshold {
187            MaskIter::Slices(self.slices())
188        } else {
189            MaskIter::Indices(self.indices())
190        }
191    }
192}
193
194impl Mask {
195    /// Create a new Mask where all values are set.
196    pub fn new_true(length: usize) -> Self {
197        Self::AllTrue(length)
198    }
199
200    /// Create a new Mask where no values are set.
201    pub fn new_false(length: usize) -> Self {
202        Self::AllFalse(length)
203    }
204
205    /// Create a new [`Mask`] from a [`BooleanBuffer`].
206    pub fn from_buffer(buffer: BooleanBuffer) -> Self {
207        let len = buffer.len();
208        let true_count = buffer.count_set_bits();
209
210        if true_count == 0 {
211            return Self::AllFalse(len);
212        }
213        if true_count == len {
214            return Self::AllTrue(len);
215        }
216
217        Self::Values(Arc::new(MaskValues {
218            buffer,
219            indices: Default::default(),
220            slices: Default::default(),
221            true_count,
222            density: true_count as f64 / len as f64,
223        }))
224    }
225
226    /// Create a new [`Mask`] from a [`Vec<usize>`].
227    // TODO(ngates): this should take an IntoIterator<usize>.
228    pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
229        let true_count = indices.len();
230        assert!(indices.is_sorted(), "Mask indices must be sorted");
231        assert!(
232            indices.last().is_none_or(|&idx| idx < len),
233            "Mask indices must be in bounds (len={len})"
234        );
235
236        if true_count == 0 {
237            return Self::AllFalse(len);
238        }
239        if true_count == len {
240            return Self::AllTrue(len);
241        }
242
243        let mut buf = BooleanBufferBuilder::new(len);
244        // TODO(ngates): for dense indices, we can do better by collecting into u64s.
245        buf.append_n(len, false);
246        indices.iter().for_each(|idx| buf.set_bit(*idx, true));
247        debug_assert_eq!(buf.len(), len);
248
249        Self::Values(Arc::new(MaskValues {
250            buffer: buf.finish(),
251            indices: OnceLock::from(indices),
252            slices: Default::default(),
253            true_count,
254            density: true_count as f64 / len as f64,
255        }))
256    }
257
258    /// Create a new [`Mask`] from an [`IntoIterator<Item = usize>`] of indices to be excluded.
259    pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
260        let mut buf = BooleanBufferBuilder::new(len);
261        buf.append_n(len, true);
262
263        let mut false_count: usize = 0;
264        indices.into_iter().for_each(|idx| {
265            buf.set_bit(idx, false);
266            false_count += 1;
267        });
268        debug_assert_eq!(buf.len(), len);
269        let true_count = len - false_count;
270
271        Self::Values(Arc::new(MaskValues {
272            buffer: buf.finish(),
273            indices: Default::default(),
274            slices: Default::default(),
275            true_count,
276            density: true_count as f64 / len as f64,
277        }))
278    }
279
280    /// Create a new [`Mask`] from a [`Vec<(usize, usize)>`] where each range
281    /// represents a contiguous range of true values.
282    pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
283        Self::check_slices(len, &vec);
284        Self::from_slices_unchecked(len, vec)
285    }
286
287    fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
288        #[cfg(debug_assertions)]
289        Self::check_slices(len, &slices);
290
291        let true_count = slices.iter().map(|(b, e)| e - b).sum();
292        if true_count == 0 {
293            return Self::AllFalse(len);
294        }
295        if true_count == len {
296            return Self::AllTrue(len);
297        }
298
299        let mut buf = BooleanBufferBuilder::new(len);
300        for (start, end) in slices.iter().copied() {
301            buf.append_n(start - buf.len(), false);
302            buf.append_n(end - start, true);
303        }
304        if let Some((_, end)) = slices.last() {
305            buf.append_n(len - end, false);
306        }
307        debug_assert_eq!(buf.len(), len);
308
309        Self::Values(Arc::new(MaskValues {
310            buffer: buf.finish(),
311            indices: Default::default(),
312            slices: OnceLock::from(slices),
313            true_count,
314            density: true_count as f64 / len as f64,
315        }))
316    }
317
318    #[inline(always)]
319    fn check_slices(len: usize, vec: &[(usize, usize)]) {
320        assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
321        for (first, second) in vec.iter().tuple_windows() {
322            assert!(
323                first.0 < second.0,
324                "Slices must be sorted, got {first:?} and {second:?}"
325            );
326            assert!(
327                first.1 <= second.0,
328                "Slices must be non-overlapping, got {first:?} and {second:?}"
329            );
330        }
331    }
332
333    /// Create a new [`Mask`] from the intersection of two indices slices.
334    pub fn from_intersection_indices(
335        len: usize,
336        lhs: impl Iterator<Item = usize>,
337        rhs: impl Iterator<Item = usize>,
338    ) -> Self {
339        let mut intersection = Vec::with_capacity(len);
340        let mut lhs = lhs.peekable();
341        let mut rhs = rhs.peekable();
342        while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
343            match l.cmp(&r) {
344                Ordering::Less => {
345                    lhs.next();
346                }
347                Ordering::Greater => {
348                    rhs.next();
349                }
350                Ordering::Equal => {
351                    intersection.push(l);
352                    lhs.next();
353                    rhs.next();
354                }
355            }
356        }
357        Self::from_indices(len, intersection)
358    }
359
360    /// Returns the length of the mask (not the number of true values).
361    #[inline]
362    pub fn len(&self) -> usize {
363        match self {
364            Self::AllTrue(len) => *len,
365            Self::AllFalse(len) => *len,
366            Self::Values(values) => values.len(),
367        }
368    }
369
370    /// Returns true if the mask is empty i.e., it's length is 0.
371    #[inline]
372    pub fn is_empty(&self) -> bool {
373        match self {
374            Self::AllTrue(len) => *len == 0,
375            Self::AllFalse(len) => *len == 0,
376            Self::Values(values) => values.is_empty(),
377        }
378    }
379
380    /// Get the true count of the mask.
381    #[inline]
382    pub fn true_count(&self) -> usize {
383        match &self {
384            Self::AllTrue(len) => *len,
385            Self::AllFalse(_) => 0,
386            Self::Values(values) => values.true_count,
387        }
388    }
389
390    /// Get the false count of the mask.
391    #[inline]
392    pub fn false_count(&self) -> usize {
393        match &self {
394            Self::AllTrue(_) => 0,
395            Self::AllFalse(len) => *len,
396            Self::Values(values) => values.buffer.len() - values.true_count,
397        }
398    }
399
400    /// Returns true if all values in the mask are true.
401    #[inline]
402    pub fn all_true(&self) -> bool {
403        match &self {
404            Self::AllTrue(_) => true,
405            Self::AllFalse(0) => true,
406            Self::AllFalse(_) => false,
407            Self::Values(values) => values.buffer.len() == values.true_count,
408        }
409    }
410
411    /// Returns true if all values in the mask are false.
412    #[inline]
413    pub fn all_false(&self) -> bool {
414        self.true_count() == 0
415    }
416
417    /// Return the density of the full mask.
418    #[inline]
419    pub fn density(&self) -> f64 {
420        match &self {
421            Self::AllTrue(_) => 1.0,
422            Self::AllFalse(_) => 0.0,
423            Self::Values(values) => values.density,
424        }
425    }
426
427    /// Returns the boolean value at a given index.
428    ///
429    /// ## Panics
430    ///
431    /// Panics if the index is out of bounds.
432    pub fn value(&self, idx: usize) -> bool {
433        match self {
434            Mask::AllTrue(_) => true,
435            Mask::AllFalse(_) => false,
436            Mask::Values(values) => values.buffer.value(idx),
437        }
438    }
439
440    /// Returns the first true index in the mask.
441    pub fn first(&self) -> Option<usize> {
442        match &self {
443            Self::AllTrue(len) => (*len > 0).then_some(0),
444            Self::AllFalse(_) => None,
445            Self::Values(values) => {
446                if let Some(indices) = values.indices.get() {
447                    return indices.first().copied();
448                }
449                if let Some(slices) = values.slices.get() {
450                    return slices.first().map(|(start, _)| *start);
451                }
452                values.buffer.set_indices().next()
453            }
454        }
455    }
456
457    /// Slice the mask.
458    pub fn slice(&self, offset: usize, length: usize) -> Self {
459        assert!(offset + length <= self.len());
460        match &self {
461            Self::AllTrue(_) => Self::new_true(length),
462            Self::AllFalse(_) => Self::new_false(length),
463            Self::Values(values) => Self::from_buffer(values.buffer.slice(offset, length)),
464        }
465    }
466
467    /// Return the boolean buffer representation of the mask.
468    pub fn boolean_buffer(&self) -> AllOr<&BooleanBuffer> {
469        match &self {
470            Self::AllTrue(_) => AllOr::All,
471            Self::AllFalse(_) => AllOr::None,
472            Self::Values(values) => AllOr::Some(&values.buffer),
473        }
474    }
475
476    /// Return a boolean buffer representation of the mask, allocating new buffers for all-true
477    /// and all-false variants.
478    pub fn to_boolean_buffer(&self) -> BooleanBuffer {
479        match self {
480            Self::AllTrue(l) => BooleanBuffer::new_set(*l),
481            Self::AllFalse(l) => BooleanBuffer::new_unset(*l),
482            Self::Values(values) => values.boolean_buffer().clone(),
483        }
484    }
485
486    /// Returns an Arrow null buffer representation of the mask.
487    pub fn to_null_buffer(&self) -> Option<NullBuffer> {
488        match self {
489            Mask::AllTrue(_) => None,
490            Mask::AllFalse(l) => Some(NullBuffer::new_null(*l)),
491            Mask::Values(values) => Some(NullBuffer::from(values.buffer.clone())),
492        }
493    }
494
495    /// Return the indices representation of the mask.
496    pub fn indices(&self) -> AllOr<&[usize]> {
497        match &self {
498            Self::AllTrue(_) => AllOr::All,
499            Self::AllFalse(_) => AllOr::None,
500            Self::Values(values) => AllOr::Some(values.indices()),
501        }
502    }
503
504    /// Return the slices representation of the mask.
505    pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
506        match &self {
507            Self::AllTrue(_) => AllOr::All,
508            Self::AllFalse(_) => AllOr::None,
509            Self::Values(values) => AllOr::Some(values.slices()),
510        }
511    }
512
513    /// Return an iterator over either indices or slices of the mask based on a density threshold.
514    pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
515        match &self {
516            Self::AllTrue(_) => AllOr::All,
517            Self::AllFalse(_) => AllOr::None,
518            Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
519        }
520    }
521
522    /// Return [`MaskValues`] if the mask is not all true or all false.
523    pub fn values(&self) -> Option<&MaskValues> {
524        match self {
525            Self::Values(values) => Some(values),
526            _ => None,
527        }
528    }
529
530    /// Given monotonically increasing `indices` in [0, n_rows], returns the
531    /// count of valid elements up to each index.
532    ///
533    /// This is O(n_rows).
534    pub fn valid_counts_for_indices(&self, indices: &[usize]) -> VortexResult<Vec<usize>> {
535        Ok(match self {
536            Self::AllTrue(_) => indices.to_vec(),
537            Self::AllFalse(_) => vec![0; indices.len()],
538            Self::Values(values) => {
539                let mut bool_iter = values.boolean_buffer().iter();
540                let mut valid_counts = Vec::with_capacity(indices.len());
541                let mut valid_count = 0;
542                let mut idx = 0;
543                for &next_idx in indices {
544                    while idx < next_idx {
545                        idx += 1;
546                        valid_count += bool_iter
547                            .next()
548                            .ok_or_else(|| vortex_err!("Row indices exceed array length"))?
549                            as usize;
550                    }
551                    valid_counts.push(valid_count);
552                }
553
554                valid_counts
555            }
556        })
557    }
558
559    /// Limit the mask to the first `limit` true values
560    pub fn limit(self, limit: usize) -> Self {
561        if self.len() <= limit {
562            return self;
563        }
564
565        match self {
566            Mask::AllTrue(len) => {
567                Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
568            }
569            Mask::AllFalse(_) => self,
570            Mask::Values(ref mask_values) => {
571                if limit >= mask_values.true_count() {
572                    return self;
573                }
574
575                let existing_buffer = mask_values.boolean_buffer();
576
577                let mut new_buffer_builder = BooleanBufferBuilder::new(mask_values.len());
578                new_buffer_builder.append_n(mask_values.len(), false);
579
580                for index in existing_buffer.set_indices().take(limit) {
581                    new_buffer_builder.set_bit(index, true);
582                }
583
584                Self::from(new_buffer_builder.finish())
585            }
586        }
587    }
588}
589
590/// Iterator over the indices or slices of a mask.
591pub enum MaskIter<'a> {
592    /// Slice of pre-cached indices of a mask.
593    Indices(&'a [usize]),
594    /// Slice of pre-cached slices of a mask.
595    Slices(&'a [(usize, usize)]),
596}
597
598impl From<BooleanBuffer> for Mask {
599    fn from(value: BooleanBuffer) -> Self {
600        Self::from_buffer(value)
601    }
602}
603
604impl FromIterator<bool> for Mask {
605    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
606        Self::from_buffer(BooleanBuffer::from_iter(iter))
607    }
608}
609
610impl FromIterator<Mask> for Mask {
611    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
612        let masks = iter
613            .into_iter()
614            .filter(|m| !m.is_empty())
615            .collect::<Vec<_>>();
616        let total_length = masks.iter().map(|v| v.len()).sum();
617
618        // If they're all valid, then return a single validity.
619        if masks.iter().all(|v| v.all_true()) {
620            return Self::AllTrue(total_length);
621        }
622        // If they're all invalid, then return a single invalidity.
623        if masks.iter().all(|v| v.all_false()) {
624            return Self::AllFalse(total_length);
625        }
626
627        // Else, construct the boolean buffer
628        let mut buffer = BooleanBufferBuilder::new(total_length);
629        for mask in masks {
630            match mask {
631                Mask::AllTrue(count) => buffer.append_n(count, true),
632                Mask::AllFalse(count) => buffer.append_n(count, false),
633                Mask::Values(values) => {
634                    buffer.append_buffer(values.boolean_buffer());
635                }
636            };
637        }
638        Self::from_buffer(buffer.finish())
639    }
640}
641
642#[cfg(test)]
643mod test {
644    use super::*;
645
646    #[test]
647    fn mask_all_true() {
648        let mask = Mask::new_true(5);
649        assert_eq!(mask.len(), 5);
650        assert_eq!(mask.true_count(), 5);
651        assert_eq!(mask.density(), 1.0);
652        assert_eq!(mask.indices(), AllOr::All);
653        assert_eq!(mask.slices(), AllOr::All);
654        assert_eq!(mask.boolean_buffer(), AllOr::All,);
655    }
656
657    #[test]
658    fn mask_all_false() {
659        let mask = Mask::new_false(5);
660        assert_eq!(mask.len(), 5);
661        assert_eq!(mask.true_count(), 0);
662        assert_eq!(mask.density(), 0.0);
663        assert_eq!(mask.indices(), AllOr::None);
664        assert_eq!(mask.slices(), AllOr::None);
665        assert_eq!(mask.boolean_buffer(), AllOr::None,);
666    }
667
668    #[test]
669    fn mask_from() {
670        let masks = [
671            Mask::from_indices(5, vec![0, 2, 3]),
672            Mask::from_slices(5, vec![(0, 1), (2, 4)]),
673            Mask::from_buffer(BooleanBuffer::from_iter([true, false, true, true, false])),
674        ];
675
676        for mask in &masks {
677            assert_eq!(mask.len(), 5);
678            assert_eq!(mask.true_count(), 3);
679            assert_eq!(mask.density(), 0.6);
680            assert_eq!(mask.indices(), AllOr::Some(&[0, 2, 3][..]));
681            assert_eq!(mask.slices(), AllOr::Some(&[(0, 1), (2, 4)][..]));
682            assert_eq!(
683                mask.boolean_buffer(),
684                AllOr::Some(&BooleanBuffer::from_iter([true, false, true, true, false]))
685            );
686        }
687    }
688
689    #[test]
690    fn limit_all_true_mask() {
691        let all_true = Mask::new_true(4);
692        let limited_mask = all_true.clone().limit(2);
693        assert_eq!(all_true.len(), limited_mask.len());
694        assert_eq!(limited_mask.true_count(), 2);
695        assert_eq!(
696            limited_mask.boolean_buffer(),
697            AllOr::Some(&BooleanBuffer::from_iter([true, true, false, false]))
698        );
699
700        let limited_mask = all_true.clone().limit(5);
701        assert_eq!(limited_mask, all_true);
702    }
703
704    #[test]
705    fn limit_mask_values() {
706        let original_mask = Mask::from_iter([true, true, false, true, false, true]);
707        let limited_mask = original_mask.clone().limit(2);
708
709        assert_eq!(
710            limited_mask.boolean_buffer(),
711            AllOr::Some(&BooleanBuffer::from_iter([
712                true, true, false, false, false, false
713            ]))
714        );
715        assert_eq!(limited_mask.true_count(), 2);
716
717        let limited_mask = original_mask.limit(3);
718
719        assert_eq!(
720            limited_mask.boolean_buffer(),
721            AllOr::Some(&BooleanBuffer::from_iter([
722                true, true, false, true, false, false
723            ]))
724        );
725        assert_eq!(limited_mask.true_count(), 3);
726
727        let original_mask = Mask::from_iter([true, true, false, true, false, true]);
728        let limited_mask = original_mask.clone().limit(100);
729
730        assert_eq!(original_mask, limited_mask);
731    }
732
733    #[test]
734    fn length_zero_masks() {
735        let all_false = Mask::new_false(0);
736        let all_true = Mask::new_true(0);
737        let buffer_set = Mask::from_buffer(BooleanBuffer::new_set(0));
738        let buffer_unset = Mask::from_buffer(BooleanBuffer::new_unset(0));
739
740        assert!(all_false.all_false());
741        assert!(all_false.all_true());
742        assert!(all_true.all_false());
743        assert!(all_true.all_true());
744        assert!(buffer_set.all_false());
745        assert!(buffer_set.all_true());
746        assert!(buffer_unset.all_false());
747        assert!(buffer_unset.all_true());
748    }
749}