Skip to main content

vortex_alp/alp/compute/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ExecutionCtx;
6use vortex_array::IntoArray;
7use vortex_array::builtins::ArrayBuiltins;
8use vortex_array::scalar_fn::fns::mask::MaskKernel;
9use vortex_array::scalar_fn::fns::mask::MaskReduce;
10use vortex_array::validity::Validity;
11use vortex_error::VortexResult;
12
13use crate::ALPArray;
14use crate::ALPVTable;
15
16impl MaskReduce for ALPVTable {
17    fn mask(array: &ALPArray, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
18        // Masking sparse patches requires reading indices, fall back to kernel.
19        if array.patches().is_some() {
20            return Ok(None);
21        }
22        let masked_encoded = array.encoded().clone().mask(mask.clone())?;
23        Ok(Some(
24            ALPArray::new(masked_encoded, array.exponents(), None).into_array(),
25        ))
26    }
27}
28
29impl MaskKernel for ALPVTable {
30    fn mask(
31        array: &ALPArray,
32        mask: &ArrayRef,
33        ctx: &mut ExecutionCtx,
34    ) -> VortexResult<Option<ArrayRef>> {
35        let vortex_mask = Validity::Array(mask.not()?).to_mask(array.len());
36        let masked_encoded = array.encoded().clone().mask(mask.clone())?;
37        let masked_patches = array
38            .patches()
39            .map(|p| p.mask(&vortex_mask, ctx))
40            .transpose()?
41            .flatten();
42        Ok(Some(
43            ALPArray::new(masked_encoded, array.exponents(), masked_patches).into_array(),
44        ))
45    }
46}
47
48#[cfg(test)]
49mod test {
50    use rstest::rstest;
51    use vortex_array::IntoArray;
52    use vortex_array::ToCanonical;
53    use vortex_array::arrays::PrimitiveArray;
54    use vortex_array::compute::conformance::mask::test_mask_conformance;
55    use vortex_buffer::buffer;
56
57    use crate::alp_encode;
58
59    #[rstest]
60    #[case(buffer![10.5f32, 20.5, 30.5, 40.5, 50.5].into_array())]
61    #[case(buffer![1000.123f64, 2000.456, 3000.789, 4000.012, 5000.345].into_array())]
62    #[case(PrimitiveArray::from_option_iter([Some(1.1f32), None, Some(2.2), Some(3.3), None]).into_array())]
63    #[case(buffer![99.99f64].into_array())]
64    #[case(buffer![
65        0.1f32, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0,
66        1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0
67    ].into_array())]
68    fn test_mask_alp_conformance(#[case] array: vortex_array::ArrayRef) {
69        let alp = alp_encode(&array.to_primitive(), None).unwrap();
70        test_mask_conformance(&alp.into_array());
71    }
72
73    #[test]
74    fn test_mask_alp_with_patches() {
75        use std::f64::consts::PI;
76        // PI doesn't encode cleanly with ALP, so it creates patches.
77        let values: Vec<f64> = (0..100)
78            .map(|i| if i % 4 == 3 { PI } else { 1.0 })
79            .collect();
80        let array = PrimitiveArray::from_iter(values);
81        let alp = alp_encode(&array, None).unwrap();
82        assert!(alp.patches().is_some(), "expected patches");
83        test_mask_conformance(&alp.into_array());
84    }
85}