1use std::cmp::{max, min};
2use std::fmt::{Display, Formatter};
3use std::ops::{Range, RangeBounds};
4
5use vortex_array::compute::{filter, slice, try_cast};
6use vortex_array::{Array, ArrayRef, ToCanonical};
7use vortex_buffer::Buffer;
8use vortex_dtype::Nullability::NonNullable;
9use vortex_dtype::{DType, PType};
10use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
11use vortex_mask::Mask;
12
13#[derive(Debug, Clone)]
18pub struct RowMask {
19 mask: Mask,
20 begin: u64,
21 end: u64,
22}
23
24#[cfg(test)]
26impl PartialEq for RowMask {
27 fn eq(&self, other: &Self) -> bool {
28 self.begin == other.begin && self.end == other.end && self.mask == other.mask
29 }
30}
31
32impl Display for RowMask {
33 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34 write!(f, "RowSelector [{}..{}]", self.begin, self.end)
35 }
36}
37
38impl RowMask {
39 pub fn new(mask: Mask, begin: u64) -> Self {
41 let end = begin + (mask.len() as u64);
42 Self { mask, begin, end }
43 }
44
45 pub fn new_valid_between(begin: u64, end: u64) -> Self {
51 let length =
52 usize::try_from(end - begin).vortex_expect("Range length does not fit into a usize");
53 RowMask::new(Mask::new_true(length), begin)
54 }
55
56 pub fn new_invalid_between(begin: u64, end: u64) -> Self {
58 let length =
59 usize::try_from(end - begin).vortex_expect("Range length does not fit into a usize");
60 RowMask::new(Mask::new_false(length), begin)
61 }
62
63 pub fn from_array(array: &dyn Array, begin: u64, end: u64) -> VortexResult<Self> {
65 if array.dtype().is_int() {
66 Self::from_index_array(array, begin, end)
67 } else if array.dtype().is_boolean() {
68 Self::from_mask_array(array, begin)
69 } else {
70 vortex_bail!(
71 "RowMask can only be created from integer or boolean arrays, got {} instead.",
72 array.dtype()
73 );
74 }
75 }
76
77 fn from_mask_array(array: &dyn Array, begin: u64) -> VortexResult<Self> {
81 Ok(Self::new(array.validity_mask()?, begin))
82 }
83
84 #[allow(clippy::cast_possible_truncation)]
88 fn from_index_array(array: &dyn Array, begin: u64, end: u64) -> VortexResult<Self> {
89 let length = usize::try_from(end - begin)
90 .map_err(|_| vortex_err!("Range length does not fit into a usize"))?;
91
92 let indices =
93 try_cast(array, &DType::Primitive(PType::U64, NonNullable))?.to_primitive()?;
94
95 let mask = Mask::from_indices(
96 length,
97 indices
98 .as_slice::<u64>()
99 .iter()
100 .map(|i| *i as usize)
101 .collect(),
102 );
103
104 Ok(RowMask::new(mask, begin))
105 }
106
107 pub fn is_disjoint(&self, range: impl RangeBounds<u64>) -> bool {
113 use std::ops::Bound;
114
115 let start = match range.start_bound() {
117 Bound::Included(&n) => n,
118 Bound::Excluded(&n) => n + 1,
119 Bound::Unbounded => 0,
120 };
121
122 let end = match range.end_bound() {
124 Bound::Included(&n) => n + 1,
125 Bound::Excluded(&n) => n,
126 Bound::Unbounded => u64::MAX,
127 };
128
129 self.end <= start || end <= self.begin
131 }
132
133 #[inline]
135 pub fn begin(&self) -> u64 {
136 self.begin
137 }
138
139 #[inline]
141 pub fn end(&self) -> u64 {
142 self.end
143 }
144
145 #[inline]
148 #[allow(clippy::len_without_is_empty)]
150 pub fn len(&self) -> usize {
151 self.mask.len()
152 }
153
154 pub fn filter_mask(&self) -> &Mask {
156 &self.mask
157 }
158
159 pub fn slice(&self, begin: u64, end: u64) -> VortexResult<Self> {
161 let range_begin = max(self.begin, begin);
162 let range_end = min(self.end, end);
163 Ok(RowMask::new(
164 if range_begin == self.begin && range_end == self.end {
165 self.mask.clone()
166 } else {
167 self.mask.slice(
168 usize::try_from(range_begin - self.begin)
169 .vortex_expect("we know this must fit into usize"),
170 usize::try_from(range_end - range_begin)
171 .vortex_expect("we know this must fit into usize"),
172 )
173 },
174 range_begin,
175 ))
176 }
177
178 pub fn filter_array(&self, array: &dyn Array) -> VortexResult<Option<ArrayRef>> {
183 let true_count = self.mask.true_count();
184 if true_count == 0 {
185 return Ok(None);
186 }
187
188 let sliced = if self.len() == array.len() {
189 array
190 } else {
191 &slice(
195 array,
196 usize::try_from(self.begin).vortex_expect("TODO(ngates): fix this bad cast"),
197 usize::try_from(self.end).vortex_expect("TODO(ngates): fix this bad cast"),
198 )?
199 };
200
201 if true_count == sliced.len() {
202 return Ok(Some(sliced.to_array()));
203 }
204
205 filter(sliced, &self.mask).map(Some)
206 }
207
208 pub fn shift(self, offset: u64) -> VortexResult<RowMask> {
210 let valid_shift = self.begin >= offset;
211 if !valid_shift {
212 vortex_bail!(
213 "Can shift RowMask by at most {}, tried to shift by {offset}",
214 self.begin
215 )
216 }
217 Ok(RowMask::new(self.mask, self.begin - offset))
218 }
219
220 pub fn true_count(&self) -> usize {
222 self.mask.true_count()
223 }
224}
225
226pub fn range_intersection(range: &Range<u64>, row_indices: &Buffer<u64>) -> Option<Range<usize>> {
227 if row_indices.first().is_some_and(|&first| first >= range.end)
228 || row_indices.last().is_some_and(|&last| range.start >= last)
229 {
230 return None;
231 }
232
233 let start_idx = row_indices
235 .binary_search(&range.start)
236 .unwrap_or_else(|x| x);
237 let end_idx = row_indices.binary_search(&range.end).unwrap_or_else(|x| x);
238 (start_idx != end_idx).then_some(start_idx..end_idx)
239}
240
241#[cfg(test)]
242mod tests {
243 use rstest::rstest;
244 use vortex_array::IntoArray;
245 use vortex_array::arrays::PrimitiveArray;
246 use vortex_array::validity::Validity;
247 use vortex_buffer::{Buffer, buffer};
248 use vortex_error::VortexUnwrap;
249 use vortex_mask::Mask;
250
251 use super::*;
252
253 #[rstest]
254 #[case(
255 RowMask::new(Mask::from_iter([true, true, true, false, false, false, false, false, true, true]), 0), (0, 1),
256 RowMask::new(Mask::from_iter([true]), 0))]
257 #[case(
258 RowMask::new(Mask::from_iter([false, false, false, false, false, true, true, true, true, true]), 0), (2, 5),
259 RowMask::new(Mask::from_iter([false, false, false]), 2)
260 )]
261 #[case(
262 RowMask::new(Mask::from_iter([true, true, true, true, false, false, false, false, false, false]), 0), (2, 5),
263 RowMask::new(Mask::from_iter([true, true, false]), 2)
264 )]
265 #[case(
266 RowMask::new(Mask::from_iter([true, true, true, false, false, true, true, false, false, false]), 0), (2, 6),
267 RowMask::new(Mask::from_iter([true, false, false, true]), 2))]
268 #[case(
269 RowMask::new(Mask::from_iter([false, false, false, false, false, true, true, true, true, true]), 0), (7, 11),
270 RowMask::new(Mask::from_iter([true, true, true]), 7))]
271 #[case(
272 RowMask::new(Mask::from_iter([false, true, true, true, true, true]), 3), (0, 5),
273 RowMask::new(Mask::from_iter([false, true]), 3))]
274 #[cfg_attr(miri, ignore)]
275 fn slice(#[case] first: RowMask, #[case] range: (u64, u64), #[case] expected: RowMask) {
276 assert_eq!(first.slice(range.0, range.1).vortex_unwrap(), expected);
277 }
278
279 #[test]
280 #[should_panic]
281 #[cfg_attr(miri, ignore)]
282 fn shift_invalid() {
283 RowMask::new(Mask::from_iter([true, true, true, true, true]), 5)
284 .shift(7)
285 .unwrap();
286 }
287
288 #[test]
289 #[cfg_attr(miri, ignore)]
290 fn shift() {
291 assert_eq!(
292 RowMask::new(Mask::from_iter([true, true, true, true, true]), 5)
293 .shift(5)
294 .unwrap(),
295 RowMask::new(Mask::from_iter([true, true, true, true, true]), 0)
296 );
297 }
298
299 #[test]
300 #[cfg_attr(miri, ignore)]
301 fn filter_array() {
302 let mask = RowMask::new(
303 Mask::from_iter([
304 false, false, false, false, false, true, true, true, true, true,
305 ]),
306 0,
307 );
308 let array = Buffer::from_iter(0..20).into_array();
309 let filtered = mask.filter_array(&array).unwrap().unwrap();
310 assert_eq!(
311 filtered.to_primitive().unwrap().as_slice::<i32>(),
312 (5..10).collect::<Vec<_>>()
313 );
314 }
315
316 #[test]
317 #[should_panic]
318 fn test_row_mask_type_validation() {
319 let array = PrimitiveArray::new(buffer![1.0, 2.0], Validity::AllInvalid).into_array();
320 RowMask::from_array(&array, 0, 2).unwrap();
321 }
322}