vortex_array/arrays/chunked/compute/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools as _;
5use vortex_dtype::DType;
6use vortex_error::{VortexExpect as _, VortexResult};
7use vortex_mask::{AllOr, Mask, MaskIter};
8use vortex_scalar::Scalar;
9
10use super::filter::{ChunkFilter, chunk_filters, find_chunk_idx};
11use crate::arrays::chunked::compute::filter::FILTER_SLICES_SELECTIVITY_THRESHOLD;
12use crate::arrays::{ChunkedArray, ChunkedVTable, ConstantArray};
13use crate::compute::{MaskKernel, MaskKernelAdapter, cast, mask};
14use crate::{Array, ArrayRef, IntoArray, register_kernel};
15
16impl MaskKernel for ChunkedVTable {
17    fn mask(&self, array: &ChunkedArray, mask: &Mask) -> VortexResult<ArrayRef> {
18        let new_dtype = array.dtype().as_nullable();
19        let new_chunks = match mask.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) {
20            AllOr::All => unreachable!("handled in top-level mask"),
21            AllOr::None => unreachable!("handled in top-level mask"),
22            AllOr::Some(MaskIter::Indices(indices)) => mask_indices(array, indices, &new_dtype),
23            AllOr::Some(MaskIter::Slices(slices)) => {
24                mask_slices(array, slices.iter().cloned(), &new_dtype)
25            }
26        }?;
27        debug_assert_eq!(new_chunks.len(), array.nchunks());
28        debug_assert_eq!(
29            new_chunks.iter().map(|x| x.len()).sum::<usize>(),
30            array.len()
31        );
32        ChunkedArray::try_new(new_chunks, new_dtype).map(|c| c.into_array())
33    }
34}
35
36register_kernel!(MaskKernelAdapter(ChunkedVTable).lift());
37
38fn mask_indices(
39    array: &ChunkedArray,
40    indices: &[usize],
41    new_dtype: &DType,
42) -> VortexResult<Vec<ArrayRef>> {
43    let mut new_chunks = Vec::with_capacity(array.nchunks());
44    let mut current_chunk_id = 0;
45    let mut chunk_indices = Vec::new();
46
47    let chunk_offsets = array.chunk_offsets();
48
49    for &set_index in indices {
50        let (chunk_id, index) = find_chunk_idx(set_index, chunk_offsets);
51        if chunk_id != current_chunk_id {
52            let chunk = array
53                .chunk(current_chunk_id)
54                .vortex_expect("find_chunk_idx must return valid chunk ID");
55            let masked_chunk = mask(chunk, &Mask::from_indices(chunk.len(), chunk_indices))?;
56            // Advance the chunk forward, reset the chunk indices buffer.
57            chunk_indices = Vec::new();
58            new_chunks.push(masked_chunk);
59            current_chunk_id += 1;
60
61            while current_chunk_id < chunk_id {
62                // Chunks that are not affected by the mask, must still be casted to the correct dtype.
63                let chunk = array
64                    .chunk(current_chunk_id)
65                    .vortex_expect("find_chunk_idx must return valid chunk ID");
66                new_chunks.push(cast(chunk, new_dtype)?);
67                current_chunk_id += 1;
68            }
69        }
70
71        chunk_indices.push(index);
72    }
73
74    if !chunk_indices.is_empty() {
75        let chunk = array
76            .chunk(current_chunk_id)
77            .vortex_expect("find_chunk_idx must return valid chunk ID");
78        let masked_chunk = mask(chunk, &Mask::from_indices(chunk.len(), chunk_indices))?;
79        new_chunks.push(masked_chunk);
80        current_chunk_id += 1;
81    }
82
83    while current_chunk_id < array.nchunks() {
84        let chunk = array
85            .chunk(current_chunk_id)
86            .vortex_expect("find_chunk_idx must return valid chunk ID");
87        new_chunks.push(cast(chunk, new_dtype)?);
88        current_chunk_id += 1;
89    }
90
91    Ok(new_chunks)
92}
93
94fn mask_slices(
95    array: &ChunkedArray,
96    slices: impl Iterator<Item = (usize, usize)>,
97    new_dtype: &DType,
98) -> VortexResult<Vec<ArrayRef>> {
99    let chunked_filters = chunk_filters(array, slices)?;
100
101    array
102        .chunks()
103        .iter()
104        .zip_eq(chunked_filters)
105        .map(|(chunk, chunk_filter)| -> VortexResult<ArrayRef> {
106            Ok(match chunk_filter {
107                ChunkFilter::All => {
108                    // entire chunk is masked out
109                    ConstantArray::new(Scalar::null(new_dtype.clone()), chunk.len()).into_array()
110                }
111                ChunkFilter::None => {
112                    // entire chunk is not affected by mask
113                    chunk.clone()
114                }
115                ChunkFilter::Slices(slices) => {
116                    // Slices of indices that must be set to null
117                    mask(chunk, &Mask::from_slices(chunk.len(), slices))?
118                }
119            })
120        })
121        .process_results(|iter| iter.collect::<Vec<_>>())
122}
123
124#[cfg(test)]
125mod test {
126    use rstest::rstest;
127    use vortex_buffer::buffer;
128    use vortex_dtype::{DType, Nullability, PType};
129
130    use crate::IntoArray;
131    use crate::arrays::{ChunkedArray, PrimitiveArray};
132    use crate::compute::conformance::mask::test_mask_conformance;
133
134    #[rstest]
135    #[case(ChunkedArray::try_new(
136        vec![
137            buffer![0u64, 1].into_array(),
138            buffer![2_u64].into_array(),
139            PrimitiveArray::empty::<u64>(Nullability::NonNullable).to_array(),
140            buffer![3_u64, 4].into_array(),
141        ],
142        DType::Primitive(PType::U64, Nullability::NonNullable),
143    ).unwrap())]
144    #[case(ChunkedArray::try_new(
145        vec![
146            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).to_array(),
147            PrimitiveArray::from_option_iter([Some(4i32), Some(5)]).to_array(),
148        ],
149        DType::Primitive(PType::I32, Nullability::Nullable),
150    ).unwrap())]
151    #[case(ChunkedArray::try_new(
152        vec![
153            buffer![42u8].into_array(),
154        ],
155        DType::Primitive(PType::U8, Nullability::NonNullable),
156    ).unwrap())]
157    #[case(ChunkedArray::try_new(
158        (0..20).map(|i| buffer![i as f32, i as f32 + 0.5].into_array()).collect(),
159        DType::Primitive(PType::F32, Nullability::NonNullable),
160    ).unwrap())]
161    fn test_mask_chunked_conformance(#[case] chunked: ChunkedArray) {
162        test_mask_conformance(chunked.as_ref());
163    }
164}