vortex_array/arrays/chunked/compute/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools as _;
5use vortex_dtype::DType;
6use vortex_error::VortexResult;
7use vortex_mask::AllOr;
8use vortex_mask::Mask;
9use vortex_mask::MaskIter;
10use vortex_scalar::Scalar;
11
12use super::filter::ChunkFilter;
13use super::filter::chunk_filters;
14use super::filter::find_chunk_idx;
15use crate::Array;
16use crate::ArrayRef;
17use crate::IntoArray;
18use crate::arrays::ChunkedArray;
19use crate::arrays::ChunkedVTable;
20use crate::arrays::ConstantArray;
21use crate::arrays::chunked::compute::filter::FILTER_SLICES_SELECTIVITY_THRESHOLD;
22use crate::compute::MaskKernel;
23use crate::compute::MaskKernelAdapter;
24use crate::compute::cast;
25use crate::compute::mask;
26use crate::register_kernel;
27
28impl MaskKernel for ChunkedVTable {
29    fn mask(&self, array: &ChunkedArray, mask: &Mask) -> VortexResult<ArrayRef> {
30        let new_dtype = array.dtype().as_nullable();
31        let new_chunks = match mask.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) {
32            AllOr::All => unreachable!("handled in top-level mask"),
33            AllOr::None => unreachable!("handled in top-level mask"),
34            AllOr::Some(MaskIter::Indices(indices)) => mask_indices(array, indices, &new_dtype),
35            AllOr::Some(MaskIter::Slices(slices)) => {
36                mask_slices(array, slices.iter().cloned(), &new_dtype)
37            }
38        }?;
39        debug_assert_eq!(new_chunks.len(), array.nchunks());
40        debug_assert_eq!(
41            new_chunks.iter().map(|x| x.len()).sum::<usize>(),
42            array.len()
43        );
44        ChunkedArray::try_new(new_chunks, new_dtype).map(|c| c.into_array())
45    }
46}
47
48register_kernel!(MaskKernelAdapter(ChunkedVTable).lift());
49
50fn mask_indices(
51    array: &ChunkedArray,
52    indices: &[usize],
53    new_dtype: &DType,
54) -> VortexResult<Vec<ArrayRef>> {
55    let mut new_chunks = Vec::with_capacity(array.nchunks());
56    let mut current_chunk_id = 0;
57    let mut chunk_indices = Vec::new();
58
59    let chunk_offsets = array.chunk_offsets();
60
61    for &set_index in indices {
62        let (chunk_id, index) = find_chunk_idx(set_index, &chunk_offsets);
63        if chunk_id != current_chunk_id {
64            let chunk = array.chunk(current_chunk_id);
65            let masked_chunk = mask(chunk, &Mask::from_indices(chunk.len(), chunk_indices))?;
66            // Advance the chunk forward, reset the chunk indices buffer.
67            chunk_indices = Vec::new();
68            new_chunks.push(masked_chunk);
69            current_chunk_id += 1;
70
71            while current_chunk_id < chunk_id {
72                // Chunks that are not affected by the mask, must still be casted to the correct dtype.
73                let chunk = array.chunk(current_chunk_id);
74                new_chunks.push(cast(chunk, new_dtype)?);
75                current_chunk_id += 1;
76            }
77        }
78
79        chunk_indices.push(index);
80    }
81
82    if !chunk_indices.is_empty() {
83        let chunk = array.chunk(current_chunk_id);
84        let masked_chunk = mask(chunk, &Mask::from_indices(chunk.len(), chunk_indices))?;
85        new_chunks.push(masked_chunk);
86        current_chunk_id += 1;
87    }
88
89    while current_chunk_id < array.nchunks() {
90        let chunk = array.chunk(current_chunk_id);
91        new_chunks.push(cast(chunk, new_dtype)?);
92        current_chunk_id += 1;
93    }
94
95    Ok(new_chunks)
96}
97
98fn mask_slices(
99    array: &ChunkedArray,
100    slices: impl Iterator<Item = (usize, usize)>,
101    new_dtype: &DType,
102) -> VortexResult<Vec<ArrayRef>> {
103    let chunked_filters = chunk_filters(array, slices)?;
104
105    array
106        .chunks()
107        .iter()
108        .zip_eq(chunked_filters)
109        .map(|(chunk, chunk_filter)| -> VortexResult<ArrayRef> {
110            match chunk_filter {
111                ChunkFilter::All => {
112                    // entire chunk is masked out
113                    Ok(
114                        ConstantArray::new(Scalar::null(new_dtype.clone()), chunk.len())
115                            .into_array(),
116                    )
117                }
118                ChunkFilter::None => {
119                    // entire chunk is not affected by mask
120                    cast(chunk, new_dtype)
121                }
122                ChunkFilter::Slices(slices) => {
123                    // Slices of indices that must be set to null
124                    mask(chunk, &Mask::from_slices(chunk.len(), slices))
125                }
126            }
127        })
128        .process_results(|iter| iter.collect::<Vec<_>>())
129}
130
131#[cfg(test)]
132mod test {
133    use rstest::rstest;
134    use vortex_buffer::buffer;
135    use vortex_dtype::DType;
136    use vortex_dtype::Nullability;
137    use vortex_dtype::PType;
138
139    use crate::IntoArray;
140    use crate::arrays::ChunkedArray;
141    use crate::arrays::PrimitiveArray;
142    use crate::compute::conformance::mask::test_mask_conformance;
143
144    #[rstest]
145    #[case(ChunkedArray::try_new(
146        vec![
147            buffer![0u64, 1].into_array(),
148            buffer![2_u64].into_array(),
149            PrimitiveArray::empty::<u64>(Nullability::NonNullable).to_array(),
150            buffer![3_u64, 4].into_array(),
151        ],
152        DType::Primitive(PType::U64, Nullability::NonNullable),
153    ).unwrap())]
154    #[case(ChunkedArray::try_new(
155        vec![
156            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).to_array(),
157            PrimitiveArray::from_option_iter([Some(4i32), Some(5)]).to_array(),
158        ],
159        DType::Primitive(PType::I32, Nullability::Nullable),
160    ).unwrap())]
161    #[case(ChunkedArray::try_new(
162        vec![
163            buffer![42u8].into_array(),
164        ],
165        DType::Primitive(PType::U8, Nullability::NonNullable),
166    ).unwrap())]
167    #[case(ChunkedArray::try_new(
168        (0..20).map(|i| buffer![i as f32, i as f32 + 0.5].into_array()).collect(),
169        DType::Primitive(PType::F32, Nullability::NonNullable),
170    ).unwrap())]
171    fn test_mask_chunked_conformance(#[case] chunked: ChunkedArray) {
172        test_mask_conformance(chunked.as_ref());
173    }
174}