vortex_array/arrays/chunked/compute/
mask.rs1use itertools::Itertools as _;
2use vortex_dtype::DType;
3use vortex_error::{VortexExpect as _, VortexResult};
4use vortex_mask::{AllOr, Mask, MaskIter};
5use vortex_scalar::Scalar;
6
7use super::filter::{ChunkFilter, chunk_filters, find_chunk_idx};
8use crate::arrays::chunked::compute::filter::FILTER_SLICES_SELECTIVITY_THRESHOLD;
9use crate::arrays::{ChunkedArray, ChunkedEncoding, ConstantArray};
10use crate::compute::{MaskFn, mask, try_cast};
11use crate::{Array, ArrayRef};
12
13impl MaskFn<&ChunkedArray> for ChunkedEncoding {
14 fn mask(&self, array: &ChunkedArray, mask: Mask) -> VortexResult<ArrayRef> {
15 let new_dtype = array.dtype().as_nullable();
16 let new_chunks = match mask.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) {
17 AllOr::All => unreachable!("handled in top-level mask"),
18 AllOr::None => unreachable!("handled in top-level mask"),
19 AllOr::Some(MaskIter::Indices(indices)) => mask_indices(array, indices, &new_dtype),
20 AllOr::Some(MaskIter::Slices(slices)) => {
21 mask_slices(array, slices.iter().cloned(), &new_dtype)
22 }
23 }?;
24 debug_assert_eq!(new_chunks.len(), array.nchunks());
25 debug_assert_eq!(
26 new_chunks.iter().map(|x| x.len()).sum::<usize>(),
27 array.len()
28 );
29 ChunkedArray::try_new(new_chunks, new_dtype).map(|c| c.into_array())
30 }
31}
32
33fn mask_indices(
34 array: &ChunkedArray,
35 indices: &[usize],
36 new_dtype: &DType,
37) -> VortexResult<Vec<ArrayRef>> {
38 let mut new_chunks = Vec::with_capacity(array.nchunks());
39 let mut current_chunk_id = 0;
40 let mut chunk_indices = Vec::new();
41
42 let chunk_offsets = array.chunk_offsets();
43
44 for &set_index in indices {
45 let (chunk_id, index) = find_chunk_idx(set_index, chunk_offsets);
46 if chunk_id != current_chunk_id {
47 let chunk = array
48 .chunk(current_chunk_id)
49 .vortex_expect("find_chunk_idx must return valid chunk ID");
50 let masked_chunk = mask(chunk, Mask::from_indices(chunk.len(), chunk_indices))?;
51 chunk_indices = Vec::new();
53 new_chunks.push(masked_chunk);
54 current_chunk_id += 1;
55
56 while current_chunk_id < chunk_id {
57 let chunk = array
59 .chunk(current_chunk_id)
60 .vortex_expect("find_chunk_idx must return valid chunk ID");
61 new_chunks.push(try_cast(chunk, new_dtype)?);
62 current_chunk_id += 1;
63 }
64 }
65
66 chunk_indices.push(index);
67 }
68
69 if !chunk_indices.is_empty() {
70 let chunk = array
71 .chunk(current_chunk_id)
72 .vortex_expect("find_chunk_idx must return valid chunk ID");
73 let masked_chunk = mask(chunk, Mask::from_indices(chunk.len(), chunk_indices))?;
74 new_chunks.push(masked_chunk);
75 current_chunk_id += 1;
76 }
77
78 while current_chunk_id < array.nchunks() {
79 let chunk = array
80 .chunk(current_chunk_id)
81 .vortex_expect("find_chunk_idx must return valid chunk ID");
82 new_chunks.push(try_cast(chunk, new_dtype)?);
83 current_chunk_id += 1;
84 }
85
86 Ok(new_chunks)
87}
88
89fn mask_slices(
90 array: &ChunkedArray,
91 slices: impl Iterator<Item = (usize, usize)>,
92 new_dtype: &DType,
93) -> VortexResult<Vec<ArrayRef>> {
94 let chunked_filters = chunk_filters(array, slices)?;
95
96 array
97 .chunks()
98 .iter()
99 .zip_eq(chunked_filters)
100 .map(|(chunk, chunk_filter)| -> VortexResult<ArrayRef> {
101 Ok(match chunk_filter {
102 ChunkFilter::All => {
103 ConstantArray::new(Scalar::null(new_dtype.clone()), chunk.len()).into_array()
105 }
106 ChunkFilter::None => {
107 chunk.clone()
109 }
110 ChunkFilter::Slices(slices) => {
111 mask(chunk, Mask::from_slices(chunk.len(), slices))?
113 }
114 })
115 })
116 .process_results(|iter| iter.collect::<Vec<_>>())
117}
118
119#[cfg(test)]
120mod test {
121 use vortex_buffer::buffer;
122 use vortex_dtype::{DType, Nullability, PType};
123
124 use crate::IntoArray;
125 use crate::array::Array;
126 use crate::arrays::{ChunkedArray, PrimitiveArray};
127 use crate::compute::test_harness::test_mask;
128
129 #[test]
130 fn test_mask_chunked_array() {
131 let dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
132 let chunked = ChunkedArray::try_new(
133 vec![
134 buffer![0u64, 1].into_array(),
135 buffer![2_u64].into_array(),
136 PrimitiveArray::empty::<u64>(dtype.nullability()).to_array(),
137 buffer![3_u64, 4].into_array(),
138 ],
139 dtype,
140 )
141 .unwrap();
142
143 test_mask(&chunked);
144 }
145}