vortex_array/arrays/chunked/compute/
mask.rs

1use itertools::Itertools as _;
2use vortex_dtype::DType;
3use vortex_error::{VortexExpect as _, VortexResult};
4use vortex_mask::{AllOr, Mask, MaskIter};
5use vortex_scalar::Scalar;
6
7use super::filter::{ChunkFilter, chunk_filters, find_chunk_idx};
8use crate::arrays::chunked::compute::filter::FILTER_SLICES_SELECTIVITY_THRESHOLD;
9use crate::arrays::{ChunkedArray, ChunkedVTable, ConstantArray};
10use crate::compute::{MaskKernel, MaskKernelAdapter, cast, mask};
11use crate::{Array, ArrayRef, IntoArray, register_kernel};
12
13impl MaskKernel for ChunkedVTable {
14    fn mask(&self, array: &ChunkedArray, mask: &Mask) -> VortexResult<ArrayRef> {
15        let new_dtype = array.dtype().as_nullable();
16        let new_chunks = match mask.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) {
17            AllOr::All => unreachable!("handled in top-level mask"),
18            AllOr::None => unreachable!("handled in top-level mask"),
19            AllOr::Some(MaskIter::Indices(indices)) => mask_indices(array, indices, &new_dtype),
20            AllOr::Some(MaskIter::Slices(slices)) => {
21                mask_slices(array, slices.iter().cloned(), &new_dtype)
22            }
23        }?;
24        debug_assert_eq!(new_chunks.len(), array.nchunks());
25        debug_assert_eq!(
26            new_chunks.iter().map(|x| x.len()).sum::<usize>(),
27            array.len()
28        );
29        ChunkedArray::try_new(new_chunks, new_dtype).map(|c| c.into_array())
30    }
31}
32
33register_kernel!(MaskKernelAdapter(ChunkedVTable).lift());
34
35fn mask_indices(
36    array: &ChunkedArray,
37    indices: &[usize],
38    new_dtype: &DType,
39) -> VortexResult<Vec<ArrayRef>> {
40    let mut new_chunks = Vec::with_capacity(array.nchunks());
41    let mut current_chunk_id = 0;
42    let mut chunk_indices = Vec::new();
43
44    let chunk_offsets = array.chunk_offsets();
45
46    for &set_index in indices {
47        let (chunk_id, index) = find_chunk_idx(set_index, chunk_offsets);
48        if chunk_id != current_chunk_id {
49            let chunk = array
50                .chunk(current_chunk_id)
51                .vortex_expect("find_chunk_idx must return valid chunk ID");
52            let masked_chunk = mask(chunk, &Mask::from_indices(chunk.len(), chunk_indices))?;
53            // Advance the chunk forward, reset the chunk indices buffer.
54            chunk_indices = Vec::new();
55            new_chunks.push(masked_chunk);
56            current_chunk_id += 1;
57
58            while current_chunk_id < chunk_id {
59                // Chunks that are not affected by the mask, must still be casted to the correct dtype.
60                let chunk = array
61                    .chunk(current_chunk_id)
62                    .vortex_expect("find_chunk_idx must return valid chunk ID");
63                new_chunks.push(cast(chunk, new_dtype)?);
64                current_chunk_id += 1;
65            }
66        }
67
68        chunk_indices.push(index);
69    }
70
71    if !chunk_indices.is_empty() {
72        let chunk = array
73            .chunk(current_chunk_id)
74            .vortex_expect("find_chunk_idx must return valid chunk ID");
75        let masked_chunk = mask(chunk, &Mask::from_indices(chunk.len(), chunk_indices))?;
76        new_chunks.push(masked_chunk);
77        current_chunk_id += 1;
78    }
79
80    while current_chunk_id < array.nchunks() {
81        let chunk = array
82            .chunk(current_chunk_id)
83            .vortex_expect("find_chunk_idx must return valid chunk ID");
84        new_chunks.push(cast(chunk, new_dtype)?);
85        current_chunk_id += 1;
86    }
87
88    Ok(new_chunks)
89}
90
91fn mask_slices(
92    array: &ChunkedArray,
93    slices: impl Iterator<Item = (usize, usize)>,
94    new_dtype: &DType,
95) -> VortexResult<Vec<ArrayRef>> {
96    let chunked_filters = chunk_filters(array, slices)?;
97
98    array
99        .chunks()
100        .iter()
101        .zip_eq(chunked_filters)
102        .map(|(chunk, chunk_filter)| -> VortexResult<ArrayRef> {
103            Ok(match chunk_filter {
104                ChunkFilter::All => {
105                    // entire chunk is masked out
106                    ConstantArray::new(Scalar::null(new_dtype.clone()), chunk.len()).into_array()
107                }
108                ChunkFilter::None => {
109                    // entire chunk is not affected by mask
110                    chunk.clone()
111                }
112                ChunkFilter::Slices(slices) => {
113                    // Slices of indices that must be set to null
114                    mask(chunk, &Mask::from_slices(chunk.len(), slices))?
115                }
116            })
117        })
118        .process_results(|iter| iter.collect::<Vec<_>>())
119}
120
121#[cfg(test)]
122mod test {
123    use vortex_buffer::buffer;
124    use vortex_dtype::{DType, Nullability, PType};
125
126    use crate::IntoArray;
127    use crate::arrays::{ChunkedArray, PrimitiveArray};
128    use crate::compute::conformance::mask::test_mask;
129
130    #[test]
131    fn test_mask_chunked_array() {
132        let dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
133        let chunked = ChunkedArray::try_new(
134            vec![
135                buffer![0u64, 1].into_array(),
136                buffer![2_u64].into_array(),
137                PrimitiveArray::empty::<u64>(dtype.nullability()).to_array(),
138                buffer![3_u64, 4].into_array(),
139            ],
140            dtype,
141        )
142        .unwrap();
143
144        test_mask(chunked.as_ref());
145    }
146}