vortex_mask/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! A mask is a set of sorted unique positive integers.
5#![deny(missing_docs)]
6mod bitops;
7mod eq;
8mod intersect_by_rank;
9mod iter_bools;
10
11#[cfg(test)]
12mod tests;
13
14use std::cmp::Ordering;
15use std::fmt::{Debug, Formatter};
16use std::sync::{Arc, OnceLock};
17
18use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};
19use itertools::Itertools;
20use vortex_error::{VortexResult, vortex_err};
21
22/// Represents a set of values that are all included, all excluded, or some mixture of both.
23pub enum AllOr<T> {
24    /// All values are included.
25    All,
26    /// No values are included.
27    None,
28    /// Some values are included.
29    Some(T),
30}
31
32impl<T> AllOr<T> {
33    /// Returns the `Some` variant of the enum, or a default value.
34    pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
35    where
36        F: FnOnce() -> T,
37        G: FnOnce() -> T,
38    {
39        match self {
40            Self::Some(v) => v,
41            AllOr::All => all_true(),
42            AllOr::None => all_false(),
43        }
44    }
45}
46
47impl<T> AllOr<&T> {
48    /// Clone the inner value.
49    pub fn cloned(self) -> AllOr<T>
50    where
51        T: Clone,
52    {
53        match self {
54            Self::All => AllOr::All,
55            Self::None => AllOr::None,
56            Self::Some(v) => AllOr::Some(v.clone()),
57        }
58    }
59}
60
61impl<T> Debug for AllOr<T>
62where
63    T: Debug,
64{
65    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
66        match self {
67            Self::All => f.write_str("All"),
68            Self::None => f.write_str("None"),
69            Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
70        }
71    }
72}
73
74impl<T> PartialEq for AllOr<T>
75where
76    T: PartialEq,
77{
78    fn eq(&self, other: &Self) -> bool {
79        match (self, other) {
80            (Self::All, Self::All) => true,
81            (Self::None, Self::None) => true,
82            (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
83            _ => false,
84        }
85    }
86}
87
88impl<T> Eq for AllOr<T> where T: Eq {}
89
90/// Represents a set of sorted unique positive integers.
91///
92/// A [`Mask`] can be constructed from various representations, and converted to various
93/// others. Internally, these are cached.
94#[derive(Clone, Debug)]
95pub enum Mask {
96    /// All values are included.
97    AllTrue(usize),
98    /// No values are included.
99    AllFalse(usize),
100    /// Some values are included, represented as a [`BooleanBuffer`].
101    Values(Arc<MaskValues>),
102}
103
104/// Represents the values of a [`Mask`] that contains some true and some false elements.
105#[derive(Debug)]
106pub struct MaskValues {
107    buffer: BooleanBuffer,
108
109    // We cached the indices and slices representations, since it can be faster than iterating
110    // the bit-mask over and over again.
111    indices: OnceLock<Vec<usize>>,
112    slices: OnceLock<Vec<(usize, usize)>>,
113
114    // Pre-computed values.
115    true_count: usize,
116    // i.e., the fraction of values that are true
117    density: f64,
118}
119
120impl MaskValues {
121    /// Returns the length of the mask.
122    #[inline]
123    pub fn len(&self) -> usize {
124        self.buffer.len()
125    }
126
127    /// Returns true if the mask is empty i.e., it's length is 0.
128    #[inline]
129    pub fn is_empty(&self) -> bool {
130        self.buffer.is_empty()
131    }
132
133    /// Returns the true count of the mask.
134    pub fn true_count(&self) -> usize {
135        self.true_count
136    }
137
138    /// Returns the boolean buffer representation of the mask.
139    pub fn boolean_buffer(&self) -> &BooleanBuffer {
140        &self.buffer
141    }
142
143    /// Returns the boolean value at a given index.
144    pub fn value(&self, index: usize) -> bool {
145        self.buffer.value(index)
146    }
147
148    /// Constructs an indices vector from one of the other representations.
149    pub fn indices(&self) -> &[usize] {
150        self.indices.get_or_init(|| {
151            if self.true_count == 0 {
152                return vec![];
153            }
154
155            if self.true_count == self.len() {
156                return (0..self.len()).collect();
157            }
158
159            if let Some(slices) = self.slices.get() {
160                let mut indices = Vec::with_capacity(self.true_count);
161                indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
162                debug_assert!(indices.is_sorted());
163                assert_eq!(indices.len(), self.true_count);
164                return indices;
165            }
166
167            let mut indices = Vec::with_capacity(self.true_count);
168            indices.extend(self.buffer.set_indices());
169            debug_assert!(indices.is_sorted());
170            assert_eq!(indices.len(), self.true_count);
171            indices
172        })
173    }
174
175    /// Constructs a slices vector from one of the other representations.
176    #[allow(clippy::cast_possible_truncation)]
177    pub fn slices(&self) -> &[(usize, usize)] {
178        self.slices.get_or_init(|| {
179            if self.true_count == self.len() {
180                return vec![(0, self.len())];
181            }
182
183            self.buffer.set_slices().collect()
184        })
185    }
186
187    /// Return an iterator over either indices or slices of the mask based on a density threshold.
188    pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
189        if self.density >= threshold {
190            MaskIter::Slices(self.slices())
191        } else {
192            MaskIter::Indices(self.indices())
193        }
194    }
195}
196
197impl Mask {
198    /// Create a new Mask where all values are set.
199    pub fn new_true(length: usize) -> Self {
200        Self::AllTrue(length)
201    }
202
203    /// Create a new Mask where no values are set.
204    pub fn new_false(length: usize) -> Self {
205        Self::AllFalse(length)
206    }
207
208    /// Create a new [`Mask`] from a [`BooleanBuffer`].
209    pub fn from_buffer(buffer: BooleanBuffer) -> Self {
210        let len = buffer.len();
211        let true_count = buffer.count_set_bits();
212
213        if true_count == 0 {
214            return Self::AllFalse(len);
215        }
216        if true_count == len {
217            return Self::AllTrue(len);
218        }
219
220        Self::Values(Arc::new(MaskValues {
221            buffer,
222            indices: Default::default(),
223            slices: Default::default(),
224            true_count,
225            density: true_count as f64 / len as f64,
226        }))
227    }
228
229    /// Create a new [`Mask`] from a [`Vec<usize>`].
230    // TODO(ngates): this should take an IntoIterator<usize>.
231    pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
232        let true_count = indices.len();
233        assert!(indices.is_sorted(), "Mask indices must be sorted");
234        assert!(
235            indices.last().is_none_or(|&idx| idx < len),
236            "Mask indices must be in bounds (len={len})"
237        );
238
239        if true_count == 0 {
240            return Self::AllFalse(len);
241        }
242        if true_count == len {
243            return Self::AllTrue(len);
244        }
245
246        let mut buf = BooleanBufferBuilder::new(len);
247        // TODO(ngates): for dense indices, we can do better by collecting into u64s.
248        buf.append_n(len, false);
249        indices.iter().for_each(|idx| buf.set_bit(*idx, true));
250        debug_assert_eq!(buf.len(), len);
251
252        Self::Values(Arc::new(MaskValues {
253            buffer: buf.finish(),
254            indices: OnceLock::from(indices),
255            slices: Default::default(),
256            true_count,
257            density: true_count as f64 / len as f64,
258        }))
259    }
260
261    /// Create a new [`Mask`] from an [`IntoIterator<Item = usize>`] of indices to be excluded.
262    pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
263        let mut buf = BooleanBufferBuilder::new(len);
264        buf.append_n(len, true);
265
266        let mut false_count: usize = 0;
267        indices.into_iter().for_each(|idx| {
268            buf.set_bit(idx, false);
269            false_count += 1;
270        });
271        debug_assert_eq!(buf.len(), len);
272        let true_count = len - false_count;
273
274        // Return optimized variants when appropriate
275        if false_count == 0 {
276            return Self::AllTrue(len);
277        }
278        if false_count == len {
279            return Self::AllFalse(len);
280        }
281
282        Self::Values(Arc::new(MaskValues {
283            buffer: buf.finish(),
284            indices: Default::default(),
285            slices: Default::default(),
286            true_count,
287            density: true_count as f64 / len as f64,
288        }))
289    }
290
291    /// Create a new [`Mask`] from a [`Vec<(usize, usize)>`] where each range
292    /// represents a contiguous range of true values.
293    pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
294        Self::check_slices(len, &vec);
295        Self::from_slices_unchecked(len, vec)
296    }
297
298    fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
299        #[cfg(debug_assertions)]
300        Self::check_slices(len, &slices);
301
302        let true_count = slices.iter().map(|(b, e)| e - b).sum();
303        if true_count == 0 {
304            return Self::AllFalse(len);
305        }
306        if true_count == len {
307            return Self::AllTrue(len);
308        }
309
310        let mut buf = BooleanBufferBuilder::new(len);
311        for (start, end) in slices.iter().copied() {
312            buf.append_n(start - buf.len(), false);
313            buf.append_n(end - start, true);
314        }
315        if let Some((_, end)) = slices.last() {
316            buf.append_n(len - end, false);
317        }
318        debug_assert_eq!(buf.len(), len);
319
320        Self::Values(Arc::new(MaskValues {
321            buffer: buf.finish(),
322            indices: Default::default(),
323            slices: OnceLock::from(slices),
324            true_count,
325            density: true_count as f64 / len as f64,
326        }))
327    }
328
329    #[inline(always)]
330    fn check_slices(len: usize, vec: &[(usize, usize)]) {
331        assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
332        for (first, second) in vec.iter().tuple_windows() {
333            assert!(
334                first.0 < second.0,
335                "Slices must be sorted, got {first:?} and {second:?}"
336            );
337            assert!(
338                first.1 <= second.0,
339                "Slices must be non-overlapping, got {first:?} and {second:?}"
340            );
341        }
342    }
343
344    /// Create a new [`Mask`] from the intersection of two indices slices.
345    pub fn from_intersection_indices(
346        len: usize,
347        lhs: impl Iterator<Item = usize>,
348        rhs: impl Iterator<Item = usize>,
349    ) -> Self {
350        let mut intersection = Vec::with_capacity(len);
351        let mut lhs = lhs.peekable();
352        let mut rhs = rhs.peekable();
353        while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
354            match l.cmp(&r) {
355                Ordering::Less => {
356                    lhs.next();
357                }
358                Ordering::Greater => {
359                    rhs.next();
360                }
361                Ordering::Equal => {
362                    intersection.push(l);
363                    lhs.next();
364                    rhs.next();
365                }
366            }
367        }
368        Self::from_indices(len, intersection)
369    }
370
371    /// Returns the length of the mask (not the number of true values).
372    #[inline]
373    pub fn len(&self) -> usize {
374        match self {
375            Self::AllTrue(len) => *len,
376            Self::AllFalse(len) => *len,
377            Self::Values(values) => values.len(),
378        }
379    }
380
381    /// Returns true if the mask is empty i.e., it's length is 0.
382    #[inline]
383    pub fn is_empty(&self) -> bool {
384        match self {
385            Self::AllTrue(len) => *len == 0,
386            Self::AllFalse(len) => *len == 0,
387            Self::Values(values) => values.is_empty(),
388        }
389    }
390
391    /// Get the true count of the mask.
392    #[inline]
393    pub fn true_count(&self) -> usize {
394        match &self {
395            Self::AllTrue(len) => *len,
396            Self::AllFalse(_) => 0,
397            Self::Values(values) => values.true_count,
398        }
399    }
400
401    /// Get the false count of the mask.
402    #[inline]
403    pub fn false_count(&self) -> usize {
404        match &self {
405            Self::AllTrue(_) => 0,
406            Self::AllFalse(len) => *len,
407            Self::Values(values) => values.buffer.len() - values.true_count,
408        }
409    }
410
411    /// Returns true if all values in the mask are true.
412    #[inline]
413    pub fn all_true(&self) -> bool {
414        match &self {
415            Self::AllTrue(_) => true,
416            Self::AllFalse(0) => true,
417            Self::AllFalse(_) => false,
418            Self::Values(values) => values.buffer.len() == values.true_count,
419        }
420    }
421
422    /// Returns true if all values in the mask are false.
423    #[inline]
424    pub fn all_false(&self) -> bool {
425        self.true_count() == 0
426    }
427
428    /// Return the density of the full mask.
429    #[inline]
430    pub fn density(&self) -> f64 {
431        match &self {
432            Self::AllTrue(_) => 1.0,
433            Self::AllFalse(_) => 0.0,
434            Self::Values(values) => values.density,
435        }
436    }
437
438    /// Returns the boolean value at a given index.
439    ///
440    /// ## Panics
441    ///
442    /// Panics if the index is out of bounds.
443    pub fn value(&self, idx: usize) -> bool {
444        match self {
445            Mask::AllTrue(_) => true,
446            Mask::AllFalse(_) => false,
447            Mask::Values(values) => values.buffer.value(idx),
448        }
449    }
450
451    /// Returns the first true index in the mask.
452    pub fn first(&self) -> Option<usize> {
453        match &self {
454            Self::AllTrue(len) => (*len > 0).then_some(0),
455            Self::AllFalse(_) => None,
456            Self::Values(values) => {
457                if let Some(indices) = values.indices.get() {
458                    return indices.first().copied();
459                }
460                if let Some(slices) = values.slices.get() {
461                    return slices.first().map(|(start, _)| *start);
462                }
463                values.buffer.set_indices().next()
464            }
465        }
466    }
467
468    /// Slice the mask.
469    pub fn slice(&self, offset: usize, length: usize) -> Self {
470        assert!(offset + length <= self.len());
471        match &self {
472            Self::AllTrue(_) => Self::new_true(length),
473            Self::AllFalse(_) => Self::new_false(length),
474            Self::Values(values) => Self::from_buffer(values.buffer.slice(offset, length)),
475        }
476    }
477
478    /// Return the boolean buffer representation of the mask.
479    pub fn boolean_buffer(&self) -> AllOr<&BooleanBuffer> {
480        match &self {
481            Self::AllTrue(_) => AllOr::All,
482            Self::AllFalse(_) => AllOr::None,
483            Self::Values(values) => AllOr::Some(&values.buffer),
484        }
485    }
486
487    /// Return a boolean buffer representation of the mask, allocating new buffers for all-true
488    /// and all-false variants.
489    pub fn to_boolean_buffer(&self) -> BooleanBuffer {
490        match self {
491            Self::AllTrue(l) => BooleanBuffer::new_set(*l),
492            Self::AllFalse(l) => BooleanBuffer::new_unset(*l),
493            Self::Values(values) => values.boolean_buffer().clone(),
494        }
495    }
496
497    /// Returns an Arrow null buffer representation of the mask.
498    pub fn to_null_buffer(&self) -> Option<NullBuffer> {
499        match self {
500            Mask::AllTrue(_) => None,
501            Mask::AllFalse(l) => Some(NullBuffer::new_null(*l)),
502            Mask::Values(values) => Some(NullBuffer::from(values.buffer.clone())),
503        }
504    }
505
506    /// Return the indices representation of the mask.
507    pub fn indices(&self) -> AllOr<&[usize]> {
508        match &self {
509            Self::AllTrue(_) => AllOr::All,
510            Self::AllFalse(_) => AllOr::None,
511            Self::Values(values) => AllOr::Some(values.indices()),
512        }
513    }
514
515    /// Return the slices representation of the mask.
516    pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
517        match &self {
518            Self::AllTrue(_) => AllOr::All,
519            Self::AllFalse(_) => AllOr::None,
520            Self::Values(values) => AllOr::Some(values.slices()),
521        }
522    }
523
524    /// Return an iterator over either indices or slices of the mask based on a density threshold.
525    pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
526        match &self {
527            Self::AllTrue(_) => AllOr::All,
528            Self::AllFalse(_) => AllOr::None,
529            Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
530        }
531    }
532
533    /// Return [`MaskValues`] if the mask is not all true or all false.
534    pub fn values(&self) -> Option<&MaskValues> {
535        match self {
536            Self::Values(values) => Some(values),
537            _ => None,
538        }
539    }
540
541    /// Given monotonically increasing `indices` in [0, n_rows], returns the
542    /// count of valid elements up to each index.
543    ///
544    /// This is O(n_rows).
545    pub fn valid_counts_for_indices(&self, indices: &[usize]) -> VortexResult<Vec<usize>> {
546        Ok(match self {
547            Self::AllTrue(_) => indices.to_vec(),
548            Self::AllFalse(_) => vec![0; indices.len()],
549            Self::Values(values) => {
550                let mut bool_iter = values.boolean_buffer().iter();
551                let mut valid_counts = Vec::with_capacity(indices.len());
552                let mut valid_count = 0;
553                let mut idx = 0;
554                for &next_idx in indices {
555                    while idx < next_idx {
556                        idx += 1;
557                        valid_count += bool_iter
558                            .next()
559                            .ok_or_else(|| vortex_err!("Row indices exceed array length"))?
560                            as usize;
561                    }
562                    valid_counts.push(valid_count);
563                }
564
565                valid_counts
566            }
567        })
568    }
569
570    /// Limit the mask to the first `limit` true values
571    pub fn limit(self, limit: usize) -> Self {
572        // Early return optimization: if we're asking for more true values than the total
573        // length of the mask, then even if all values were true, we couldn't exceed the
574        // limit, so return the original mask unchanged.
575        if self.len() <= limit {
576            return self;
577        }
578
579        match self {
580            Mask::AllTrue(len) => {
581                Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
582            }
583            Mask::AllFalse(_) => self,
584            Mask::Values(ref mask_values) => {
585                if limit >= mask_values.true_count() {
586                    return self;
587                }
588
589                let existing_buffer = mask_values.boolean_buffer();
590
591                let mut new_buffer_builder = BooleanBufferBuilder::new(mask_values.len());
592                new_buffer_builder.append_n(mask_values.len(), false);
593
594                for index in existing_buffer.set_indices().take(limit) {
595                    new_buffer_builder.set_bit(index, true);
596                }
597
598                Self::from(new_buffer_builder.finish())
599            }
600        }
601    }
602}
603
604/// Iterator over the indices or slices of a mask.
605pub enum MaskIter<'a> {
606    /// Slice of pre-cached indices of a mask.
607    Indices(&'a [usize]),
608    /// Slice of pre-cached slices of a mask.
609    Slices(&'a [(usize, usize)]),
610}
611
612impl From<BooleanBuffer> for Mask {
613    fn from(value: BooleanBuffer) -> Self {
614        Self::from_buffer(value)
615    }
616}
617
618impl FromIterator<bool> for Mask {
619    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
620        Self::from_buffer(BooleanBuffer::from_iter(iter))
621    }
622}
623
624impl FromIterator<Mask> for Mask {
625    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
626        let masks = iter
627            .into_iter()
628            .filter(|m| !m.is_empty())
629            .collect::<Vec<_>>();
630        let total_length = masks.iter().map(|v| v.len()).sum();
631
632        // If they're all valid, then return a single validity.
633        if masks.iter().all(|v| v.all_true()) {
634            return Self::AllTrue(total_length);
635        }
636        // If they're all invalid, then return a single invalidity.
637        if masks.iter().all(|v| v.all_false()) {
638            return Self::AllFalse(total_length);
639        }
640
641        // Else, construct the boolean buffer
642        let mut buffer = BooleanBufferBuilder::new(total_length);
643        for mask in masks {
644            match mask {
645                Mask::AllTrue(count) => buffer.append_n(count, true),
646                Mask::AllFalse(count) => buffer.append_n(count, false),
647                Mask::Values(values) => {
648                    buffer.append_buffer(values.boolean_buffer());
649                }
650            };
651        }
652        Self::from_buffer(buffer.finish())
653    }
654}