Skip to main content

vortex_alp/alp_rd/compute/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::arrays::ScalarFnArrayExt;
7use vortex_array::scalar_fn::EmptyOptions;
8use vortex_array::scalar_fn::fns::mask::Mask as MaskExpr;
9use vortex_array::scalar_fn::fns::mask::MaskReduce;
10use vortex_error::VortexResult;
11
12use crate::ALPRDArray;
13use crate::ALPRDVTable;
14
15impl MaskReduce for ALPRDVTable {
16    fn mask(array: &ALPRDArray, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
17        let masked_left_parts = MaskExpr.try_new_array(
18            array.left_parts().len(),
19            EmptyOptions,
20            [array.left_parts().clone(), mask.clone()],
21        )?;
22        Ok(Some(
23            ALPRDArray::try_new(
24                array.dtype().as_nullable(),
25                masked_left_parts,
26                array.left_parts_dictionary().clone(),
27                array.right_parts().clone(),
28                array.right_bit_width(),
29                array.left_parts_patches().cloned(),
30            )?
31            .into_array(),
32        ))
33    }
34}
35
36#[cfg(test)]
37mod tests {
38    use rstest::rstest;
39    use vortex_array::IntoArray;
40    use vortex_array::arrays::PrimitiveArray;
41    use vortex_array::compute::conformance::mask::test_mask_conformance;
42
43    use crate::ALPRDFloat;
44    use crate::RDEncoder;
45
46    #[rstest]
47    #[case(0.1f32, 0.2f32, 3e25f32)]
48    #[case(0.1f64, 0.2f64, 3e100f64)]
49    fn test_mask_simple<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
50        test_mask_conformance(
51            &RDEncoder::new(&[a, b])
52                .encode(&PrimitiveArray::from_iter([a, b, outlier, b, outlier]))
53                .into_array(),
54        );
55    }
56
57    #[rstest]
58    #[case(0.1f32, 3e25f32)]
59    #[case(0.5f64, 1e100f64)]
60    fn test_mask_with_nulls<T: ALPRDFloat>(#[case] a: T, #[case] outlier: T) {
61        test_mask_conformance(
62            &RDEncoder::new(&[a])
63                .encode(&PrimitiveArray::from_option_iter([
64                    Some(a),
65                    None,
66                    Some(outlier),
67                    Some(a),
68                    None,
69                ]))
70                .into_array(),
71        );
72    }
73}