vortex_layout/
row_mask.rs

1use std::cmp::{max, min};
2use std::fmt::{Display, Formatter};
3use std::ops::{Range, RangeBounds};
4
5use vortex_array::compute::{filter, slice, try_cast};
6use vortex_array::{Array, ArrayRef, ToCanonical};
7use vortex_buffer::Buffer;
8use vortex_dtype::Nullability::NonNullable;
9use vortex_dtype::{DType, PType};
10use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
11use vortex_mask::Mask;
12
13/// A RowMask captures a set of selected rows within a range.
14///
15/// The range itself can be [`u64`], but the length of the range must fit into a [`usize`], this
16/// allows us to use a `usize` filter mask within a much larger file.
17#[derive(Debug, Clone)]
18pub struct RowMask {
19    mask: Mask,
20    begin: u64,
21    end: u64,
22}
23
24// We don't want to implement full logical equality, this naive equality is sufficient for tests.
25#[cfg(test)]
26impl PartialEq for RowMask {
27    fn eq(&self, other: &Self) -> bool {
28        self.begin == other.begin && self.end == other.end && self.mask == other.mask
29    }
30}
31
32impl Display for RowMask {
33    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34        write!(f, "RowSelector [{}..{}]", self.begin, self.end)
35    }
36}
37
38impl RowMask {
39    /// Define a new [`RowMask`] with the given mask and offset into the file.
40    pub fn new(mask: Mask, begin: u64) -> Self {
41        let end = begin + (mask.len() as u64);
42        Self { mask, begin, end }
43    }
44
45    /// Construct a RowMask which is valid in the given range.
46    ///
47    /// ## Panics
48    ///
49    /// If the size of the range is too large to fit into a usize.
50    pub fn new_valid_between(begin: u64, end: u64) -> Self {
51        let length =
52            usize::try_from(end - begin).vortex_expect("Range length does not fit into a usize");
53        RowMask::new(Mask::new_true(length), begin)
54    }
55
56    /// Construct a RowMask which is invalid everywhere in the given range.
57    pub fn new_invalid_between(begin: u64, end: u64) -> Self {
58        let length =
59            usize::try_from(end - begin).vortex_expect("Range length does not fit into a usize");
60        RowMask::new(Mask::new_false(length), begin)
61    }
62
63    /// Creates a RowMask from an array, only supported boolean and integer types.
64    pub fn from_array(array: &dyn Array, begin: u64, end: u64) -> VortexResult<Self> {
65        if array.dtype().is_int() {
66            Self::from_index_array(array, begin, end)
67        } else if array.dtype().is_boolean() {
68            Self::from_mask_array(array, begin)
69        } else {
70            vortex_bail!(
71                "RowMask can only be created from integer or boolean arrays, got {} instead.",
72                array.dtype()
73            );
74        }
75    }
76
77    /// Construct a RowMask from a Boolean typed array.
78    ///
79    /// True-valued positions are kept by the returned mask.
80    fn from_mask_array(array: &dyn Array, begin: u64) -> VortexResult<Self> {
81        Ok(Self::new(array.validity_mask()?, begin))
82    }
83
84    /// Construct a RowMask from an integral array.
85    ///
86    /// The array values are interpreted as indices and those indices are kept by the returned mask.
87    #[allow(clippy::cast_possible_truncation)]
88    fn from_index_array(array: &dyn Array, begin: u64, end: u64) -> VortexResult<Self> {
89        let length = usize::try_from(end - begin)
90            .map_err(|_| vortex_err!("Range length does not fit into a usize"))?;
91
92        let indices =
93            try_cast(array, &DType::Primitive(PType::U64, NonNullable))?.to_primitive()?;
94
95        let mask = Mask::from_indices(
96            length,
97            indices
98                .as_slice::<u64>()
99                .iter()
100                .map(|i| *i as usize)
101                .collect(),
102        );
103
104        Ok(RowMask::new(mask, begin))
105    }
106
107    /// Whether the mask is disjoint with the given range.
108    ///
109    /// This function may return false negatives, but never false positives.
110    ///
111    /// TODO(ngates): improve this function to take into account the [`Mask`].
112    pub fn is_disjoint(&self, range: impl RangeBounds<u64>) -> bool {
113        use std::ops::Bound;
114
115        // Get the start bound of the input range
116        let start = match range.start_bound() {
117            Bound::Included(&n) => n,
118            Bound::Excluded(&n) => n + 1,
119            Bound::Unbounded => 0,
120        };
121
122        // Get the end bound of the input range
123        let end = match range.end_bound() {
124            Bound::Included(&n) => n + 1,
125            Bound::Excluded(&n) => n,
126            Bound::Unbounded => u64::MAX,
127        };
128
129        // Two ranges are disjoint if one ends before the other begins
130        self.end <= start || end <= self.begin
131    }
132
133    /// The beginning of the masked range.
134    #[inline]
135    pub fn begin(&self) -> u64 {
136        self.begin
137    }
138
139    /// The end of the masked range.
140    #[inline]
141    pub fn end(&self) -> u64 {
142        self.end
143    }
144
145    /// The length of the mask is the number of possible rows between the `begin` and `end`,
146    /// regardless of how many appear in the mask. For the number of masked rows, see `true_count`.
147    #[inline]
148    // There is good definition of is_empty, does it mean len == 0 or true_count == 0?
149    #[allow(clippy::len_without_is_empty)]
150    pub fn len(&self) -> usize {
151        self.mask.len()
152    }
153
154    /// Returns the [`Mask`] whose true values are relative to the range of this `RowMask`.
155    pub fn filter_mask(&self) -> &Mask {
156        &self.mask
157    }
158
159    /// Limit mask to `[begin..end)` range
160    pub fn slice(&self, begin: u64, end: u64) -> VortexResult<Self> {
161        let range_begin = max(self.begin, begin);
162        let range_end = min(self.end, end);
163        Ok(RowMask::new(
164            if range_begin == self.begin && range_end == self.end {
165                self.mask.clone()
166            } else {
167                self.mask.slice(
168                    usize::try_from(range_begin - self.begin)
169                        .vortex_expect("we know this must fit into usize"),
170                    usize::try_from(range_end - range_begin)
171                        .vortex_expect("we know this must fit into usize"),
172                )
173            },
174            range_begin,
175        ))
176    }
177
178    /// Filter array with this `RowMask`.
179    ///
180    /// This function assumes that Array is no longer than the mask length and that the mask starts on same offset as the array,
181    /// i.e. the beginning of the array corresponds to the beginning of the mask with begin = 0
182    pub fn filter_array(&self, array: &dyn Array) -> VortexResult<Option<ArrayRef>> {
183        let true_count = self.mask.true_count();
184        if true_count == 0 {
185            return Ok(None);
186        }
187
188        let sliced = if self.len() == array.len() {
189            array
190        } else {
191            // TODO(ngates): I thought the point was the array only covers the valid row range of
192            //  the mask?
193            // FIXME(ngates): this is made more obvious by the unsafe u64 cast.
194            &slice(
195                array,
196                usize::try_from(self.begin).vortex_expect("TODO(ngates): fix this bad cast"),
197                usize::try_from(self.end).vortex_expect("TODO(ngates): fix this bad cast"),
198            )?
199        };
200
201        if true_count == sliced.len() {
202            return Ok(Some(sliced.to_array()));
203        }
204
205        filter(sliced, &self.mask).map(Some)
206    }
207
208    /// Shift the [`RowMask`] down by the given offset.
209    pub fn shift(self, offset: u64) -> VortexResult<RowMask> {
210        let valid_shift = self.begin >= offset;
211        if !valid_shift {
212            vortex_bail!(
213                "Can shift RowMask by at most {}, tried to shift by {offset}",
214                self.begin
215            )
216        }
217        Ok(RowMask::new(self.mask, self.begin - offset))
218    }
219
220    /// The number of masked rows within the range.
221    pub fn true_count(&self) -> usize {
222        self.mask.true_count()
223    }
224}
225
226pub fn range_intersection(range: &Range<u64>, row_indices: &Buffer<u64>) -> Option<Range<usize>> {
227    if row_indices.first().is_some_and(|&first| first >= range.end)
228        || row_indices.last().is_some_and(|&last| range.start >= last)
229    {
230        return None;
231    }
232
233    // For the given row range, find the indices that are within the row_indices.
234    let start_idx = row_indices
235        .binary_search(&range.start)
236        .unwrap_or_else(|x| x);
237    let end_idx = row_indices.binary_search(&range.end).unwrap_or_else(|x| x);
238    (start_idx != end_idx).then_some(start_idx..end_idx)
239}
240
241#[cfg(test)]
242mod tests {
243    use rstest::rstest;
244    use vortex_array::IntoArray;
245    use vortex_array::arrays::PrimitiveArray;
246    use vortex_array::validity::Validity;
247    use vortex_buffer::{Buffer, buffer};
248    use vortex_error::VortexUnwrap;
249    use vortex_mask::Mask;
250
251    use super::*;
252
253    #[rstest]
254    #[case(
255        RowMask::new(Mask::from_iter([true, true, true, false, false, false, false, false, true, true]), 0), (0, 1),
256        RowMask::new(Mask::from_iter([true]), 0))]
257    #[case(
258        RowMask::new(Mask::from_iter([false, false, false, false, false, true, true, true, true, true]), 0), (2, 5),
259        RowMask::new(Mask::from_iter([false, false, false]), 2)
260    )]
261    #[case(
262        RowMask::new(Mask::from_iter([true, true, true, true, false, false, false, false, false, false]), 0), (2, 5),
263        RowMask::new(Mask::from_iter([true, true, false]), 2)
264    )]
265    #[case(
266        RowMask::new(Mask::from_iter([true, true, true, false, false, true, true, false, false, false]), 0), (2, 6),
267        RowMask::new(Mask::from_iter([true, false, false, true]), 2))]
268    #[case(
269        RowMask::new(Mask::from_iter([false, false, false, false, false, true, true, true, true, true]), 0), (7, 11),
270        RowMask::new(Mask::from_iter([true, true, true]), 7))]
271    #[case(
272        RowMask::new(Mask::from_iter([false, true, true, true, true, true]), 3), (0, 5),
273        RowMask::new(Mask::from_iter([false, true]), 3))]
274    #[cfg_attr(miri, ignore)]
275    fn slice(#[case] first: RowMask, #[case] range: (u64, u64), #[case] expected: RowMask) {
276        assert_eq!(first.slice(range.0, range.1).vortex_unwrap(), expected);
277    }
278
279    #[test]
280    #[should_panic]
281    #[cfg_attr(miri, ignore)]
282    fn shift_invalid() {
283        RowMask::new(Mask::from_iter([true, true, true, true, true]), 5)
284            .shift(7)
285            .unwrap();
286    }
287
288    #[test]
289    #[cfg_attr(miri, ignore)]
290    fn shift() {
291        assert_eq!(
292            RowMask::new(Mask::from_iter([true, true, true, true, true]), 5)
293                .shift(5)
294                .unwrap(),
295            RowMask::new(Mask::from_iter([true, true, true, true, true]), 0)
296        );
297    }
298
299    #[test]
300    #[cfg_attr(miri, ignore)]
301    fn filter_array() {
302        let mask = RowMask::new(
303            Mask::from_iter([
304                false, false, false, false, false, true, true, true, true, true,
305            ]),
306            0,
307        );
308        let array = Buffer::from_iter(0..20).into_array();
309        let filtered = mask.filter_array(&array).unwrap().unwrap();
310        assert_eq!(
311            filtered.to_primitive().unwrap().as_slice::<i32>(),
312            (5..10).collect::<Vec<_>>()
313        );
314    }
315
316    #[test]
317    #[should_panic]
318    fn test_row_mask_type_validation() {
319        let array = PrimitiveArray::new(buffer![1.0, 2.0], Validity::AllInvalid).into_array();
320        RowMask::from_array(&array, 0, 2).unwrap();
321    }
322}