Skip to main content

vortex_array/arrays/chunked/compute/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::ExecutionCtx;
8use crate::IntoArray;
9use crate::arrays::ChunkedArray;
10use crate::arrays::ChunkedVTable;
11use crate::arrays::ScalarFnArrayExt;
12use crate::compute::MaskKernel;
13use crate::expr::EmptyOptions;
14use crate::expr::mask::Mask as MaskExpr;
15
16impl MaskKernel for ChunkedVTable {
17    fn mask(
18        array: &ChunkedArray,
19        mask: &ArrayRef,
20        _ctx: &mut ExecutionCtx,
21    ) -> VortexResult<Option<ArrayRef>> {
22        let chunk_offsets = array.chunk_offsets();
23        let new_chunks: Vec<ArrayRef> = array
24            .chunks()
25            .iter()
26            .enumerate()
27            .map(|(i, chunk)| {
28                let start: usize = chunk_offsets[i].try_into()?;
29                let end: usize = chunk_offsets[i + 1].try_into()?;
30                let chunk_mask = mask.slice(start..end)?;
31                MaskExpr.try_new_array(chunk.len(), EmptyOptions, [chunk.clone(), chunk_mask])
32            })
33            .collect::<VortexResult<_>>()?;
34
35        Ok(Some(
36            ChunkedArray::try_new(new_chunks, array.dtype().as_nullable())?.into_array(),
37        ))
38    }
39}
40
41#[cfg(test)]
42mod test {
43    use rstest::rstest;
44    use vortex_buffer::buffer;
45    use vortex_dtype::DType;
46    use vortex_dtype::Nullability;
47    use vortex_dtype::PType;
48
49    use crate::IntoArray;
50    use crate::arrays::ChunkedArray;
51    use crate::arrays::PrimitiveArray;
52    use crate::compute::conformance::mask::test_mask_conformance;
53
54    #[rstest]
55    #[case(ChunkedArray::try_new(
56        vec![
57            buffer![0u64, 1].into_array(),
58            buffer![2_u64].into_array(),
59            PrimitiveArray::empty::<u64>(Nullability::NonNullable).to_array(),
60            buffer![3_u64, 4].into_array(),
61        ],
62        DType::Primitive(PType::U64, Nullability::NonNullable),
63    ).unwrap())]
64    #[case(ChunkedArray::try_new(
65        vec![
66            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).to_array(),
67            PrimitiveArray::from_option_iter([Some(4i32), Some(5)]).to_array(),
68        ],
69        DType::Primitive(PType::I32, Nullability::Nullable),
70    ).unwrap())]
71    #[case(ChunkedArray::try_new(
72        vec![
73            buffer![42u8].into_array(),
74        ],
75        DType::Primitive(PType::U8, Nullability::NonNullable),
76    ).unwrap())]
77    #[case(ChunkedArray::try_new(
78        (0..20).map(|i| buffer![i as f32, i as f32 + 0.5].into_array()).collect(),
79        DType::Primitive(PType::F32, Nullability::NonNullable),
80    ).unwrap())]
81    fn test_mask_chunked_conformance(#[case] chunked: ChunkedArray) {
82        test_mask_conformance(chunked.as_ref());
83    }
84}