vortex_mask/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! A mask is a set of sorted unique positive integers.
5#![deny(missing_docs)]
6
7mod bitops;
8mod eq;
9mod intersect_by_rank;
10mod iter_bools;
11mod mask_mut;
12
13#[cfg(test)]
14mod tests;
15
16use std::cmp::Ordering;
17use std::fmt::{Debug, Formatter};
18use std::ops::{Bound, RangeBounds};
19use std::sync::{Arc, OnceLock};
20
21use itertools::Itertools;
22pub use mask_mut::*;
23use vortex_buffer::{BitBuffer, BitBufferMut, set_bit_unchecked};
24use vortex_error::{VortexResult, vortex_panic};
25
26/// Represents a set of values that are all included, all excluded, or some mixture of both.
27pub enum AllOr<T> {
28    /// All values are included.
29    All,
30    /// No values are included.
31    None,
32    /// Some values are included.
33    Some(T),
34}
35
36impl<T> AllOr<T> {
37    /// Returns the `Some` variant of the enum, or a default value.
38    #[inline]
39    pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
40    where
41        F: FnOnce() -> T,
42        G: FnOnce() -> T,
43    {
44        match self {
45            Self::Some(v) => v,
46            AllOr::All => all_true(),
47            AllOr::None => all_false(),
48        }
49    }
50}
51
52impl<T> AllOr<&T> {
53    /// Clone the inner value.
54    #[inline]
55    pub fn cloned(self) -> AllOr<T>
56    where
57        T: Clone,
58    {
59        match self {
60            Self::All => AllOr::All,
61            Self::None => AllOr::None,
62            Self::Some(v) => AllOr::Some(v.clone()),
63        }
64    }
65}
66
67impl<T> Debug for AllOr<T>
68where
69    T: Debug,
70{
71    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
72        match self {
73            Self::All => f.write_str("All"),
74            Self::None => f.write_str("None"),
75            Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
76        }
77    }
78}
79
80impl<T> PartialEq for AllOr<T>
81where
82    T: PartialEq,
83{
84    fn eq(&self, other: &Self) -> bool {
85        match (self, other) {
86            (Self::All, Self::All) => true,
87            (Self::None, Self::None) => true,
88            (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
89            _ => false,
90        }
91    }
92}
93
94impl<T> Eq for AllOr<T> where T: Eq {}
95
96/// Represents a set of sorted unique positive integers.
97///
98/// A [`Mask`] can be constructed from various representations, and converted to various
99/// others. Internally, these are cached.
100#[derive(Debug, Clone)]
101#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
102pub enum Mask {
103    /// All values are included.
104    AllTrue(usize),
105    /// No values are included.
106    AllFalse(usize),
107    /// Some values are included, represented as a [`BitBuffer`].
108    Values(Arc<MaskValues>),
109}
110
111/// Represents the values of a [`Mask`] that contains some true and some false elements.
112#[derive(Debug)]
113#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
114pub struct MaskValues {
115    buffer: BitBuffer,
116
117    // We cached the indices and slices representations, since it can be faster than iterating
118    // the bit-mask over and over again.
119    #[cfg_attr(feature = "serde", serde(skip))]
120    indices: OnceLock<Vec<usize>>,
121    #[cfg_attr(feature = "serde", serde(skip))]
122    slices: OnceLock<Vec<(usize, usize)>>,
123
124    // Pre-computed values.
125    true_count: usize,
126    // i.e., the fraction of values that are true
127    density: f64,
128}
129
130impl Mask {
131    /// Create a new Mask where all values are set.
132    #[inline]
133    pub fn new_true(length: usize) -> Self {
134        Self::AllTrue(length)
135    }
136
137    /// Create a new Mask where no values are set.
138    #[inline]
139    pub fn new_false(length: usize) -> Self {
140        Self::AllFalse(length)
141    }
142
143    /// Create a new [`Mask`] from a [`BitBuffer`].
144    pub fn from_buffer(buffer: BitBuffer) -> Self {
145        let len = buffer.len();
146        let true_count = buffer.true_count();
147
148        if true_count == 0 {
149            return Self::AllFalse(len);
150        }
151        if true_count == len {
152            return Self::AllTrue(len);
153        }
154
155        Self::Values(Arc::new(MaskValues {
156            buffer,
157            indices: Default::default(),
158            slices: Default::default(),
159            true_count,
160            density: true_count as f64 / len as f64,
161        }))
162    }
163
164    /// Create a new [`Mask`] from a [`Vec<usize>`].
165    // TODO(ngates): this should take an IntoIterator<usize>.
166    pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
167        let true_count = indices.len();
168        assert!(indices.is_sorted(), "Mask indices must be sorted");
169        assert!(
170            indices.last().is_none_or(|&idx| idx < len),
171            "Mask indices must be in bounds (len={len})"
172        );
173
174        if true_count == 0 {
175            return Self::AllFalse(len);
176        }
177        if true_count == len {
178            return Self::AllTrue(len);
179        }
180
181        let mut buf = BitBufferMut::new_unset(len);
182        // TODO(ngates): for dense indices, we can do better by collecting into u64s.
183        indices.iter().for_each(|&idx| buf.set(idx));
184        debug_assert_eq!(buf.len(), len);
185
186        Self::Values(Arc::new(MaskValues {
187            buffer: buf.freeze(),
188            indices: OnceLock::from(indices),
189            slices: Default::default(),
190            true_count,
191            density: true_count as f64 / len as f64,
192        }))
193    }
194
195    /// Create a new [`Mask`] from an [`IntoIterator<Item = usize>`] of indices to be excluded.
196    pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
197        let mut buf = BitBufferMut::new_set(len);
198
199        let mut false_count: usize = 0;
200        indices.into_iter().for_each(|idx| {
201            buf.unset(idx);
202            false_count += 1;
203        });
204        debug_assert_eq!(buf.len(), len);
205        let true_count = len - false_count;
206
207        // Return optimized variants when appropriate
208        if false_count == 0 {
209            return Self::AllTrue(len);
210        }
211        if false_count == len {
212            return Self::AllFalse(len);
213        }
214
215        Self::Values(Arc::new(MaskValues {
216            buffer: buf.freeze(),
217            indices: Default::default(),
218            slices: Default::default(),
219            true_count,
220            density: true_count as f64 / len as f64,
221        }))
222    }
223
224    /// Create a new [`Mask`] from a [`Vec<(usize, usize)>`] where each range
225    /// represents a contiguous range of true values.
226    pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
227        Self::check_slices(len, &vec);
228        Self::from_slices_unchecked(len, vec)
229    }
230
231    fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
232        #[cfg(debug_assertions)]
233        Self::check_slices(len, &slices);
234
235        let true_count = slices.iter().map(|(b, e)| e - b).sum();
236        if true_count == 0 {
237            return Self::AllFalse(len);
238        }
239        if true_count == len {
240            return Self::AllTrue(len);
241        }
242
243        let mut buf = BitBufferMut::new_unset(len);
244        for (start, end) in slices.iter().copied() {
245            (start..end).for_each(|idx| buf.set(idx));
246        }
247        debug_assert_eq!(buf.len(), len);
248
249        Self::Values(Arc::new(MaskValues {
250            buffer: buf.freeze(),
251            indices: Default::default(),
252            slices: OnceLock::from(slices),
253            true_count,
254            density: true_count as f64 / len as f64,
255        }))
256    }
257
258    #[inline(always)]
259    fn check_slices(len: usize, vec: &[(usize, usize)]) {
260        assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
261        for (first, second) in vec.iter().tuple_windows() {
262            assert!(
263                first.0 < second.0,
264                "Slices must be sorted, got {first:?} and {second:?}"
265            );
266            assert!(
267                first.1 <= second.0,
268                "Slices must be non-overlapping, got {first:?} and {second:?}"
269            );
270        }
271    }
272
273    /// Create a new [`Mask`] from the intersection of two indices slices.
274    pub fn from_intersection_indices(
275        len: usize,
276        lhs: impl Iterator<Item = usize>,
277        rhs: impl Iterator<Item = usize>,
278    ) -> Self {
279        let mut intersection = Vec::with_capacity(len);
280        let mut lhs = lhs.peekable();
281        let mut rhs = rhs.peekable();
282        while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
283            match l.cmp(&r) {
284                Ordering::Less => {
285                    lhs.next();
286                }
287                Ordering::Greater => {
288                    rhs.next();
289                }
290                Ordering::Equal => {
291                    intersection.push(l);
292                    lhs.next();
293                    rhs.next();
294                }
295            }
296        }
297        Self::from_indices(len, intersection)
298    }
299
300    /// Returns the length of the mask (not the number of true values).
301    #[inline]
302    pub fn len(&self) -> usize {
303        match self {
304            Self::AllTrue(len) => *len,
305            Self::AllFalse(len) => *len,
306            Self::Values(values) => values.len(),
307        }
308    }
309
310    /// Returns true if the mask is empty i.e., it's length is 0.
311    #[inline]
312    pub fn is_empty(&self) -> bool {
313        match self {
314            Self::AllTrue(len) => *len == 0,
315            Self::AllFalse(len) => *len == 0,
316            Self::Values(values) => values.is_empty(),
317        }
318    }
319
320    /// Get the true count of the mask.
321    #[inline]
322    pub fn true_count(&self) -> usize {
323        match &self {
324            Self::AllTrue(len) => *len,
325            Self::AllFalse(_) => 0,
326            Self::Values(values) => values.true_count,
327        }
328    }
329
330    /// Get the false count of the mask.
331    #[inline]
332    pub fn false_count(&self) -> usize {
333        match &self {
334            Self::AllTrue(_) => 0,
335            Self::AllFalse(len) => *len,
336            Self::Values(values) => values.buffer.len() - values.true_count,
337        }
338    }
339
340    /// Returns true if all values in the mask are true.
341    #[inline]
342    pub fn all_true(&self) -> bool {
343        match &self {
344            Self::AllTrue(_) => true,
345            Self::AllFalse(0) => true,
346            Self::AllFalse(_) => false,
347            Self::Values(values) => values.buffer.len() == values.true_count,
348        }
349    }
350
351    /// Returns true if all values in the mask are false.
352    #[inline]
353    pub fn all_false(&self) -> bool {
354        self.true_count() == 0
355    }
356
357    /// Return the density of the full mask.
358    #[inline]
359    pub fn density(&self) -> f64 {
360        match &self {
361            Self::AllTrue(_) => 1.0,
362            Self::AllFalse(_) => 0.0,
363            Self::Values(values) => values.density,
364        }
365    }
366
367    /// Returns the boolean value at a given index.
368    ///
369    /// ## Panics
370    ///
371    /// Panics if the index is out of bounds.
372    #[inline]
373    pub fn value(&self, idx: usize) -> bool {
374        match self {
375            Mask::AllTrue(_) => true,
376            Mask::AllFalse(_) => false,
377            Mask::Values(values) => values.buffer.value(idx),
378        }
379    }
380
381    /// Returns the first true index in the mask.
382    pub fn first(&self) -> Option<usize> {
383        match &self {
384            Self::AllTrue(len) => (*len > 0).then_some(0),
385            Self::AllFalse(_) => None,
386            Self::Values(values) => {
387                if let Some(indices) = values.indices.get() {
388                    return indices.first().copied();
389                }
390                if let Some(slices) = values.slices.get() {
391                    return slices.first().map(|(start, _)| *start);
392                }
393                values.buffer.set_indices().next()
394            }
395        }
396    }
397
398    /// Slice the mask.
399    pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
400        let start = match range.start_bound() {
401            Bound::Included(&s) => s,
402            Bound::Excluded(&s) => s + 1,
403            Bound::Unbounded => 0,
404        };
405        let end = match range.end_bound() {
406            Bound::Included(&e) => e + 1,
407            Bound::Excluded(&e) => e,
408            Bound::Unbounded => self.len(),
409        };
410
411        assert!(start <= end);
412        assert!(start <= self.len());
413        assert!(end <= self.len());
414        let len = end - start;
415
416        match &self {
417            Self::AllTrue(_) => Self::new_true(len),
418            Self::AllFalse(_) => Self::new_false(len),
419            Self::Values(values) => Self::from_buffer(values.buffer.slice(range)),
420        }
421    }
422
423    /// Return the boolean buffer representation of the mask.
424    #[inline]
425    pub fn bit_buffer(&self) -> AllOr<&BitBuffer> {
426        match &self {
427            Self::AllTrue(_) => AllOr::All,
428            Self::AllFalse(_) => AllOr::None,
429            Self::Values(values) => AllOr::Some(&values.buffer),
430        }
431    }
432
433    /// Return a boolean buffer representation of the mask, allocating new buffers for all-true
434    /// and all-false variants.
435    #[inline]
436    pub fn to_bit_buffer(&self) -> BitBuffer {
437        match self {
438            Self::AllTrue(l) => BitBuffer::new_set(*l),
439            Self::AllFalse(l) => BitBuffer::new_unset(*l),
440            Self::Values(values) => values.bit_buffer().clone(),
441        }
442    }
443
444    /// Return a boolean buffer representation of the mask, allocating new buffers for all-true
445    /// and all-false variants.
446    #[inline]
447    pub fn into_bit_buffer(self) -> BitBuffer {
448        match self {
449            Self::AllTrue(l) => BitBuffer::new_set(l),
450            Self::AllFalse(l) => BitBuffer::new_unset(l),
451            Self::Values(values) => Arc::try_unwrap(values)
452                .map(|v| v.into_bit_buffer())
453                .unwrap_or_else(|v| v.bit_buffer().clone()),
454        }
455    }
456
457    /// Return the indices representation of the mask.
458    #[inline]
459    pub fn indices(&self) -> AllOr<&[usize]> {
460        match &self {
461            Self::AllTrue(_) => AllOr::All,
462            Self::AllFalse(_) => AllOr::None,
463            Self::Values(values) => AllOr::Some(values.indices()),
464        }
465    }
466
467    /// Return the slices representation of the mask.
468    #[inline]
469    pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
470        match &self {
471            Self::AllTrue(_) => AllOr::All,
472            Self::AllFalse(_) => AllOr::None,
473            Self::Values(values) => AllOr::Some(values.slices()),
474        }
475    }
476
477    /// Return an iterator over either indices or slices of the mask based on a density threshold.
478    #[inline]
479    pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
480        match &self {
481            Self::AllTrue(_) => AllOr::All,
482            Self::AllFalse(_) => AllOr::None,
483            Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
484        }
485    }
486
487    /// Return [`MaskValues`] if the mask is not all true or all false.
488    #[inline]
489    pub fn values(&self) -> Option<&MaskValues> {
490        if let Self::Values(values) = self {
491            Some(values)
492        } else {
493            None
494        }
495    }
496
497    /// Given monotonically increasing `indices` in [0, n_rows], returns the
498    /// count of valid elements up to each index.
499    ///
500    /// This is O(n_rows).
501    pub fn valid_counts_for_indices(&self, indices: &[usize]) -> Vec<usize> {
502        match self {
503            Self::AllTrue(_) => indices.to_vec(),
504            Self::AllFalse(_) => vec![0; indices.len()],
505            Self::Values(values) => {
506                let mut bool_iter = values.bit_buffer().iter();
507                let mut valid_counts = Vec::with_capacity(indices.len());
508                let mut valid_count = 0;
509                let mut idx = 0;
510                for &next_idx in indices {
511                    while idx < next_idx {
512                        idx += 1;
513                        valid_count += bool_iter
514                            .next()
515                            .unwrap_or_else(|| vortex_panic!("Row indices exceed array length"))
516                            as usize;
517                    }
518                    valid_counts.push(valid_count);
519                }
520
521                valid_counts
522            }
523        }
524    }
525
526    /// Limit the mask to the first `limit` true values
527    pub fn limit(self, limit: usize) -> Self {
528        // Early return optimization: if we're asking for more true values than the total
529        // length of the mask, then even if all values were true, we couldn't exceed the
530        // limit, so return the original mask unchanged.
531        if self.len() <= limit {
532            return self;
533        }
534
535        match self {
536            Mask::AllTrue(len) => {
537                Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
538            }
539            Mask::AllFalse(_) => self,
540            Mask::Values(ref mask_values) => {
541                if limit >= mask_values.true_count() {
542                    return self;
543                }
544
545                let existing_buffer = mask_values.bit_buffer();
546
547                let mut new_buffer_builder = BitBufferMut::new_unset(mask_values.len());
548                debug_assert!(limit < mask_values.len());
549
550                let ptr = new_buffer_builder.as_mut_ptr();
551                for index in existing_buffer.set_indices().take(limit) {
552                    // SAFETY: We checked that `limit` was less than the mask values length,
553                    // therefore `index` must be within the bounds of the bit buffer.
554                    unsafe { set_bit_unchecked(ptr, index) }
555                }
556
557                Self::from(new_buffer_builder.freeze())
558            }
559        }
560    }
561
562    /// Concatenate multiple masks together into a single mask.
563    pub fn concat<'a>(masks: impl Iterator<Item = &'a Self>) -> VortexResult<Self> {
564        let masks: Vec<_> = masks.collect();
565        let len = masks.iter().map(|t| t.len()).sum();
566
567        if masks.iter().all(|t| t.all_true()) {
568            return Ok(Mask::AllTrue(len));
569        }
570
571        if masks.iter().all(|t| t.all_false()) {
572            return Ok(Mask::AllFalse(len));
573        }
574
575        let mut builder = BitBufferMut::with_capacity(len);
576
577        for mask in masks {
578            match mask {
579                Mask::AllTrue(n) => builder.append_n(true, *n),
580                Mask::AllFalse(n) => builder.append_n(false, *n),
581                Mask::Values(v) => builder.append_buffer(v.bit_buffer()),
582            }
583        }
584
585        Ok(Mask::from_buffer(builder.freeze()))
586    }
587}
588
589impl MaskValues {
590    /// Returns the length of the mask.
591    #[inline]
592    pub fn len(&self) -> usize {
593        self.buffer.len()
594    }
595
596    /// Returns true if the mask is empty i.e., it's length is 0.
597    #[inline]
598    pub fn is_empty(&self) -> bool {
599        self.buffer.is_empty()
600    }
601
602    /// Returns the true count of the mask.
603    #[inline]
604    pub fn true_count(&self) -> usize {
605        self.true_count
606    }
607
608    /// Returns the boolean buffer representation of the mask.
609    #[inline]
610    pub fn bit_buffer(&self) -> &BitBuffer {
611        &self.buffer
612    }
613
614    /// Returns the boolean buffer representation of the mask.
615    #[inline]
616    pub fn into_bit_buffer(self) -> BitBuffer {
617        self.buffer
618    }
619
620    /// Returns the boolean value at a given index.
621    #[inline]
622    pub fn value(&self, index: usize) -> bool {
623        self.buffer.value(index)
624    }
625
626    /// Constructs an indices vector from one of the other representations.
627    pub fn indices(&self) -> &[usize] {
628        self.indices.get_or_init(|| {
629            if self.true_count == 0 {
630                return vec![];
631            }
632
633            if self.true_count == self.len() {
634                return (0..self.len()).collect();
635            }
636
637            if let Some(slices) = self.slices.get() {
638                let mut indices = Vec::with_capacity(self.true_count);
639                indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
640                debug_assert!(indices.is_sorted());
641                assert_eq!(indices.len(), self.true_count);
642                return indices;
643            }
644
645            let mut indices = Vec::with_capacity(self.true_count);
646            indices.extend(self.buffer.set_indices());
647            debug_assert!(indices.is_sorted());
648            assert_eq!(indices.len(), self.true_count);
649            indices
650        })
651    }
652
653    /// Constructs a slices vector from one of the other representations.
654    #[allow(clippy::cast_possible_truncation)]
655    #[inline]
656    pub fn slices(&self) -> &[(usize, usize)] {
657        self.slices.get_or_init(|| {
658            if self.true_count == self.len() {
659                return vec![(0, self.len())];
660            }
661
662            self.buffer.set_slices().collect()
663        })
664    }
665
666    /// Return an iterator over either indices or slices of the mask based on a density threshold.
667    #[inline]
668    pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
669        if self.density >= threshold {
670            MaskIter::Slices(self.slices())
671        } else {
672            MaskIter::Indices(self.indices())
673        }
674    }
675
676    /// Extracts the internal [`BitBuffer`].
677    pub(crate) fn into_buffer(self) -> BitBuffer {
678        self.buffer
679    }
680}
681
682/// Iterator over the indices or slices of a mask.
683pub enum MaskIter<'a> {
684    /// Slice of pre-cached indices of a mask.
685    Indices(&'a [usize]),
686    /// Slice of pre-cached slices of a mask.
687    Slices(&'a [(usize, usize)]),
688}
689
690impl From<BitBuffer> for Mask {
691    fn from(value: BitBuffer) -> Self {
692        Self::from_buffer(value)
693    }
694}
695
696impl FromIterator<bool> for Mask {
697    #[inline]
698    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
699        Self::from_buffer(BitBuffer::from_iter(iter))
700    }
701}
702
703impl FromIterator<Mask> for Mask {
704    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
705        let masks = iter
706            .into_iter()
707            .filter(|m| !m.is_empty())
708            .collect::<Vec<_>>();
709        let total_length = masks.iter().map(|v| v.len()).sum();
710
711        // If they're all valid, then return a single validity.
712        if masks.iter().all(|v| v.all_true()) {
713            return Self::AllTrue(total_length);
714        }
715        // If they're all invalid, then return a single invalidity.
716        if masks.iter().all(|v| v.all_false()) {
717            return Self::AllFalse(total_length);
718        }
719
720        // Else, construct the boolean buffer
721        let mut buffer = BitBufferMut::with_capacity(total_length);
722        for mask in masks {
723            match mask {
724                Mask::AllTrue(count) => buffer.append_n(true, count),
725                Mask::AllFalse(count) => buffer.append_n(false, count),
726                Mask::Values(values) => {
727                    buffer.append_buffer(values.bit_buffer());
728                }
729            };
730        }
731        Self::from_buffer(buffer.freeze())
732    }
733}