vortex_array/arrays/chunked/compute/
mask.rs

1use itertools::Itertools as _;
2use vortex_dtype::DType;
3use vortex_error::{VortexExpect as _, VortexResult};
4use vortex_mask::{AllOr, Mask, MaskIter};
5use vortex_scalar::Scalar;
6
7use super::filter::{ChunkFilter, chunk_filters, find_chunk_idx};
8use crate::arrays::chunked::compute::filter::FILTER_SLICES_SELECTIVITY_THRESHOLD;
9use crate::arrays::{ChunkedArray, ChunkedEncoding, ConstantArray};
10use crate::compute::{MaskFn, mask, try_cast};
11use crate::{Array, ArrayRef};
12
13impl MaskFn<&ChunkedArray> for ChunkedEncoding {
14    fn mask(&self, array: &ChunkedArray, mask: Mask) -> VortexResult<ArrayRef> {
15        let new_dtype = array.dtype().as_nullable();
16        let new_chunks = match mask.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) {
17            AllOr::All => unreachable!("handled in top-level mask"),
18            AllOr::None => unreachable!("handled in top-level mask"),
19            AllOr::Some(MaskIter::Indices(indices)) => mask_indices(array, indices, &new_dtype),
20            AllOr::Some(MaskIter::Slices(slices)) => {
21                mask_slices(array, slices.iter().cloned(), &new_dtype)
22            }
23        }?;
24        debug_assert_eq!(new_chunks.len(), array.nchunks());
25        debug_assert_eq!(
26            new_chunks.iter().map(|x| x.len()).sum::<usize>(),
27            array.len()
28        );
29        ChunkedArray::try_new(new_chunks, new_dtype).map(|c| c.into_array())
30    }
31}
32
33fn mask_indices(
34    array: &ChunkedArray,
35    indices: &[usize],
36    new_dtype: &DType,
37) -> VortexResult<Vec<ArrayRef>> {
38    let mut new_chunks = Vec::with_capacity(array.nchunks());
39    let mut current_chunk_id = 0;
40    let mut chunk_indices = Vec::new();
41
42    let chunk_offsets = array.chunk_offsets();
43
44    for &set_index in indices {
45        let (chunk_id, index) = find_chunk_idx(set_index, chunk_offsets);
46        if chunk_id != current_chunk_id {
47            let chunk = array
48                .chunk(current_chunk_id)
49                .vortex_expect("find_chunk_idx must return valid chunk ID");
50            let masked_chunk = mask(chunk, Mask::from_indices(chunk.len(), chunk_indices))?;
51            // Advance the chunk forward, reset the chunk indices buffer.
52            chunk_indices = Vec::new();
53            new_chunks.push(masked_chunk);
54            current_chunk_id += 1;
55
56            while current_chunk_id < chunk_id {
57                // Chunks that are not affected by the mask, must still be casted to the correct dtype.
58                let chunk = array
59                    .chunk(current_chunk_id)
60                    .vortex_expect("find_chunk_idx must return valid chunk ID");
61                new_chunks.push(try_cast(chunk, new_dtype)?);
62                current_chunk_id += 1;
63            }
64        }
65
66        chunk_indices.push(index);
67    }
68
69    if !chunk_indices.is_empty() {
70        let chunk = array
71            .chunk(current_chunk_id)
72            .vortex_expect("find_chunk_idx must return valid chunk ID");
73        let masked_chunk = mask(chunk, Mask::from_indices(chunk.len(), chunk_indices))?;
74        new_chunks.push(masked_chunk);
75        current_chunk_id += 1;
76    }
77
78    while current_chunk_id < array.nchunks() {
79        let chunk = array
80            .chunk(current_chunk_id)
81            .vortex_expect("find_chunk_idx must return valid chunk ID");
82        new_chunks.push(try_cast(chunk, new_dtype)?);
83        current_chunk_id += 1;
84    }
85
86    Ok(new_chunks)
87}
88
89fn mask_slices(
90    array: &ChunkedArray,
91    slices: impl Iterator<Item = (usize, usize)>,
92    new_dtype: &DType,
93) -> VortexResult<Vec<ArrayRef>> {
94    let chunked_filters = chunk_filters(array, slices)?;
95
96    array
97        .chunks()
98        .iter()
99        .zip_eq(chunked_filters)
100        .map(|(chunk, chunk_filter)| -> VortexResult<ArrayRef> {
101            Ok(match chunk_filter {
102                ChunkFilter::All => {
103                    // entire chunk is masked out
104                    ConstantArray::new(Scalar::null(new_dtype.clone()), chunk.len()).into_array()
105                }
106                ChunkFilter::None => {
107                    // entire chunk is not affected by mask
108                    chunk.clone()
109                }
110                ChunkFilter::Slices(slices) => {
111                    // Slices of indices that must be set to null
112                    mask(chunk, Mask::from_slices(chunk.len(), slices))?
113                }
114            })
115        })
116        .process_results(|iter| iter.collect::<Vec<_>>())
117}
118
119#[cfg(test)]
120mod test {
121    use vortex_buffer::buffer;
122    use vortex_dtype::{DType, Nullability, PType};
123
124    use crate::IntoArray;
125    use crate::array::Array;
126    use crate::arrays::{ChunkedArray, PrimitiveArray};
127    use crate::compute::test_harness::test_mask;
128
129    #[test]
130    fn test_mask_chunked_array() {
131        let dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
132        let chunked = ChunkedArray::try_new(
133            vec![
134                buffer![0u64, 1].into_array(),
135                buffer![2_u64].into_array(),
136                PrimitiveArray::empty::<u64>(dtype.nullability()).to_array(),
137                buffer![3_u64, 4].into_array(),
138            ],
139            dtype,
140        )
141        .unwrap();
142
143        test_mask(&chunked);
144    }
145}