Skip to main content

vortex_alp/alp/compute/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::builtins::ArrayBuiltins;
9use vortex_array::scalar_fn::fns::mask::MaskKernel;
10use vortex_array::scalar_fn::fns::mask::MaskReduce;
11use vortex_array::validity::Validity;
12use vortex_error::VortexResult;
13
14use crate::ALP;
15use crate::ALPArrayExt;
16use crate::ALPArraySlotsExt;
17
18impl MaskReduce for ALP {
19    fn mask(array: ArrayView<'_, Self>, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
20        // Masking sparse patches requires reading indices, fall back to kernel.
21        if array.patches().is_some() {
22            return Ok(None);
23        }
24        let masked_encoded = array.encoded().clone().mask(mask.clone())?;
25        Ok(Some(
26            ALP::new(masked_encoded, array.exponents(), None).into_array(),
27        ))
28    }
29}
30
31impl MaskKernel for ALP {
32    fn mask(
33        array: ArrayView<'_, Self>,
34        mask: &ArrayRef,
35        ctx: &mut ExecutionCtx,
36    ) -> VortexResult<Option<ArrayRef>> {
37        let vortex_mask = Validity::Array(mask.not()?).execute_mask(array.len(), ctx)?;
38        let masked_encoded = array.encoded().clone().mask(mask.clone())?;
39        let masked_dtype = array
40            .dtype()
41            .with_nullability(masked_encoded.dtype().nullability());
42        let masked_patches = array
43            .patches()
44            .map(|p| p.mask(&vortex_mask, ctx))
45            .transpose()?
46            .flatten()
47            .map(|patches| patches.cast_values(&masked_dtype))
48            .transpose()?;
49        Ok(Some(
50            ALP::new(masked_encoded, array.exponents(), masked_patches).into_array(),
51        ))
52    }
53}
54
55#[cfg(test)]
56mod test {
57    use rstest::rstest;
58    use vortex_array::IntoArray;
59    use vortex_array::LEGACY_SESSION;
60    use vortex_array::VortexSessionExecute;
61    use vortex_array::arrays::BoolArray;
62    use vortex_array::arrays::PrimitiveArray;
63    use vortex_array::compute::conformance::mask::test_mask_conformance;
64    use vortex_array::dtype::Nullability;
65    use vortex_array::scalar_fn::fns::mask::MaskKernel;
66    use vortex_buffer::buffer;
67
68    use crate::alp::array::ALPArrayExt;
69    use crate::alp_encode;
70
71    #[rstest]
72    #[case(buffer![10.5f32, 20.5, 30.5, 40.5, 50.5].into_array())]
73    #[case(buffer![1000.123f64, 2000.456, 3000.789, 4000.012, 5000.345].into_array())]
74    #[case(PrimitiveArray::from_option_iter([Some(1.1f32), None, Some(2.2), Some(3.3), None]).into_array())]
75    #[case(buffer![99.99f64].into_array())]
76    #[case(buffer![
77        0.1f32, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0,
78        1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0
79    ].into_array())]
80    fn test_mask_alp_conformance(#[case] array: vortex_array::ArrayRef) {
81        let mut ctx = LEGACY_SESSION.create_execution_ctx();
82        let array_primitive = array.execute::<PrimitiveArray>(&mut ctx).unwrap();
83        let alp = alp_encode(array_primitive.as_view(), None, &mut ctx).unwrap();
84        test_mask_conformance(&alp.into_array());
85    }
86
87    #[test]
88    fn test_mask_alp_with_patches() {
89        use std::f64::consts::PI;
90        let mut ctx = LEGACY_SESSION.create_execution_ctx();
91        // PI doesn't encode cleanly with ALP, so it creates patches.
92        let values: Vec<f64> = (0..100)
93            .map(|i| if i % 4 == 3 { PI } else { 1.0 })
94            .collect();
95        let array = PrimitiveArray::from_iter(values);
96        let alp = alp_encode(array.as_view(), None, &mut ctx).unwrap();
97        assert!(alp.patches().is_some(), "expected patches");
98        test_mask_conformance(&alp.into_array());
99    }
100
101    #[test]
102    fn test_mask_alp_with_patches_casts_surviving_patch_values_to_nullable() {
103        let mut ctx = LEGACY_SESSION.create_execution_ctx();
104        let values = PrimitiveArray::from_iter([1.234f32, f32::NAN, 2.345, f32::INFINITY, 3.456]);
105        let alp = alp_encode(values.as_view(), None, &mut ctx).unwrap();
106        assert!(alp.patches().is_some(), "expected patches");
107
108        let keep_mask = BoolArray::from_iter([false, true, true, true, true]).into_array();
109        let masked = <crate::ALP as MaskKernel>::mask(alp.as_view(), &keep_mask, &mut ctx)
110            .unwrap()
111            .unwrap();
112
113        let masked_alp = masked.as_opt::<crate::ALP>().unwrap();
114        let masked_patches = masked_alp.patches().unwrap();
115
116        assert_eq!(masked.dtype().nullability(), Nullability::Nullable);
117        assert_eq!(masked_patches.dtype().nullability(), Nullability::Nullable);
118    }
119}