Skip to main content

vortex_mask/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! A mask is a set of sorted unique positive integers.
5#![deny(missing_docs)]
6
7mod bitops;
8mod eq;
9mod intersect_by_rank;
10
11#[cfg(test)]
12mod tests;
13
14use std::cmp::Ordering;
15use std::fmt::Debug;
16use std::fmt::Formatter;
17use std::ops::Bound;
18use std::ops::RangeBounds;
19use std::sync::Arc;
20use std::sync::OnceLock;
21
22use itertools::Itertools;
23use vortex_buffer::BitBuffer;
24use vortex_buffer::BitBufferMut;
25use vortex_buffer::BitIterator;
26use vortex_error::VortexResult;
27use vortex_error::vortex_panic;
28
29/// Represents a set of values that are all included, all excluded, or some mixture of both.
30pub enum AllOr<T> {
31    /// All values are included.
32    All,
33    /// No values are included.
34    None,
35    /// Some values are included.
36    Some(T),
37}
38
39impl<T> AllOr<T> {
40    /// Returns the `Some` variant of the enum, or a default value.
41    #[inline]
42    pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
43    where
44        F: FnOnce() -> T,
45        G: FnOnce() -> T,
46    {
47        match self {
48            Self::Some(v) => v,
49            AllOr::All => all_true(),
50            AllOr::None => all_false(),
51        }
52    }
53}
54
55impl<T> AllOr<&T> {
56    /// Clone the inner value.
57    #[inline]
58    pub fn cloned(self) -> AllOr<T>
59    where
60        T: Clone,
61    {
62        match self {
63            Self::All => AllOr::All,
64            Self::None => AllOr::None,
65            Self::Some(v) => AllOr::Some(v.clone()),
66        }
67    }
68}
69
70impl<T> Debug for AllOr<T>
71where
72    T: Debug,
73{
74    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
75        match self {
76            Self::All => f.write_str("All"),
77            Self::None => f.write_str("None"),
78            Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
79        }
80    }
81}
82
83impl<T> PartialEq for AllOr<T>
84where
85    T: PartialEq,
86{
87    fn eq(&self, other: &Self) -> bool {
88        match (self, other) {
89            (Self::All, Self::All) => true,
90            (Self::None, Self::None) => true,
91            (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
92            _ => false,
93        }
94    }
95}
96
97impl<T> Eq for AllOr<T> where T: Eq {}
98
99/// Represents a set of sorted unique positive integers.
100/// If a value is included in a Mask, it's valid.
101///
102/// A [`Mask`] can be constructed from various representations, and converted to various
103/// others. Internally, these are cached.
104#[derive(Clone)]
105#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
106pub enum Mask {
107    /// All values are included.
108    AllTrue(usize),
109    /// No values are included.
110    AllFalse(usize),
111    /// Some values are included, represented as a [`BitBuffer`].
112    Values(Arc<MaskValues>),
113}
114
115impl Debug for Mask {
116    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
117        match self {
118            Self::AllTrue(len) => write!(f, "All true({len})"),
119            Self::AllFalse(len) => write!(f, "All false({len})"),
120            Self::Values(mask) => write!(f, "{mask:?}"),
121        }
122    }
123}
124
125impl Default for Mask {
126    fn default() -> Self {
127        Self::new_true(0)
128    }
129}
130
131/// Represents the values of a [`Mask`] that contains some true and some false elements.
132#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
133pub struct MaskValues {
134    buffer: BitBuffer,
135
136    // We cached the indices and slices representations, since it can be faster than iterating
137    // the bit-mask over and over again.
138    #[cfg_attr(feature = "serde", serde(skip))]
139    indices: OnceLock<Vec<usize>>,
140    #[cfg_attr(feature = "serde", serde(skip))]
141    slices: OnceLock<Vec<(usize, usize)>>,
142
143    // Pre-computed values.
144    true_count: usize,
145    // i.e., the fraction of values that are true
146    density: f64,
147}
148
149impl Debug for MaskValues {
150    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
151        write!(f, "true_count={}, ", self.true_count)?;
152        write!(f, "density={}, ", self.density)?;
153        if let Some(v) = self.indices.get() {
154            write!(f, "indices={v:?}, ")?;
155        }
156        if let Some(v) = self.slices.get() {
157            write!(f, "slices={v:?}, ")?;
158        }
159        if f.alternate() {
160            f.write_str("\n")?;
161        }
162        write!(f, "{}", self.buffer)
163    }
164}
165
166impl Mask {
167    /// Create a new Mask with the given length.
168    pub fn new(length: usize, value: bool) -> Self {
169        if value {
170            Self::AllTrue(length)
171        } else {
172            Self::AllFalse(length)
173        }
174    }
175
176    /// Create a new Mask where all values are set.
177    #[inline]
178    pub fn new_true(length: usize) -> Self {
179        Self::AllTrue(length)
180    }
181
182    /// Create a new Mask where no values are set.
183    #[inline]
184    pub fn new_false(length: usize) -> Self {
185        Self::AllFalse(length)
186    }
187
188    /// Create a new [`Mask`] from a [`BitBuffer`].
189    pub fn from_buffer(buffer: BitBuffer) -> Self {
190        let len = buffer.len();
191        let true_count = buffer.true_count();
192
193        if true_count == 0 {
194            return Self::AllFalse(len);
195        }
196        if true_count == len {
197            return Self::AllTrue(len);
198        }
199
200        Self::Values(Arc::new(MaskValues {
201            buffer,
202            indices: Default::default(),
203            slices: Default::default(),
204            true_count,
205            density: true_count as f64 / len as f64,
206        }))
207    }
208
209    /// Create a new [`Mask`] from sorted, unique indices.
210    pub fn from_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
211        let indices = indices.into_iter().collect::<Vec<_>>();
212        assert!(indices.is_sorted(), "Mask indices must be sorted");
213        assert!(
214            indices.windows(2).all(|w| w[0] != w[1]),
215            "Mask indices must be unique"
216        );
217        let buffer = BitBuffer::from_indices(len, indices.iter().copied());
218        debug_assert_eq!(buffer.len(), len);
219        let true_count = buffer.true_count();
220
221        if true_count == 0 {
222            return Self::AllFalse(len);
223        }
224        if true_count == len {
225            return Self::AllTrue(len);
226        }
227
228        Self::Values(Arc::new(MaskValues {
229            buffer,
230            indices: OnceLock::from(indices),
231            slices: Default::default(),
232            true_count,
233            density: true_count as f64 / len as f64,
234        }))
235    }
236
237    /// Create a new [`Mask`] from an [`IntoIterator<Item = usize>`] of indices to be excluded.
238    pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
239        let mut buf = BitBufferMut::new_set(len);
240
241        let mut false_count: usize = 0;
242        indices.into_iter().for_each(|idx| {
243            buf.unset(idx);
244            false_count += 1;
245        });
246        debug_assert_eq!(buf.len(), len);
247        let true_count = len - false_count;
248
249        // Return optimized variants when appropriate
250        if false_count == 0 {
251            return Self::AllTrue(len);
252        }
253        if false_count == len {
254            return Self::AllFalse(len);
255        }
256
257        Self::Values(Arc::new(MaskValues {
258            buffer: buf.freeze(),
259            indices: Default::default(),
260            slices: Default::default(),
261            true_count,
262            density: true_count as f64 / len as f64,
263        }))
264    }
265
266    /// Create a new [`Mask`] from a [`Vec<(usize, usize)>`] where each range
267    /// represents a contiguous range of true values.
268    pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
269        Self::check_slices(len, &vec);
270        Self::from_slices_unchecked(len, vec)
271    }
272
273    fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
274        #[cfg(debug_assertions)]
275        Self::check_slices(len, &slices);
276
277        let true_count = slices.iter().map(|(b, e)| e - b).sum();
278        if true_count == 0 {
279            return Self::AllFalse(len);
280        }
281        if true_count == len {
282            return Self::AllTrue(len);
283        }
284
285        let mut buf = BitBufferMut::with_capacity(len);
286        let mut cursor = 0;
287        for (start, end) in slices.iter().copied() {
288            buf.append_n(false, start - cursor);
289            buf.append_n(true, end - start);
290            cursor = end;
291        }
292        buf.append_n(false, len - cursor);
293        debug_assert_eq!(buf.len(), len);
294
295        Self::Values(Arc::new(MaskValues {
296            buffer: buf.freeze(),
297            indices: Default::default(),
298            slices: OnceLock::from(slices),
299            true_count,
300            density: true_count as f64 / len as f64,
301        }))
302    }
303
304    #[inline(always)]
305    fn check_slices(len: usize, vec: &[(usize, usize)]) {
306        assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
307        for (first, second) in vec.iter().tuple_windows() {
308            assert!(
309                first.0 < second.0,
310                "Slices must be sorted, got {first:?} and {second:?}"
311            );
312            assert!(
313                first.1 <= second.0,
314                "Slices must be non-overlapping, got {first:?} and {second:?}"
315            );
316        }
317    }
318
319    /// Create a new [`Mask`] from the intersection of two indices slices.
320    pub fn from_intersection_indices(
321        len: usize,
322        lhs: impl Iterator<Item = usize>,
323        rhs: impl Iterator<Item = usize>,
324    ) -> Self {
325        let mut intersection = Vec::with_capacity(len);
326        let mut lhs = lhs.peekable();
327        let mut rhs = rhs.peekable();
328        while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
329            match l.cmp(&r) {
330                Ordering::Less => {
331                    lhs.next();
332                }
333                Ordering::Greater => {
334                    rhs.next();
335                }
336                Ordering::Equal => {
337                    intersection.push(l);
338                    lhs.next();
339                    rhs.next();
340                }
341            }
342        }
343        Self::from_indices(len, intersection)
344    }
345
346    /// Clears the mask of all data. Drops any allocated capacity.
347    pub fn clear(&mut self) {
348        *self = Self::new_false(0);
349    }
350
351    /// Returns the length of the mask (not the number of true values).
352    #[inline]
353    pub fn len(&self) -> usize {
354        match self {
355            Self::AllTrue(len) => *len,
356            Self::AllFalse(len) => *len,
357            Self::Values(values) => values.len(),
358        }
359    }
360
361    /// Returns true if the mask is empty i.e., it's length is 0.
362    #[inline]
363    pub fn is_empty(&self) -> bool {
364        match self {
365            Self::AllTrue(len) => *len == 0,
366            Self::AllFalse(len) => *len == 0,
367            Self::Values(values) => values.is_empty(),
368        }
369    }
370
371    /// Get the true count of the mask.
372    #[inline]
373    pub fn true_count(&self) -> usize {
374        match &self {
375            Self::AllTrue(len) => *len,
376            Self::AllFalse(_) => 0,
377            Self::Values(values) => values.true_count,
378        }
379    }
380
381    /// Get the false count of the mask.
382    #[inline]
383    pub fn false_count(&self) -> usize {
384        match &self {
385            Self::AllTrue(_) => 0,
386            Self::AllFalse(len) => *len,
387            Self::Values(values) => values.buffer.len() - values.true_count,
388        }
389    }
390
391    /// Returns true if all values in the mask are true.
392    #[inline]
393    pub fn all_true(&self) -> bool {
394        match &self {
395            Self::AllTrue(_) => true,
396            Self::AllFalse(0) => true,
397            Self::AllFalse(_) => false,
398            Self::Values(values) => values.buffer.len() == values.true_count,
399        }
400    }
401
402    /// Returns true if all values in the mask are false.
403    #[inline]
404    pub fn all_false(&self) -> bool {
405        self.true_count() == 0
406    }
407
408    /// Return the density of the full mask.
409    #[inline]
410    pub fn density(&self) -> f64 {
411        match &self {
412            Self::AllTrue(_) => 1.0,
413            Self::AllFalse(_) => 0.0,
414            Self::Values(values) => values.density,
415        }
416    }
417
418    /// Returns the boolean value at a given index.
419    ///
420    /// ## Panics
421    ///
422    /// Panics if the index is out of bounds.
423    #[inline]
424    pub fn value(&self, idx: usize) -> bool {
425        match self {
426            Mask::AllTrue(_) => true,
427            Mask::AllFalse(_) => false,
428            Mask::Values(values) => values.buffer.value(idx),
429        }
430    }
431
432    /// Iterate the mask as one `bool` per element, in order.
433    ///
434    /// Unlike repeatedly calling [`Mask::value`], this advances a single cursor rather than
435    /// recomputing the byte/bit offset for every element, and it does not allocate for the
436    /// all-true / all-false variants. Prefer this for sequential per-element scans.
437    #[inline]
438    pub fn iter(&self) -> MaskBoolIter<'_> {
439        match self {
440            Mask::AllTrue(len) => MaskBoolIter::Repeat {
441                value: true,
442                remaining: *len,
443            },
444            Mask::AllFalse(len) => MaskBoolIter::Repeat {
445                value: false,
446                remaining: *len,
447            },
448            Mask::Values(values) => MaskBoolIter::Bits(values.bit_buffer().iter()),
449        }
450    }
451
452    /// Returns the first true index in the mask.
453    pub fn first(&self) -> Option<usize> {
454        match &self {
455            Self::AllTrue(len) => (*len > 0).then_some(0),
456            Self::AllFalse(_) => None,
457            Self::Values(values) => {
458                if let Some(indices) = values.indices.get() {
459                    return indices.first().copied();
460                }
461                if let Some(slices) = values.slices.get() {
462                    return slices.first().map(|(start, _)| *start);
463                }
464                values.buffer.set_indices().next()
465            }
466        }
467    }
468
469    /// Returns the last true index in the mask.
470    pub fn last(&self) -> Option<usize> {
471        match &self {
472            Self::AllTrue(len) => (*len > 0).then_some(*len - 1),
473            Self::AllFalse(_) => None,
474            Self::Values(values) => {
475                if let Some(indices) = values.indices.get() {
476                    return indices.last().copied();
477                }
478                if let Some(slices) = values.slices.get() {
479                    return slices.last().map(|(_, end)| end - 1);
480                }
481
482                if values.true_count == 0 {
483                    return None;
484                }
485
486                Some(
487                    values
488                        .buffer
489                        .select(values.true_count - 1)
490                        .unwrap_or_else(|| {
491                            vortex_panic!(
492                                "Rank {} out of bounds for mask with true count {}",
493                                values.true_count - 1,
494                                values.true_count
495                            )
496                        }),
497                )
498            }
499        }
500    }
501
502    /// Returns the position in the mask of the nth true value.
503    pub fn rank(&self, n: usize) -> usize {
504        if n >= self.true_count() {
505            vortex_panic!(
506                "Rank {n} out of bounds for mask with true count {}",
507                self.true_count()
508            );
509        }
510        match &self {
511            Self::AllTrue(_) => n,
512            Self::AllFalse(_) => unreachable!("no true values in all-false mask"),
513            Self::Values(values) => {
514                if let Some(indices) = values.indices.get() {
515                    return indices[n];
516                }
517
518                values.buffer.select(n).unwrap_or_else(|| {
519                    vortex_panic!(
520                        "Rank {} out of bounds for mask with true count {}",
521                        values.true_count - 1,
522                        values.true_count
523                    )
524                })
525            }
526        }
527    }
528
529    /// Slice the mask.
530    pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
531        let start = match range.start_bound() {
532            Bound::Included(&s) => s,
533            Bound::Excluded(&s) => s + 1,
534            Bound::Unbounded => 0,
535        };
536        let end = match range.end_bound() {
537            Bound::Included(&e) => e + 1,
538            Bound::Excluded(&e) => e,
539            Bound::Unbounded => self.len(),
540        };
541
542        assert!(start <= end);
543        assert!(start <= self.len());
544        assert!(end <= self.len());
545        let len = end - start;
546
547        match &self {
548            Self::AllTrue(_) => Self::new_true(len),
549            Self::AllFalse(_) => Self::new_false(len),
550            Self::Values(values) => Self::from_buffer(values.buffer.slice(range)),
551        }
552    }
553
554    /// Return the boolean buffer representation of the mask.
555    #[inline]
556    pub fn bit_buffer(&self) -> AllOr<&BitBuffer> {
557        match &self {
558            Self::AllTrue(_) => AllOr::All,
559            Self::AllFalse(_) => AllOr::None,
560            Self::Values(values) => AllOr::Some(&values.buffer),
561        }
562    }
563
564    /// Return a boolean buffer representation of the mask, allocating new buffers for all-true
565    /// and all-false variants.
566    #[inline]
567    pub fn to_bit_buffer(&self) -> BitBuffer {
568        match self {
569            Self::AllTrue(l) => BitBuffer::new_set(*l),
570            Self::AllFalse(l) => BitBuffer::new_unset(*l),
571            Self::Values(values) => values.bit_buffer().clone(),
572        }
573    }
574
575    /// Return a boolean buffer representation of the mask, allocating new buffers for all-true
576    /// and all-false variants.
577    #[inline]
578    pub fn into_bit_buffer(self) -> BitBuffer {
579        match self {
580            Self::AllTrue(l) => BitBuffer::new_set(l),
581            Self::AllFalse(l) => BitBuffer::new_unset(l),
582            Self::Values(values) => Arc::try_unwrap(values)
583                .map(|v| v.into_bit_buffer())
584                .unwrap_or_else(|v| v.bit_buffer().clone()),
585        }
586    }
587
588    /// Return the indices representation of the mask.
589    #[inline]
590    pub fn indices(&self) -> AllOr<&[usize]> {
591        match &self {
592            Self::AllTrue(_) => AllOr::All,
593            Self::AllFalse(_) => AllOr::None,
594            Self::Values(values) => AllOr::Some(values.indices()),
595        }
596    }
597
598    /// Return the slices representation of the mask.
599    #[inline]
600    pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
601        match &self {
602            Self::AllTrue(_) => AllOr::All,
603            Self::AllFalse(_) => AllOr::None,
604            Self::Values(values) => AllOr::Some(values.slices()),
605        }
606    }
607
608    /// Return an iterator over either indices or slices of the mask based on a density threshold.
609    #[inline]
610    pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
611        match &self {
612            Self::AllTrue(_) => AllOr::All,
613            Self::AllFalse(_) => AllOr::None,
614            Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
615        }
616    }
617
618    /// Return [`MaskValues`] if the mask is not all true or all false.
619    #[inline]
620    pub fn values(&self) -> Option<&MaskValues> {
621        if let Self::Values(values) = self {
622            Some(values)
623        } else {
624            None
625        }
626    }
627
628    /// Given monotonically increasing `indices` in [0, n_rows], returns the
629    /// count of valid elements up to each index.
630    ///
631    /// This is O(n_rows).
632    pub fn valid_counts_for_indices(&self, indices: &[usize]) -> Vec<usize> {
633        match self {
634            Self::AllTrue(_) => indices.to_vec(),
635            Self::AllFalse(_) => vec![0; indices.len()],
636            Self::Values(values) => {
637                let mut bool_iter = values.bit_buffer().iter();
638                let mut valid_counts = Vec::with_capacity(indices.len());
639                let mut valid_count = 0;
640                let mut idx = 0;
641                for &next_idx in indices {
642                    while idx < next_idx {
643                        idx += 1;
644                        valid_count += bool_iter
645                            .next()
646                            .unwrap_or_else(|| vortex_panic!("Row indices exceed array length"))
647                            as usize;
648                    }
649                    valid_counts.push(valid_count);
650                }
651
652                valid_counts
653            }
654        }
655    }
656
657    /// Limit the mask to the first `limit` true values
658    pub fn limit(self, limit: usize) -> Self {
659        // Early return optimization: if we're asking for more true values than the total
660        // length of the mask, then even if all values were true, we couldn't exceed the
661        // limit, so return the original mask unchanged.
662        if self.len() <= limit {
663            return self;
664        }
665
666        match &self {
667            Mask::AllTrue(len) => {
668                Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
669            }
670            Mask::AllFalse(_) => self,
671            Mask::Values(mask_values) => {
672                if limit >= mask_values.true_count() {
673                    return self;
674                }
675
676                let existing_buffer = mask_values.bit_buffer();
677
678                let mut new_buffer_builder = BitBufferMut::new_unset(mask_values.len());
679                debug_assert!(limit < mask_values.len());
680
681                for index in existing_buffer.set_indices().take(limit) {
682                    // SAFETY: We checked that `limit` was less than the mask values length,
683                    // therefore `index` must be within the bounds of the bit buffer.
684                    unsafe { new_buffer_builder.set_unchecked(index) }
685                }
686
687                Self::from(new_buffer_builder.freeze())
688            }
689        }
690    }
691
692    /// Concatenate multiple masks together into a single mask.
693    pub fn concat<'a>(masks: impl Iterator<Item = &'a Self>) -> VortexResult<Self> {
694        let masks: Vec<_> = masks.collect();
695        let len = masks.iter().map(|t| t.len()).sum();
696
697        if masks.iter().all(|t| t.all_true()) {
698            return Ok(Mask::AllTrue(len));
699        }
700
701        if masks.iter().all(|t| t.all_false()) {
702            return Ok(Mask::AllFalse(len));
703        }
704
705        let mut builder = BitBufferMut::with_capacity(len);
706
707        for mask in masks {
708            match mask {
709                Mask::AllTrue(n) => builder.append_n(true, *n),
710                Mask::AllFalse(n) => builder.append_n(false, *n),
711                Mask::Values(v) => builder.append_buffer(v.bit_buffer()),
712            }
713        }
714
715        Ok(Mask::from_buffer(builder.freeze()))
716    }
717}
718
719impl MaskValues {
720    /// Returns the length of the mask.
721    #[inline]
722    pub fn len(&self) -> usize {
723        self.buffer.len()
724    }
725
726    /// Returns true if the mask is empty i.e., it's length is 0.
727    #[inline]
728    pub fn is_empty(&self) -> bool {
729        self.buffer.is_empty()
730    }
731
732    /// Returns the density of the mask.
733    #[inline]
734    pub fn density(&self) -> f64 {
735        self.density
736    }
737
738    /// Returns the true count of the mask.
739    #[inline]
740    pub fn true_count(&self) -> usize {
741        self.true_count
742    }
743
744    /// Returns the boolean buffer representation of the mask.
745    #[inline]
746    pub fn bit_buffer(&self) -> &BitBuffer {
747        &self.buffer
748    }
749
750    /// Returns the boolean buffer representation of the mask.
751    #[inline]
752    pub fn into_bit_buffer(self) -> BitBuffer {
753        self.buffer
754    }
755
756    /// Returns the boolean value at a given index.
757    #[inline]
758    pub fn value(&self, index: usize) -> bool {
759        self.buffer.value(index)
760    }
761
762    /// Constructs an indices vector from one of the other representations.
763    pub fn indices(&self) -> &[usize] {
764        self.indices.get_or_init(|| {
765            if self.true_count == 0 {
766                return vec![];
767            }
768
769            if self.true_count == self.len() {
770                return (0..self.len()).collect();
771            }
772
773            if let Some(slices) = self.slices.get() {
774                let mut indices = Vec::with_capacity(self.true_count);
775                indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
776                debug_assert!(indices.is_sorted());
777                assert_eq!(indices.len(), self.true_count);
778                return indices;
779            }
780
781            let mut indices = Vec::with_capacity(self.true_count);
782            indices.extend(self.buffer.set_indices());
783            debug_assert!(indices.is_sorted());
784            assert_eq!(indices.len(), self.true_count);
785            indices
786        })
787    }
788
789    /// Returns cached index positions when this mask already has them materialized.
790    ///
791    /// Unlike [`Self::indices`], this does not build the index vector from another
792    /// representation.
793    #[inline]
794    pub fn cached_indices(&self) -> Option<&[usize]> {
795        self.indices.get().map(Vec::as_slice)
796    }
797
798    /// Constructs a slices vector from one of the other representations.
799    #[inline]
800    pub fn slices(&self) -> &[(usize, usize)] {
801        self.slices.get_or_init(|| {
802            if self.true_count == self.len() {
803                return vec![(0, self.len())];
804            }
805
806            self.buffer.set_slices().collect()
807        })
808    }
809
810    /// Returns cached true-value ranges when this mask already has them materialized.
811    ///
812    /// Unlike [`Self::slices`], this does not build the slice vector from another
813    /// representation.
814    #[inline]
815    pub fn cached_slices(&self) -> Option<&[(usize, usize)]> {
816        self.slices.get().map(Vec::as_slice)
817    }
818
819    /// Return an iterator over either indices or slices of the mask based on a density threshold.
820    #[inline]
821    pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
822        if self.density >= threshold {
823            MaskIter::Slices(self.slices())
824        } else {
825            MaskIter::Indices(self.indices())
826        }
827    }
828}
829
830/// Iterator over the indices or slices of a mask.
831pub enum MaskIter<'a> {
832    /// Slice of pre-cached indices of a mask.
833    Indices(&'a [usize]),
834    /// Slice of pre-cached slices of a mask.
835    Slices(&'a [(usize, usize)]),
836}
837
838/// Iterator yielding one `bool` per element of a [`Mask`], in order.
839///
840/// Created by [`Mask::iter`].
841pub enum MaskBoolIter<'a> {
842    /// An all-true or all-false run.
843    Repeat {
844        /// The constant value yielded by every element of the run.
845        value: bool,
846        /// The number of elements still to yield.
847        remaining: usize,
848    },
849    /// Per-element bits of a [`Mask::Values`] mask.
850    Bits(BitIterator<'a>),
851}
852
853impl Iterator for MaskBoolIter<'_> {
854    type Item = bool;
855
856    #[inline]
857    fn next(&mut self) -> Option<Self::Item> {
858        match self {
859            Self::Repeat { remaining: 0, .. } => None,
860            Self::Repeat { value, remaining } => {
861                *remaining -= 1;
862                Some(*value)
863            }
864            Self::Bits(bits) => bits.next(),
865        }
866    }
867
868    #[inline]
869    fn size_hint(&self) -> (usize, Option<usize>) {
870        let remaining = match self {
871            Self::Repeat { remaining, .. } => *remaining,
872            Self::Bits(bits) => bits.len(),
873        };
874        (remaining, Some(remaining))
875    }
876}
877
878impl ExactSizeIterator for MaskBoolIter<'_> {}
879
880impl From<BitBuffer> for Mask {
881    fn from(value: BitBuffer) -> Self {
882        Self::from_buffer(value)
883    }
884}
885
886impl FromIterator<bool> for Mask {
887    #[inline]
888    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
889        Self::from_buffer(BitBuffer::from_iter(iter))
890    }
891}
892
893impl FromIterator<Mask> for Mask {
894    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
895        let masks = iter
896            .into_iter()
897            .filter(|m| !m.is_empty())
898            .collect::<Vec<_>>();
899        let total_length = masks.iter().map(|v| v.len()).sum();
900
901        // If they're all valid, then return a single validity.
902        if masks.iter().all(|v| v.all_true()) {
903            return Self::AllTrue(total_length);
904        }
905        // If they're all invalid, then return a single invalidity.
906        if masks.iter().all(|v| v.all_false()) {
907            return Self::AllFalse(total_length);
908        }
909
910        // Else, construct the boolean buffer
911        let mut buffer = BitBufferMut::with_capacity(total_length);
912        for mask in masks {
913            match mask {
914                Mask::AllTrue(count) => buffer.append_n(true, count),
915                Mask::AllFalse(count) => buffer.append_n(false, count),
916                Mask::Values(values) => {
917                    buffer.append_buffer(values.bit_buffer());
918                }
919            };
920        }
921        Self::from_buffer(buffer.freeze())
922    }
923}