Skip to main content

vortex_alp/alp/compute/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::builtins::ArrayBuiltins;
9use vortex_array::scalar_fn::fns::mask::MaskKernel;
10use vortex_array::scalar_fn::fns::mask::MaskReduce;
11use vortex_array::validity::Validity;
12use vortex_error::VortexResult;
13
14use crate::ALP;
15use crate::ALPArrayExt;
16use crate::ALPArraySlotsExt;
17
18impl MaskReduce for ALP {
19    fn mask(array: ArrayView<'_, Self>, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
20        // Masking sparse patches requires reading indices, fall back to kernel.
21        if array.patches().is_some() {
22            return Ok(None);
23        }
24        let masked_encoded = array.encoded().clone().mask(mask.clone())?;
25        Ok(Some(
26            ALP::new(masked_encoded, array.exponents(), None).into_array(),
27        ))
28    }
29}
30
31impl MaskKernel for ALP {
32    fn mask(
33        array: ArrayView<'_, Self>,
34        mask: &ArrayRef,
35        ctx: &mut ExecutionCtx,
36    ) -> VortexResult<Option<ArrayRef>> {
37        let vortex_mask = Validity::Array(mask.not()?).execute_mask(array.len(), ctx)?;
38        let masked_encoded = array.encoded().clone().mask(mask.clone())?;
39        let masked_dtype = array
40            .dtype()
41            .with_nullability(masked_encoded.dtype().nullability());
42        let masked_patches = array
43            .patches()
44            .map(|p| p.mask(&vortex_mask, ctx))
45            .transpose()?
46            .flatten()
47            .map(|patches| patches.cast_values(&masked_dtype))
48            .transpose()?;
49        Ok(Some(
50            ALP::new(masked_encoded, array.exponents(), masked_patches).into_array(),
51        ))
52    }
53}
54
55#[cfg(test)]
56mod test {
57    use rstest::rstest;
58    use vortex_array::IntoArray;
59    use vortex_array::LEGACY_SESSION;
60    use vortex_array::ToCanonical;
61    use vortex_array::VortexSessionExecute;
62    use vortex_array::arrays::BoolArray;
63    use vortex_array::arrays::PrimitiveArray;
64    use vortex_array::compute::conformance::mask::test_mask_conformance;
65    use vortex_array::dtype::Nullability;
66    use vortex_array::scalar_fn::fns::mask::MaskKernel;
67    use vortex_buffer::buffer;
68
69    use crate::alp::array::ALPArrayExt;
70    use crate::alp_encode;
71
72    #[rstest]
73    #[case(buffer![10.5f32, 20.5, 30.5, 40.5, 50.5].into_array())]
74    #[case(buffer![1000.123f64, 2000.456, 3000.789, 4000.012, 5000.345].into_array())]
75    #[case(PrimitiveArray::from_option_iter([Some(1.1f32), None, Some(2.2), Some(3.3), None]).into_array())]
76    #[case(buffer![99.99f64].into_array())]
77    #[case(buffer![
78        0.1f32, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0,
79        1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0
80    ].into_array())]
81    fn test_mask_alp_conformance(#[case] array: vortex_array::ArrayRef) {
82        let alp = alp_encode(&array.to_primitive(), None).unwrap();
83        test_mask_conformance(&alp.into_array());
84    }
85
86    #[test]
87    fn test_mask_alp_with_patches() {
88        use std::f64::consts::PI;
89        // PI doesn't encode cleanly with ALP, so it creates patches.
90        let values: Vec<f64> = (0..100)
91            .map(|i| if i % 4 == 3 { PI } else { 1.0 })
92            .collect();
93        let array = PrimitiveArray::from_iter(values);
94        let alp = alp_encode(&array, None).unwrap();
95        assert!(alp.patches().is_some(), "expected patches");
96        test_mask_conformance(&alp.into_array());
97    }
98
99    #[test]
100    fn test_mask_alp_with_patches_casts_surviving_patch_values_to_nullable() {
101        let values = PrimitiveArray::from_iter([1.234f32, f32::NAN, 2.345, f32::INFINITY, 3.456]);
102        let alp = alp_encode(&values, None).unwrap();
103        assert!(alp.patches().is_some(), "expected patches");
104
105        let keep_mask = BoolArray::from_iter([false, true, true, true, true]).into_array();
106        let mut ctx = LEGACY_SESSION.create_execution_ctx();
107        let masked = <crate::ALP as MaskKernel>::mask(alp.as_view(), &keep_mask, &mut ctx)
108            .unwrap()
109            .unwrap();
110
111        let masked_alp = masked.as_opt::<crate::ALP>().unwrap();
112        let masked_patches = masked_alp.patches().unwrap();
113
114        assert_eq!(masked.dtype().nullability(), Nullability::Nullable);
115        assert_eq!(masked_patches.dtype().nullability(), Nullability::Nullable);
116    }
117}