vortex_mask/
lib.rs

1//! A mask is a set of sorted unique positive integers.
2#![deny(missing_docs)]
3mod bitops;
4mod eq;
5mod intersect_by_rank;
6mod iter_bools;
7
8use std::cmp::Ordering;
9use std::fmt::{Debug, Formatter};
10use std::sync::{Arc, OnceLock};
11
12use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};
13use itertools::Itertools;
14use vortex_error::{VortexResult, vortex_err};
15
16/// Represents a set of values that are all included, all excluded, or some mixture of both.
17pub enum AllOr<T> {
18    /// All values are included.
19    All,
20    /// No values are included.
21    None,
22    /// Some values are included.
23    Some(T),
24}
25
26impl<T> AllOr<T> {
27    /// Returns the `Some` variant of the enum, or a default value.
28    pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
29    where
30        F: FnOnce() -> T,
31        G: FnOnce() -> T,
32    {
33        match self {
34            Self::Some(v) => v,
35            AllOr::All => all_true(),
36            AllOr::None => all_false(),
37        }
38    }
39}
40
41impl<T> AllOr<&T> {
42    /// Clone the inner value.
43    pub fn cloned(self) -> AllOr<T>
44    where
45        T: Clone,
46    {
47        match self {
48            Self::All => AllOr::All,
49            Self::None => AllOr::None,
50            Self::Some(v) => AllOr::Some(v.clone()),
51        }
52    }
53}
54
55impl<T> Debug for AllOr<T>
56where
57    T: Debug,
58{
59    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60        match self {
61            Self::All => f.write_str("All"),
62            Self::None => f.write_str("None"),
63            Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
64        }
65    }
66}
67
68impl<T> PartialEq for AllOr<T>
69where
70    T: PartialEq,
71{
72    fn eq(&self, other: &Self) -> bool {
73        match (self, other) {
74            (Self::All, Self::All) => true,
75            (Self::None, Self::None) => true,
76            (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
77            _ => false,
78        }
79    }
80}
81
82impl<T> Eq for AllOr<T> where T: Eq {}
83
84/// Represents a set of sorted unique positive integers.
85///
86/// A [`Mask`] can be constructed from various representations, and converted to various
87/// others. Internally, these are cached.
88#[derive(Clone, Debug)]
89pub enum Mask {
90    /// All values are included.
91    AllTrue(usize),
92    /// No values are included.
93    AllFalse(usize),
94    /// Some values are included, represented as a [`BooleanBuffer`].
95    Values(Arc<MaskValues>),
96}
97
98/// Represents the values of a [`Mask`] that contains some true and some false elements.
99#[derive(Debug)]
100pub struct MaskValues {
101    buffer: BooleanBuffer,
102
103    // We cached the indices and slices representations, since it can be faster than iterating
104    // the bit-mask over and over again.
105    indices: OnceLock<Vec<usize>>,
106    slices: OnceLock<Vec<(usize, usize)>>,
107
108    // Pre-computed values.
109    true_count: usize,
110    // i.e., the fraction of values that are true
111    density: f64,
112}
113
114impl MaskValues {
115    /// Returns the length of the mask.
116    #[inline]
117    #[allow(clippy::len_without_is_empty)]
118    pub fn len(&self) -> usize {
119        self.buffer.len()
120    }
121
122    /// Returns the true count of the mask.
123    pub fn true_count(&self) -> usize {
124        self.true_count
125    }
126
127    /// Returns the boolean buffer representation of the mask.
128    pub fn boolean_buffer(&self) -> &BooleanBuffer {
129        &self.buffer
130    }
131
132    /// Returns the boolean value at a given index.
133    pub fn value(&self, index: usize) -> bool {
134        self.buffer.value(index)
135    }
136
137    /// Constructs an indices vector from one of the other representations.
138    pub fn indices(&self) -> &[usize] {
139        self.indices.get_or_init(|| {
140            if self.true_count == 0 {
141                return vec![];
142            }
143
144            if self.true_count == self.len() {
145                return (0..self.len()).collect();
146            }
147
148            if let Some(slices) = self.slices.get() {
149                let mut indices = Vec::with_capacity(self.true_count);
150                indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
151                debug_assert!(indices.is_sorted());
152                assert_eq!(indices.len(), self.true_count);
153                return indices;
154            }
155
156            let mut indices = Vec::with_capacity(self.true_count);
157            indices.extend(self.buffer.set_indices());
158            debug_assert!(indices.is_sorted());
159            assert_eq!(indices.len(), self.true_count);
160            indices
161        })
162    }
163
164    /// Constructs a slices vector from one of the other representations.
165    #[allow(clippy::cast_possible_truncation)]
166    pub fn slices(&self) -> &[(usize, usize)] {
167        self.slices.get_or_init(|| {
168            if self.true_count == self.len() {
169                return vec![(0, self.len())];
170            }
171
172            self.buffer.set_slices().collect()
173        })
174    }
175
176    /// Return an iterator over either indices or slices of the mask based on a density threshold.
177    pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
178        if self.density >= threshold {
179            MaskIter::Slices(self.slices())
180        } else {
181            MaskIter::Indices(self.indices())
182        }
183    }
184}
185
186impl Mask {
187    /// Create a new Mask where all values are set.
188    pub fn new_true(length: usize) -> Self {
189        Self::AllTrue(length)
190    }
191
192    /// Create a new Mask where no values are set.
193    pub fn new_false(length: usize) -> Self {
194        Self::AllFalse(length)
195    }
196
197    /// Create a new [`Mask`] from a [`BooleanBuffer`].
198    pub fn from_buffer(buffer: BooleanBuffer) -> Self {
199        let len = buffer.len();
200        let true_count = buffer.count_set_bits();
201
202        if true_count == 0 {
203            return Self::AllFalse(len);
204        }
205        if true_count == len {
206            return Self::AllTrue(len);
207        }
208
209        Self::Values(Arc::new(MaskValues {
210            buffer,
211            indices: Default::default(),
212            slices: Default::default(),
213            true_count,
214            density: true_count as f64 / len as f64,
215        }))
216    }
217
218    /// Create a new [`Mask`] from a [`Vec<usize>`].
219    // TODO(ngates): this should take an IntoIterator<usize>.
220    pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
221        let true_count = indices.len();
222        assert!(indices.is_sorted(), "Mask indices must be sorted");
223        assert!(
224            indices.last().is_none_or(|&idx| idx < len),
225            "Mask indices must be in bounds (len={len})"
226        );
227
228        if true_count == 0 {
229            return Self::AllFalse(len);
230        }
231        if true_count == len {
232            return Self::AllTrue(len);
233        }
234
235        let mut buf = BooleanBufferBuilder::new(len);
236        // TODO(ngates): for dense indices, we can do better by collecting into u64s.
237        buf.append_n(len, false);
238        indices.iter().for_each(|idx| buf.set_bit(*idx, true));
239        debug_assert_eq!(buf.len(), len);
240
241        Self::Values(Arc::new(MaskValues {
242            buffer: buf.finish(),
243            indices: OnceLock::from(indices),
244            slices: Default::default(),
245            true_count,
246            density: true_count as f64 / len as f64,
247        }))
248    }
249
250    /// Create a new [`Mask`] from an [`IntoIterator<Item = usize>`] of indices to be excluded.
251    pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
252        let mut buf = BooleanBufferBuilder::new(len);
253        buf.append_n(len, true);
254
255        let mut false_count: usize = 0;
256        indices.into_iter().for_each(|idx| {
257            buf.set_bit(idx, false);
258            false_count += 1;
259        });
260        debug_assert_eq!(buf.len(), len);
261        let true_count = len - false_count;
262
263        Self::Values(Arc::new(MaskValues {
264            buffer: buf.finish(),
265            indices: Default::default(),
266            slices: Default::default(),
267            true_count,
268            density: true_count as f64 / len as f64,
269        }))
270    }
271
272    /// Create a new [`Mask`] from a [`Vec<(usize, usize)>`] where each range
273    /// represents a contiguous range of true values.
274    pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
275        Self::check_slices(len, &vec);
276        Self::from_slices_unchecked(len, vec)
277    }
278
279    fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
280        #[cfg(debug_assertions)]
281        Self::check_slices(len, &slices);
282
283        let true_count = slices.iter().map(|(b, e)| e - b).sum();
284        if true_count == 0 {
285            return Self::AllFalse(len);
286        }
287        if true_count == len {
288            return Self::AllTrue(len);
289        }
290
291        let mut buf = BooleanBufferBuilder::new(len);
292        for (start, end) in slices.iter().copied() {
293            buf.append_n(start - buf.len(), false);
294            buf.append_n(end - start, true);
295        }
296        if let Some((_, end)) = slices.last() {
297            buf.append_n(len - end, false);
298        }
299        debug_assert_eq!(buf.len(), len);
300
301        Self::Values(Arc::new(MaskValues {
302            buffer: buf.finish(),
303            indices: Default::default(),
304            slices: OnceLock::from(slices),
305            true_count,
306            density: true_count as f64 / len as f64,
307        }))
308    }
309
310    #[inline(always)]
311    fn check_slices(len: usize, vec: &[(usize, usize)]) {
312        assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
313        for (first, second) in vec.iter().tuple_windows() {
314            assert!(
315                first.0 < second.0,
316                "Slices must be sorted, got {first:?} and {second:?}"
317            );
318            assert!(
319                first.1 <= second.0,
320                "Slices must be non-overlapping, got {first:?} and {second:?}"
321            );
322        }
323    }
324
325    /// Create a new [`Mask`] from the intersection of two indices slices.
326    pub fn from_intersection_indices(
327        len: usize,
328        lhs: impl Iterator<Item = usize>,
329        rhs: impl Iterator<Item = usize>,
330    ) -> Self {
331        let mut intersection = Vec::with_capacity(len);
332        let mut lhs = lhs.peekable();
333        let mut rhs = rhs.peekable();
334        while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
335            match l.cmp(&r) {
336                Ordering::Less => {
337                    lhs.next();
338                }
339                Ordering::Greater => {
340                    rhs.next();
341                }
342                Ordering::Equal => {
343                    intersection.push(l);
344                    lhs.next();
345                    rhs.next();
346                }
347            }
348        }
349        Self::from_indices(len, intersection)
350    }
351
352    /// Returns the length of the mask (not the number of true values).
353    #[inline]
354    // It's confusing to provide is_empty, does it mean len == 0 or true_count == 0?
355    #[allow(clippy::len_without_is_empty)]
356    pub fn len(&self) -> usize {
357        match &self {
358            Self::AllTrue(len) => *len,
359            Self::AllFalse(len) => *len,
360            Self::Values(values) => values.len(),
361        }
362    }
363
364    /// Get the true count of the mask.
365    #[inline]
366    pub fn true_count(&self) -> usize {
367        match &self {
368            Self::AllTrue(len) => *len,
369            Self::AllFalse(_) => 0,
370            Self::Values(values) => values.true_count,
371        }
372    }
373
374    /// Get the false count of the mask.
375    #[inline]
376    pub fn false_count(&self) -> usize {
377        match &self {
378            Self::AllTrue(_) => 0,
379            Self::AllFalse(len) => *len,
380            Self::Values(values) => values.buffer.len() - values.true_count,
381        }
382    }
383
384    /// Returns true if all values in the mask are true.
385    #[inline]
386    pub fn all_true(&self) -> bool {
387        match &self {
388            Self::AllTrue(_) => true,
389            Self::AllFalse(_) => false,
390            Self::Values(values) => values.buffer.len() == values.true_count,
391        }
392    }
393
394    /// Returns true if all values in the mask are false.
395    #[inline]
396    pub fn all_false(&self) -> bool {
397        self.true_count() == 0
398    }
399
400    /// Return the density of the full mask.
401    #[inline]
402    pub fn density(&self) -> f64 {
403        match &self {
404            Self::AllTrue(_) => 1.0,
405            Self::AllFalse(_) => 0.0,
406            Self::Values(values) => values.density,
407        }
408    }
409
410    /// Returns the boolean value at a given index.
411    ///
412    /// ## Panics
413    ///
414    /// Panics if the index is out of bounds.
415    pub fn value(&self, idx: usize) -> bool {
416        match self {
417            Mask::AllTrue(_) => true,
418            Mask::AllFalse(_) => false,
419            Mask::Values(values) => values.buffer.value(idx),
420        }
421    }
422
423    /// Returns the first true index in the mask.
424    pub fn first(&self) -> Option<usize> {
425        match &self {
426            Self::AllTrue(len) => (*len > 0).then_some(0),
427            Self::AllFalse(_) => None,
428            Self::Values(values) => {
429                if let Some(indices) = values.indices.get() {
430                    return indices.first().copied();
431                }
432                if let Some(slices) = values.slices.get() {
433                    return slices.first().map(|(start, _)| *start);
434                }
435                values.buffer.set_indices().next()
436            }
437        }
438    }
439
440    /// Slice the mask.
441    pub fn slice(&self, offset: usize, length: usize) -> Self {
442        assert!(offset + length <= self.len());
443        match &self {
444            Self::AllTrue(_) => Self::new_true(length),
445            Self::AllFalse(_) => Self::new_false(length),
446            Self::Values(values) => Self::from_buffer(values.buffer.slice(offset, length)),
447        }
448    }
449
450    /// Return the boolean buffer representation of the mask.
451    pub fn boolean_buffer(&self) -> AllOr<&BooleanBuffer> {
452        match &self {
453            Self::AllTrue(_) => AllOr::All,
454            Self::AllFalse(_) => AllOr::None,
455            Self::Values(values) => AllOr::Some(&values.buffer),
456        }
457    }
458
459    /// Return a boolean buffer representation of the mask, allocating new buffers for all-true
460    /// and all-false variants.
461    pub fn to_boolean_buffer(&self) -> BooleanBuffer {
462        match self {
463            Self::AllTrue(l) => BooleanBuffer::new_set(*l),
464            Self::AllFalse(l) => BooleanBuffer::new_unset(*l),
465            Self::Values(values) => values.boolean_buffer().clone(),
466        }
467    }
468
469    /// Returns an Arrow null buffer representation of the mask.
470    pub fn to_null_buffer(&self) -> Option<NullBuffer> {
471        match self {
472            Mask::AllTrue(_) => None,
473            Mask::AllFalse(l) => Some(NullBuffer::new_null(*l)),
474            Mask::Values(values) => Some(NullBuffer::from(values.buffer.clone())),
475        }
476    }
477
478    /// Return the indices representation of the mask.
479    pub fn indices(&self) -> AllOr<&[usize]> {
480        match &self {
481            Self::AllTrue(_) => AllOr::All,
482            Self::AllFalse(_) => AllOr::None,
483            Self::Values(values) => AllOr::Some(values.indices()),
484        }
485    }
486
487    /// Return the slices representation of the mask.
488    pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
489        match &self {
490            Self::AllTrue(_) => AllOr::All,
491            Self::AllFalse(_) => AllOr::None,
492            Self::Values(values) => AllOr::Some(values.slices()),
493        }
494    }
495
496    /// Return an iterator over either indices or slices of the mask based on a density threshold.
497    pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
498        match &self {
499            Self::AllTrue(_) => AllOr::All,
500            Self::AllFalse(_) => AllOr::None,
501            Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
502        }
503    }
504
505    /// Return [`MaskValues`] if the mask is not all true or all false.
506    pub fn values(&self) -> Option<&MaskValues> {
507        match self {
508            Self::Values(values) => Some(values),
509            _ => None,
510        }
511    }
512
513    /// Given monotonically increasing `indices` in [0, n_rows], returns the
514    /// count of valid elements up to each index.
515    ///
516    /// This is O(n_rows).
517    pub fn valid_counts_for_indices(&self, indices: &[usize]) -> VortexResult<Vec<usize>> {
518        Ok(match self {
519            Self::AllTrue(_) => indices.to_vec(),
520            Self::AllFalse(_) => vec![0; indices.len()],
521            Self::Values(values) => {
522                let mut bool_iter = values.boolean_buffer().iter();
523                let mut valid_counts = Vec::with_capacity(indices.len());
524                let mut valid_count = 0;
525                let mut idx = 0;
526                for &next_idx in indices {
527                    while idx < next_idx {
528                        idx += 1;
529                        valid_count += bool_iter
530                            .next()
531                            .ok_or_else(|| vortex_err!("Row indices exceed array length"))?
532                            as usize;
533                    }
534                    valid_counts.push(valid_count);
535                }
536
537                valid_counts
538            }
539        })
540    }
541
542    /// Limit the mask to the first `limit` true values
543    pub fn limit(self, limit: usize) -> Self {
544        if self.len() <= limit {
545            return self;
546        }
547
548        match self {
549            Mask::AllTrue(len) => {
550                Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
551            }
552            Mask::AllFalse(_) => self,
553            Mask::Values(ref mask_values) => {
554                if limit >= mask_values.true_count() {
555                    return self;
556                }
557
558                let existing_buffer = mask_values.boolean_buffer();
559
560                let mut new_buffer_builder = BooleanBufferBuilder::new(mask_values.len());
561                new_buffer_builder.append_n(mask_values.len(), false);
562
563                for index in existing_buffer.set_indices().take(limit) {
564                    new_buffer_builder.set_bit(index, true);
565                }
566
567                Self::from(new_buffer_builder.finish())
568            }
569        }
570    }
571}
572
573/// Iterator over the indices or slices of a mask.
574pub enum MaskIter<'a> {
575    /// Slice of pre-cached indices of a mask.
576    Indices(&'a [usize]),
577    /// Slice of pre-cached slices of a mask.
578    Slices(&'a [(usize, usize)]),
579}
580
581impl From<BooleanBuffer> for Mask {
582    fn from(value: BooleanBuffer) -> Self {
583        Self::from_buffer(value)
584    }
585}
586
587impl FromIterator<bool> for Mask {
588    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
589        Self::from_buffer(BooleanBuffer::from_iter(iter))
590    }
591}
592
593impl FromIterator<Mask> for Mask {
594    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
595        let masks = iter.into_iter().collect::<Vec<_>>();
596        let total_length = masks.iter().map(|v| v.len()).sum();
597
598        // If they're all valid, then return a single validity.
599        if masks.iter().all(|v| v.all_true()) {
600            return Self::AllTrue(total_length);
601        }
602        // If they're all invalid, then return a single invalidity.
603        if masks.iter().all(|v| v.all_false()) {
604            return Self::AllFalse(total_length);
605        }
606
607        // Else, construct the boolean buffer
608        let mut buffer = BooleanBufferBuilder::new(total_length);
609        for mask in masks {
610            match mask {
611                Mask::AllTrue(count) => buffer.append_n(count, true),
612                Mask::AllFalse(count) => buffer.append_n(count, false),
613                Mask::Values(values) => {
614                    buffer.append_buffer(values.boolean_buffer());
615                }
616            };
617        }
618        Self::from_buffer(buffer.finish())
619    }
620}
621
622#[cfg(test)]
623mod test {
624    use super::*;
625
626    #[test]
627    fn mask_all_true() {
628        let mask = Mask::new_true(5);
629        assert_eq!(mask.len(), 5);
630        assert_eq!(mask.true_count(), 5);
631        assert_eq!(mask.density(), 1.0);
632        assert_eq!(mask.indices(), AllOr::All);
633        assert_eq!(mask.slices(), AllOr::All);
634        assert_eq!(mask.boolean_buffer(), AllOr::All,);
635    }
636
637    #[test]
638    fn mask_all_false() {
639        let mask = Mask::new_false(5);
640        assert_eq!(mask.len(), 5);
641        assert_eq!(mask.true_count(), 0);
642        assert_eq!(mask.density(), 0.0);
643        assert_eq!(mask.indices(), AllOr::None);
644        assert_eq!(mask.slices(), AllOr::None);
645        assert_eq!(mask.boolean_buffer(), AllOr::None,);
646    }
647
648    #[test]
649    fn mask_from() {
650        let masks = [
651            Mask::from_indices(5, vec![0, 2, 3]),
652            Mask::from_slices(5, vec![(0, 1), (2, 4)]),
653            Mask::from_buffer(BooleanBuffer::from_iter([true, false, true, true, false])),
654        ];
655
656        for mask in &masks {
657            assert_eq!(mask.len(), 5);
658            assert_eq!(mask.true_count(), 3);
659            assert_eq!(mask.density(), 0.6);
660            assert_eq!(mask.indices(), AllOr::Some(&[0, 2, 3][..]));
661            assert_eq!(mask.slices(), AllOr::Some(&[(0, 1), (2, 4)][..]));
662            assert_eq!(
663                mask.boolean_buffer(),
664                AllOr::Some(&BooleanBuffer::from_iter([true, false, true, true, false]))
665            );
666        }
667    }
668
669    #[test]
670    fn limit_all_true_mask() {
671        let all_true = Mask::new_true(4);
672        let limited_mask = all_true.clone().limit(2);
673        assert_eq!(all_true.len(), limited_mask.len());
674        assert_eq!(limited_mask.true_count(), 2);
675        assert_eq!(
676            limited_mask.boolean_buffer(),
677            AllOr::Some(&BooleanBuffer::from_iter([true, true, false, false]))
678        );
679
680        let limited_mask = all_true.clone().limit(5);
681        assert_eq!(limited_mask, all_true);
682    }
683
684    #[test]
685    fn limit_mask_values() {
686        let original_mask = Mask::from_iter([true, true, false, true, false, true]);
687        let limited_mask = original_mask.clone().limit(2);
688
689        assert_eq!(
690            limited_mask.boolean_buffer(),
691            AllOr::Some(&BooleanBuffer::from_iter([
692                true, true, false, false, false, false
693            ]))
694        );
695        assert_eq!(limited_mask.true_count(), 2);
696
697        let limited_mask = original_mask.limit(3);
698
699        assert_eq!(
700            limited_mask.boolean_buffer(),
701            AllOr::Some(&BooleanBuffer::from_iter([
702                true, true, false, true, false, false
703            ]))
704        );
705        assert_eq!(limited_mask.true_count(), 3);
706
707        let original_mask = Mask::from_iter([true, true, false, true, false, true]);
708        let limited_mask = original_mask.clone().limit(100);
709
710        assert_eq!(original_mask, limited_mask);
711    }
712}