vortex_mask/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! A mask is a set of sorted unique positive integers.
5#![deny(missing_docs)]
6mod bitops;
7mod eq;
8mod intersect_by_rank;
9mod iter_bools;
10
11#[cfg(test)]
12mod tests;
13
14use std::cmp::Ordering;
15use std::fmt::{Debug, Formatter};
16use std::ops::Range;
17use std::sync::{Arc, OnceLock};
18
19use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};
20use itertools::Itertools;
21use vortex_error::{VortexResult, vortex_panic};
22
23/// Represents a set of values that are all included, all excluded, or some mixture of both.
24pub enum AllOr<T> {
25    /// All values are included.
26    All,
27    /// No values are included.
28    None,
29    /// Some values are included.
30    Some(T),
31}
32
33impl<T> AllOr<T> {
34    /// Returns the `Some` variant of the enum, or a default value.
35    #[inline]
36    pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
37    where
38        F: FnOnce() -> T,
39        G: FnOnce() -> T,
40    {
41        match self {
42            Self::Some(v) => v,
43            AllOr::All => all_true(),
44            AllOr::None => all_false(),
45        }
46    }
47}
48
49impl<T> AllOr<&T> {
50    /// Clone the inner value.
51    #[inline]
52    pub fn cloned(self) -> AllOr<T>
53    where
54        T: Clone,
55    {
56        match self {
57            Self::All => AllOr::All,
58            Self::None => AllOr::None,
59            Self::Some(v) => AllOr::Some(v.clone()),
60        }
61    }
62}
63
64impl<T> Debug for AllOr<T>
65where
66    T: Debug,
67{
68    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
69        match self {
70            Self::All => f.write_str("All"),
71            Self::None => f.write_str("None"),
72            Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
73        }
74    }
75}
76
77impl<T> PartialEq for AllOr<T>
78where
79    T: PartialEq,
80{
81    fn eq(&self, other: &Self) -> bool {
82        match (self, other) {
83            (Self::All, Self::All) => true,
84            (Self::None, Self::None) => true,
85            (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
86            _ => false,
87        }
88    }
89}
90
91impl<T> Eq for AllOr<T> where T: Eq {}
92
93/// Represents a set of sorted unique positive integers.
94///
95/// A [`Mask`] can be constructed from various representations, and converted to various
96/// others. Internally, these are cached.
97#[derive(Clone, Debug)]
98pub enum Mask {
99    /// All values are included.
100    AllTrue(usize),
101    /// No values are included.
102    AllFalse(usize),
103    /// Some values are included, represented as a [`BooleanBuffer`].
104    Values(Arc<MaskValues>),
105}
106
107/// Represents the values of a [`Mask`] that contains some true and some false elements.
108#[derive(Debug)]
109pub struct MaskValues {
110    buffer: BooleanBuffer,
111
112    // We cached the indices and slices representations, since it can be faster than iterating
113    // the bit-mask over and over again.
114    indices: OnceLock<Vec<usize>>,
115    slices: OnceLock<Vec<(usize, usize)>>,
116
117    // Pre-computed values.
118    true_count: usize,
119    // i.e., the fraction of values that are true
120    density: f64,
121}
122
123impl MaskValues {
124    /// Returns the length of the mask.
125    #[inline]
126    pub fn len(&self) -> usize {
127        self.buffer.len()
128    }
129
130    /// Returns true if the mask is empty i.e., it's length is 0.
131    #[inline]
132    pub fn is_empty(&self) -> bool {
133        self.buffer.is_empty()
134    }
135
136    /// Returns the true count of the mask.
137    #[inline]
138    pub fn true_count(&self) -> usize {
139        self.true_count
140    }
141
142    /// Returns the boolean buffer representation of the mask.
143    #[inline]
144    pub fn boolean_buffer(&self) -> &BooleanBuffer {
145        &self.buffer
146    }
147
148    /// Returns the boolean value at a given index.
149    #[inline]
150    pub fn value(&self, index: usize) -> bool {
151        self.buffer.value(index)
152    }
153
154    /// Constructs an indices vector from one of the other representations.
155    pub fn indices(&self) -> &[usize] {
156        self.indices.get_or_init(|| {
157            if self.true_count == 0 {
158                return vec![];
159            }
160
161            if self.true_count == self.len() {
162                return (0..self.len()).collect();
163            }
164
165            if let Some(slices) = self.slices.get() {
166                let mut indices = Vec::with_capacity(self.true_count);
167                indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
168                debug_assert!(indices.is_sorted());
169                assert_eq!(indices.len(), self.true_count);
170                return indices;
171            }
172
173            let mut indices = Vec::with_capacity(self.true_count);
174            indices.extend(self.buffer.set_indices());
175            debug_assert!(indices.is_sorted());
176            assert_eq!(indices.len(), self.true_count);
177            indices
178        })
179    }
180
181    /// Constructs a slices vector from one of the other representations.
182    #[allow(clippy::cast_possible_truncation)]
183    #[inline]
184    pub fn slices(&self) -> &[(usize, usize)] {
185        self.slices.get_or_init(|| {
186            if self.true_count == self.len() {
187                return vec![(0, self.len())];
188            }
189
190            self.buffer.set_slices().collect()
191        })
192    }
193
194    /// Return an iterator over either indices or slices of the mask based on a density threshold.
195    #[inline]
196    pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
197        if self.density >= threshold {
198            MaskIter::Slices(self.slices())
199        } else {
200            MaskIter::Indices(self.indices())
201        }
202    }
203}
204
205impl Mask {
206    /// Create a new Mask where all values are set.
207    #[inline]
208    pub fn new_true(length: usize) -> Self {
209        Self::AllTrue(length)
210    }
211
212    /// Create a new Mask where no values are set.
213    #[inline]
214    pub fn new_false(length: usize) -> Self {
215        Self::AllFalse(length)
216    }
217
218    /// Create a new [`Mask`] from a [`BooleanBuffer`].
219    pub fn from_buffer(buffer: BooleanBuffer) -> Self {
220        let len = buffer.len();
221        let true_count = buffer.count_set_bits();
222
223        if true_count == 0 {
224            return Self::AllFalse(len);
225        }
226        if true_count == len {
227            return Self::AllTrue(len);
228        }
229
230        Self::Values(Arc::new(MaskValues {
231            buffer,
232            indices: Default::default(),
233            slices: Default::default(),
234            true_count,
235            density: true_count as f64 / len as f64,
236        }))
237    }
238
239    /// Create a new [`Mask`] from a [`Vec<usize>`].
240    // TODO(ngates): this should take an IntoIterator<usize>.
241    pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
242        let true_count = indices.len();
243        assert!(indices.is_sorted(), "Mask indices must be sorted");
244        assert!(
245            indices.last().is_none_or(|&idx| idx < len),
246            "Mask indices must be in bounds (len={len})"
247        );
248
249        if true_count == 0 {
250            return Self::AllFalse(len);
251        }
252        if true_count == len {
253            return Self::AllTrue(len);
254        }
255
256        let mut buf = BooleanBufferBuilder::new(len);
257        // TODO(ngates): for dense indices, we can do better by collecting into u64s.
258        buf.append_n(len, false);
259        indices.iter().for_each(|idx| buf.set_bit(*idx, true));
260        debug_assert_eq!(buf.len(), len);
261
262        Self::Values(Arc::new(MaskValues {
263            buffer: buf.finish(),
264            indices: OnceLock::from(indices),
265            slices: Default::default(),
266            true_count,
267            density: true_count as f64 / len as f64,
268        }))
269    }
270
271    /// Create a new [`Mask`] from an [`IntoIterator<Item = usize>`] of indices to be excluded.
272    pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
273        let mut buf = BooleanBufferBuilder::new(len);
274        buf.append_n(len, true);
275
276        let mut false_count: usize = 0;
277        indices.into_iter().for_each(|idx| {
278            buf.set_bit(idx, false);
279            false_count += 1;
280        });
281        debug_assert_eq!(buf.len(), len);
282        let true_count = len - false_count;
283
284        // Return optimized variants when appropriate
285        if false_count == 0 {
286            return Self::AllTrue(len);
287        }
288        if false_count == len {
289            return Self::AllFalse(len);
290        }
291
292        Self::Values(Arc::new(MaskValues {
293            buffer: buf.finish(),
294            indices: Default::default(),
295            slices: Default::default(),
296            true_count,
297            density: true_count as f64 / len as f64,
298        }))
299    }
300
301    /// Create a new [`Mask`] from a [`Vec<(usize, usize)>`] where each range
302    /// represents a contiguous range of true values.
303    pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
304        Self::check_slices(len, &vec);
305        Self::from_slices_unchecked(len, vec)
306    }
307
308    fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
309        #[cfg(debug_assertions)]
310        Self::check_slices(len, &slices);
311
312        let true_count = slices.iter().map(|(b, e)| e - b).sum();
313        if true_count == 0 {
314            return Self::AllFalse(len);
315        }
316        if true_count == len {
317            return Self::AllTrue(len);
318        }
319
320        let mut buf = BooleanBufferBuilder::new(len);
321        for (start, end) in slices.iter().copied() {
322            buf.append_n(start - buf.len(), false);
323            buf.append_n(end - start, true);
324        }
325        if let Some((_, end)) = slices.last() {
326            buf.append_n(len - end, false);
327        }
328        debug_assert_eq!(buf.len(), len);
329
330        Self::Values(Arc::new(MaskValues {
331            buffer: buf.finish(),
332            indices: Default::default(),
333            slices: OnceLock::from(slices),
334            true_count,
335            density: true_count as f64 / len as f64,
336        }))
337    }
338
339    #[inline(always)]
340    fn check_slices(len: usize, vec: &[(usize, usize)]) {
341        assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
342        for (first, second) in vec.iter().tuple_windows() {
343            assert!(
344                first.0 < second.0,
345                "Slices must be sorted, got {first:?} and {second:?}"
346            );
347            assert!(
348                first.1 <= second.0,
349                "Slices must be non-overlapping, got {first:?} and {second:?}"
350            );
351        }
352    }
353
354    /// Create a new [`Mask`] from the intersection of two indices slices.
355    pub fn from_intersection_indices(
356        len: usize,
357        lhs: impl Iterator<Item = usize>,
358        rhs: impl Iterator<Item = usize>,
359    ) -> Self {
360        let mut intersection = Vec::with_capacity(len);
361        let mut lhs = lhs.peekable();
362        let mut rhs = rhs.peekable();
363        while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
364            match l.cmp(&r) {
365                Ordering::Less => {
366                    lhs.next();
367                }
368                Ordering::Greater => {
369                    rhs.next();
370                }
371                Ordering::Equal => {
372                    intersection.push(l);
373                    lhs.next();
374                    rhs.next();
375                }
376            }
377        }
378        Self::from_indices(len, intersection)
379    }
380
381    /// Returns the length of the mask (not the number of true values).
382    #[inline]
383    pub fn len(&self) -> usize {
384        match self {
385            Self::AllTrue(len) => *len,
386            Self::AllFalse(len) => *len,
387            Self::Values(values) => values.len(),
388        }
389    }
390
391    /// Returns true if the mask is empty i.e., it's length is 0.
392    #[inline]
393    pub fn is_empty(&self) -> bool {
394        match self {
395            Self::AllTrue(len) => *len == 0,
396            Self::AllFalse(len) => *len == 0,
397            Self::Values(values) => values.is_empty(),
398        }
399    }
400
401    /// Get the true count of the mask.
402    #[inline]
403    pub fn true_count(&self) -> usize {
404        match &self {
405            Self::AllTrue(len) => *len,
406            Self::AllFalse(_) => 0,
407            Self::Values(values) => values.true_count,
408        }
409    }
410
411    /// Get the false count of the mask.
412    #[inline]
413    pub fn false_count(&self) -> usize {
414        match &self {
415            Self::AllTrue(_) => 0,
416            Self::AllFalse(len) => *len,
417            Self::Values(values) => values.buffer.len() - values.true_count,
418        }
419    }
420
421    /// Returns true if all values in the mask are true.
422    #[inline]
423    pub fn all_true(&self) -> bool {
424        match &self {
425            Self::AllTrue(_) => true,
426            Self::AllFalse(0) => true,
427            Self::AllFalse(_) => false,
428            Self::Values(values) => values.buffer.len() == values.true_count,
429        }
430    }
431
432    /// Returns true if all values in the mask are false.
433    #[inline]
434    pub fn all_false(&self) -> bool {
435        self.true_count() == 0
436    }
437
438    /// Return the density of the full mask.
439    #[inline]
440    pub fn density(&self) -> f64 {
441        match &self {
442            Self::AllTrue(_) => 1.0,
443            Self::AllFalse(_) => 0.0,
444            Self::Values(values) => values.density,
445        }
446    }
447
448    /// Returns the boolean value at a given index.
449    ///
450    /// ## Panics
451    ///
452    /// Panics if the index is out of bounds.
453    #[inline]
454    pub fn value(&self, idx: usize) -> bool {
455        match self {
456            Mask::AllTrue(_) => true,
457            Mask::AllFalse(_) => false,
458            Mask::Values(values) => values.buffer.value(idx),
459        }
460    }
461
462    /// Returns the first true index in the mask.
463    pub fn first(&self) -> Option<usize> {
464        match &self {
465            Self::AllTrue(len) => (*len > 0).then_some(0),
466            Self::AllFalse(_) => None,
467            Self::Values(values) => {
468                if let Some(indices) = values.indices.get() {
469                    return indices.first().copied();
470                }
471                if let Some(slices) = values.slices.get() {
472                    return slices.first().map(|(start, _)| *start);
473                }
474                values.buffer.set_indices().next()
475            }
476        }
477    }
478
479    /// Slice the mask.
480    #[inline]
481    pub fn slice(&self, range: Range<usize>) -> Self {
482        assert!(range.end <= self.len());
483        match &self {
484            Self::AllTrue(_) => Self::new_true(range.len()),
485            Self::AllFalse(_) => Self::new_false(range.len()),
486            Self::Values(values) => {
487                Self::from_buffer(values.buffer.slice(range.start, range.len()))
488            }
489        }
490    }
491
492    /// Return the boolean buffer representation of the mask.
493    #[inline]
494    pub fn boolean_buffer(&self) -> AllOr<&BooleanBuffer> {
495        match &self {
496            Self::AllTrue(_) => AllOr::All,
497            Self::AllFalse(_) => AllOr::None,
498            Self::Values(values) => AllOr::Some(&values.buffer),
499        }
500    }
501
502    /// Return a boolean buffer representation of the mask, allocating new buffers for all-true
503    /// and all-false variants.
504    #[inline]
505    pub fn to_boolean_buffer(&self) -> BooleanBuffer {
506        match self {
507            Self::AllTrue(l) => BooleanBuffer::new_set(*l),
508            Self::AllFalse(l) => BooleanBuffer::new_unset(*l),
509            Self::Values(values) => values.boolean_buffer().clone(),
510        }
511    }
512
513    /// Returns an Arrow null buffer representation of the mask.
514    #[inline]
515    pub fn to_null_buffer(&self) -> Option<NullBuffer> {
516        match self {
517            Mask::AllTrue(_) => None,
518            Mask::AllFalse(l) => Some(NullBuffer::new_null(*l)),
519            Mask::Values(values) => Some(NullBuffer::from(values.buffer.clone())),
520        }
521    }
522
523    /// Return the indices representation of the mask.
524    #[inline]
525    pub fn indices(&self) -> AllOr<&[usize]> {
526        match &self {
527            Self::AllTrue(_) => AllOr::All,
528            Self::AllFalse(_) => AllOr::None,
529            Self::Values(values) => AllOr::Some(values.indices()),
530        }
531    }
532
533    /// Return the slices representation of the mask.
534    #[inline]
535    pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
536        match &self {
537            Self::AllTrue(_) => AllOr::All,
538            Self::AllFalse(_) => AllOr::None,
539            Self::Values(values) => AllOr::Some(values.slices()),
540        }
541    }
542
543    /// Return an iterator over either indices or slices of the mask based on a density threshold.
544    #[inline]
545    pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
546        match &self {
547            Self::AllTrue(_) => AllOr::All,
548            Self::AllFalse(_) => AllOr::None,
549            Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
550        }
551    }
552
553    /// Return [`MaskValues`] if the mask is not all true or all false.
554    #[inline]
555    pub fn values(&self) -> Option<&MaskValues> {
556        if let Self::Values(values) = self {
557            Some(values)
558        } else {
559            None
560        }
561    }
562
563    /// Given monotonically increasing `indices` in [0, n_rows], returns the
564    /// count of valid elements up to each index.
565    ///
566    /// This is O(n_rows).
567    pub fn valid_counts_for_indices(&self, indices: &[usize]) -> Vec<usize> {
568        match self {
569            Self::AllTrue(_) => indices.to_vec(),
570            Self::AllFalse(_) => vec![0; indices.len()],
571            Self::Values(values) => {
572                let mut bool_iter = values.boolean_buffer().iter();
573                let mut valid_counts = Vec::with_capacity(indices.len());
574                let mut valid_count = 0;
575                let mut idx = 0;
576                for &next_idx in indices {
577                    while idx < next_idx {
578                        idx += 1;
579                        valid_count += bool_iter
580                            .next()
581                            .unwrap_or_else(|| vortex_panic!("Row indices exceed array length"))
582                            as usize;
583                    }
584                    valid_counts.push(valid_count);
585                }
586
587                valid_counts
588            }
589        }
590    }
591
592    /// Limit the mask to the first `limit` true values
593    pub fn limit(self, limit: usize) -> Self {
594        // Early return optimization: if we're asking for more true values than the total
595        // length of the mask, then even if all values were true, we couldn't exceed the
596        // limit, so return the original mask unchanged.
597        if self.len() <= limit {
598            return self;
599        }
600
601        match self {
602            Mask::AllTrue(len) => {
603                Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
604            }
605            Mask::AllFalse(_) => self,
606            Mask::Values(ref mask_values) => {
607                if limit >= mask_values.true_count() {
608                    return self;
609                }
610
611                let existing_buffer = mask_values.boolean_buffer();
612
613                let mut new_buffer_builder = BooleanBufferBuilder::new(mask_values.len());
614                new_buffer_builder.append_n(mask_values.len(), false);
615
616                for index in existing_buffer.set_indices().take(limit) {
617                    new_buffer_builder.set_bit(index, true);
618                }
619
620                Self::from(new_buffer_builder.finish())
621            }
622        }
623    }
624
625    /// Concatenate multiple masks together into a single mask.
626    pub fn concat<'a>(masks: impl Iterator<Item = &'a Self>) -> VortexResult<Self> {
627        let masks: Vec<_> = masks.collect();
628        let len = masks.iter().map(|t| t.len()).sum();
629
630        if masks.iter().all(|t| t.all_true()) {
631            return Ok(Mask::AllTrue(len));
632        }
633
634        if masks.iter().all(|t| t.all_false()) {
635            return Ok(Mask::AllFalse(len));
636        }
637
638        let mut builder = BooleanBufferBuilder::new(len);
639
640        for mask in masks {
641            match mask {
642                Mask::AllTrue(n) => builder.append_n(*n, true),
643                Mask::AllFalse(n) => builder.append_n(*n, false),
644                Mask::Values(v) => builder.append_buffer(v.boolean_buffer()),
645            }
646        }
647
648        Ok(Mask::from_buffer(builder.finish()))
649    }
650}
651
652/// Iterator over the indices or slices of a mask.
653pub enum MaskIter<'a> {
654    /// Slice of pre-cached indices of a mask.
655    Indices(&'a [usize]),
656    /// Slice of pre-cached slices of a mask.
657    Slices(&'a [(usize, usize)]),
658}
659
660impl From<BooleanBuffer> for Mask {
661    #[inline]
662    fn from(value: BooleanBuffer) -> Self {
663        Self::from_buffer(value)
664    }
665}
666
667impl FromIterator<bool> for Mask {
668    #[inline]
669    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
670        Self::from_buffer(BooleanBuffer::from_iter(iter))
671    }
672}
673
674impl FromIterator<Mask> for Mask {
675    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
676        let masks = iter
677            .into_iter()
678            .filter(|m| !m.is_empty())
679            .collect::<Vec<_>>();
680        let total_length = masks.iter().map(|v| v.len()).sum();
681
682        // If they're all valid, then return a single validity.
683        if masks.iter().all(|v| v.all_true()) {
684            return Self::AllTrue(total_length);
685        }
686        // If they're all invalid, then return a single invalidity.
687        if masks.iter().all(|v| v.all_false()) {
688            return Self::AllFalse(total_length);
689        }
690
691        // Else, construct the boolean buffer
692        let mut buffer = BooleanBufferBuilder::new(total_length);
693        for mask in masks {
694            match mask {
695                Mask::AllTrue(count) => buffer.append_n(count, true),
696                Mask::AllFalse(count) => buffer.append_n(count, false),
697                Mask::Values(values) => {
698                    buffer.append_buffer(values.boolean_buffer());
699                }
700            };
701        }
702        Self::from_buffer(buffer.finish())
703    }
704}