vortex_array/arrays/chunked/compute/
mask.rs1use itertools::Itertools as _;
5use vortex_dtype::DType;
6use vortex_error::VortexResult;
7use vortex_mask::{AllOr, Mask, MaskIter};
8use vortex_scalar::Scalar;
9
10use super::filter::{ChunkFilter, chunk_filters, find_chunk_idx};
11use crate::arrays::chunked::compute::filter::FILTER_SLICES_SELECTIVITY_THRESHOLD;
12use crate::arrays::{ChunkedArray, ChunkedVTable, ConstantArray};
13use crate::compute::{MaskKernel, MaskKernelAdapter, cast, mask};
14use crate::{Array, ArrayRef, IntoArray, register_kernel};
15
16impl MaskKernel for ChunkedVTable {
17 fn mask(&self, array: &ChunkedArray, mask: &Mask) -> VortexResult<ArrayRef> {
18 let new_dtype = array.dtype().as_nullable();
19 let new_chunks = match mask.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) {
20 AllOr::All => unreachable!("handled in top-level mask"),
21 AllOr::None => unreachable!("handled in top-level mask"),
22 AllOr::Some(MaskIter::Indices(indices)) => mask_indices(array, indices, &new_dtype),
23 AllOr::Some(MaskIter::Slices(slices)) => {
24 mask_slices(array, slices.iter().cloned(), &new_dtype)
25 }
26 }?;
27 debug_assert_eq!(new_chunks.len(), array.nchunks());
28 debug_assert_eq!(
29 new_chunks.iter().map(|x| x.len()).sum::<usize>(),
30 array.len()
31 );
32 ChunkedArray::try_new(new_chunks, new_dtype).map(|c| c.into_array())
33 }
34}
35
36register_kernel!(MaskKernelAdapter(ChunkedVTable).lift());
37
38fn mask_indices(
39 array: &ChunkedArray,
40 indices: &[usize],
41 new_dtype: &DType,
42) -> VortexResult<Vec<ArrayRef>> {
43 let mut new_chunks = Vec::with_capacity(array.nchunks());
44 let mut current_chunk_id = 0;
45 let mut chunk_indices = Vec::new();
46
47 let chunk_offsets = array.chunk_offsets();
48
49 for &set_index in indices {
50 let (chunk_id, index) = find_chunk_idx(set_index, chunk_offsets);
51 if chunk_id != current_chunk_id {
52 let chunk = array.chunk(current_chunk_id);
53 let masked_chunk = mask(chunk, &Mask::from_indices(chunk.len(), chunk_indices))?;
54 chunk_indices = Vec::new();
56 new_chunks.push(masked_chunk);
57 current_chunk_id += 1;
58
59 while current_chunk_id < chunk_id {
60 let chunk = array.chunk(current_chunk_id);
62 new_chunks.push(cast(chunk, new_dtype)?);
63 current_chunk_id += 1;
64 }
65 }
66
67 chunk_indices.push(index);
68 }
69
70 if !chunk_indices.is_empty() {
71 let chunk = array.chunk(current_chunk_id);
72 let masked_chunk = mask(chunk, &Mask::from_indices(chunk.len(), chunk_indices))?;
73 new_chunks.push(masked_chunk);
74 current_chunk_id += 1;
75 }
76
77 while current_chunk_id < array.nchunks() {
78 let chunk = array.chunk(current_chunk_id);
79 new_chunks.push(cast(chunk, new_dtype)?);
80 current_chunk_id += 1;
81 }
82
83 Ok(new_chunks)
84}
85
86fn mask_slices(
87 array: &ChunkedArray,
88 slices: impl Iterator<Item = (usize, usize)>,
89 new_dtype: &DType,
90) -> VortexResult<Vec<ArrayRef>> {
91 let chunked_filters = chunk_filters(array, slices)?;
92
93 array
94 .chunks()
95 .iter()
96 .zip_eq(chunked_filters)
97 .map(|(chunk, chunk_filter)| -> VortexResult<ArrayRef> {
98 Ok(match chunk_filter {
99 ChunkFilter::All => {
100 ConstantArray::new(Scalar::null(new_dtype.clone()), chunk.len()).into_array()
102 }
103 ChunkFilter::None => {
104 chunk.clone()
106 }
107 ChunkFilter::Slices(slices) => {
108 mask(chunk, &Mask::from_slices(chunk.len(), slices))?
110 }
111 })
112 })
113 .process_results(|iter| iter.collect::<Vec<_>>())
114}
115
116#[cfg(test)]
117mod test {
118 use rstest::rstest;
119 use vortex_buffer::buffer;
120 use vortex_dtype::{DType, Nullability, PType};
121
122 use crate::IntoArray;
123 use crate::arrays::{ChunkedArray, PrimitiveArray};
124 use crate::compute::conformance::mask::test_mask_conformance;
125
126 #[rstest]
127 #[case(ChunkedArray::try_new(
128 vec![
129 buffer![0u64, 1].into_array(),
130 buffer![2_u64].into_array(),
131 PrimitiveArray::empty::<u64>(Nullability::NonNullable).to_array(),
132 buffer![3_u64, 4].into_array(),
133 ],
134 DType::Primitive(PType::U64, Nullability::NonNullable),
135 ).unwrap())]
136 #[case(ChunkedArray::try_new(
137 vec![
138 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).to_array(),
139 PrimitiveArray::from_option_iter([Some(4i32), Some(5)]).to_array(),
140 ],
141 DType::Primitive(PType::I32, Nullability::Nullable),
142 ).unwrap())]
143 #[case(ChunkedArray::try_new(
144 vec![
145 buffer![42u8].into_array(),
146 ],
147 DType::Primitive(PType::U8, Nullability::NonNullable),
148 ).unwrap())]
149 #[case(ChunkedArray::try_new(
150 (0..20).map(|i| buffer![i as f32, i as f32 + 0.5].into_array()).collect(),
151 DType::Primitive(PType::F32, Nullability::NonNullable),
152 ).unwrap())]
153 fn test_mask_chunked_conformance(#[case] chunked: ChunkedArray) {
154 test_mask_conformance(chunked.as_ref());
155 }
156}