vortex_array/arrays/chunked/compute/
mask.rs1use itertools::Itertools as _;
2use vortex_dtype::DType;
3use vortex_error::{VortexExpect as _, VortexResult};
4use vortex_mask::{AllOr, Mask, MaskIter};
5use vortex_scalar::Scalar;
6
7use super::filter::{ChunkFilter, chunk_filters, find_chunk_idx};
8use crate::arrays::chunked::compute::filter::FILTER_SLICES_SELECTIVITY_THRESHOLD;
9use crate::arrays::{ChunkedArray, ChunkedVTable, ConstantArray};
10use crate::compute::{MaskKernel, MaskKernelAdapter, cast, mask};
11use crate::{Array, ArrayRef, IntoArray, register_kernel};
12
13impl MaskKernel for ChunkedVTable {
14 fn mask(&self, array: &ChunkedArray, mask: &Mask) -> VortexResult<ArrayRef> {
15 let new_dtype = array.dtype().as_nullable();
16 let new_chunks = match mask.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) {
17 AllOr::All => unreachable!("handled in top-level mask"),
18 AllOr::None => unreachable!("handled in top-level mask"),
19 AllOr::Some(MaskIter::Indices(indices)) => mask_indices(array, indices, &new_dtype),
20 AllOr::Some(MaskIter::Slices(slices)) => {
21 mask_slices(array, slices.iter().cloned(), &new_dtype)
22 }
23 }?;
24 debug_assert_eq!(new_chunks.len(), array.nchunks());
25 debug_assert_eq!(
26 new_chunks.iter().map(|x| x.len()).sum::<usize>(),
27 array.len()
28 );
29 ChunkedArray::try_new(new_chunks, new_dtype).map(|c| c.into_array())
30 }
31}
32
33register_kernel!(MaskKernelAdapter(ChunkedVTable).lift());
34
35fn mask_indices(
36 array: &ChunkedArray,
37 indices: &[usize],
38 new_dtype: &DType,
39) -> VortexResult<Vec<ArrayRef>> {
40 let mut new_chunks = Vec::with_capacity(array.nchunks());
41 let mut current_chunk_id = 0;
42 let mut chunk_indices = Vec::new();
43
44 let chunk_offsets = array.chunk_offsets();
45
46 for &set_index in indices {
47 let (chunk_id, index) = find_chunk_idx(set_index, chunk_offsets);
48 if chunk_id != current_chunk_id {
49 let chunk = array
50 .chunk(current_chunk_id)
51 .vortex_expect("find_chunk_idx must return valid chunk ID");
52 let masked_chunk = mask(chunk, &Mask::from_indices(chunk.len(), chunk_indices))?;
53 chunk_indices = Vec::new();
55 new_chunks.push(masked_chunk);
56 current_chunk_id += 1;
57
58 while current_chunk_id < chunk_id {
59 let chunk = array
61 .chunk(current_chunk_id)
62 .vortex_expect("find_chunk_idx must return valid chunk ID");
63 new_chunks.push(cast(chunk, new_dtype)?);
64 current_chunk_id += 1;
65 }
66 }
67
68 chunk_indices.push(index);
69 }
70
71 if !chunk_indices.is_empty() {
72 let chunk = array
73 .chunk(current_chunk_id)
74 .vortex_expect("find_chunk_idx must return valid chunk ID");
75 let masked_chunk = mask(chunk, &Mask::from_indices(chunk.len(), chunk_indices))?;
76 new_chunks.push(masked_chunk);
77 current_chunk_id += 1;
78 }
79
80 while current_chunk_id < array.nchunks() {
81 let chunk = array
82 .chunk(current_chunk_id)
83 .vortex_expect("find_chunk_idx must return valid chunk ID");
84 new_chunks.push(cast(chunk, new_dtype)?);
85 current_chunk_id += 1;
86 }
87
88 Ok(new_chunks)
89}
90
91fn mask_slices(
92 array: &ChunkedArray,
93 slices: impl Iterator<Item = (usize, usize)>,
94 new_dtype: &DType,
95) -> VortexResult<Vec<ArrayRef>> {
96 let chunked_filters = chunk_filters(array, slices)?;
97
98 array
99 .chunks()
100 .iter()
101 .zip_eq(chunked_filters)
102 .map(|(chunk, chunk_filter)| -> VortexResult<ArrayRef> {
103 Ok(match chunk_filter {
104 ChunkFilter::All => {
105 ConstantArray::new(Scalar::null(new_dtype.clone()), chunk.len()).into_array()
107 }
108 ChunkFilter::None => {
109 chunk.clone()
111 }
112 ChunkFilter::Slices(slices) => {
113 mask(chunk, &Mask::from_slices(chunk.len(), slices))?
115 }
116 })
117 })
118 .process_results(|iter| iter.collect::<Vec<_>>())
119}
120
121#[cfg(test)]
122mod test {
123 use vortex_buffer::buffer;
124 use vortex_dtype::{DType, Nullability, PType};
125
126 use crate::IntoArray;
127 use crate::arrays::{ChunkedArray, PrimitiveArray};
128 use crate::compute::conformance::mask::test_mask;
129
130 #[test]
131 fn test_mask_chunked_array() {
132 let dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
133 let chunked = ChunkedArray::try_new(
134 vec![
135 buffer![0u64, 1].into_array(),
136 buffer![2_u64].into_array(),
137 PrimitiveArray::empty::<u64>(dtype.nullability()).to_array(),
138 buffer![3_u64, 4].into_array(),
139 ],
140 dtype,
141 )
142 .unwrap();
143
144 test_mask(chunked.as_ref());
145 }
146}