Skip to main content

vortex_alp/alp/compute/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ExecutionCtx;
6use vortex_array::builtins::ArrayBuiltins;
7use vortex_array::scalar_fn::fns::mask::MaskKernel;
8use vortex_array::scalar_fn::fns::mask::MaskReduce;
9use vortex_array::validity::Validity;
10use vortex_error::VortexResult;
11
12use crate::ALPArray;
13use crate::ALPVTable;
14
15impl MaskReduce for ALPVTable {
16    fn mask(array: &ALPArray, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
17        // Masking sparse patches requires reading indices, fall back to kernel.
18        if array.patches().is_some() {
19            return Ok(None);
20        }
21        let masked_encoded = array.encoded().clone().mask(mask.clone())?;
22        Ok(Some(
23            ALPArray::new(masked_encoded, array.exponents(), None).to_array(),
24        ))
25    }
26}
27
28impl MaskKernel for ALPVTable {
29    fn mask(
30        array: &ALPArray,
31        mask: &ArrayRef,
32        ctx: &mut ExecutionCtx,
33    ) -> VortexResult<Option<ArrayRef>> {
34        let vortex_mask = Validity::Array(mask.not()?).to_mask(array.len());
35        let masked_encoded = array.encoded().clone().mask(mask.clone())?;
36        let masked_patches = array
37            .patches()
38            .map(|p| p.mask(&vortex_mask, ctx))
39            .transpose()?
40            .flatten();
41        Ok(Some(
42            ALPArray::new(masked_encoded, array.exponents(), masked_patches).to_array(),
43        ))
44    }
45}
46
47#[cfg(test)]
48mod test {
49    use rstest::rstest;
50    use vortex_array::IntoArray;
51    use vortex_array::ToCanonical;
52    use vortex_array::arrays::PrimitiveArray;
53    use vortex_array::compute::conformance::mask::test_mask_conformance;
54    use vortex_buffer::buffer;
55
56    use crate::alp_encode;
57
58    #[rstest]
59    #[case(buffer![10.5f32, 20.5, 30.5, 40.5, 50.5].into_array())]
60    #[case(buffer![1000.123f64, 2000.456, 3000.789, 4000.012, 5000.345].into_array())]
61    #[case(PrimitiveArray::from_option_iter([Some(1.1f32), None, Some(2.2), Some(3.3), None]).into_array())]
62    #[case(buffer![99.99f64].into_array())]
63    #[case(buffer![
64        0.1f32, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0,
65        1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0
66    ].into_array())]
67    fn test_mask_alp_conformance(#[case] array: vortex_array::ArrayRef) {
68        let alp = alp_encode(&array.to_primitive(), None).unwrap();
69        test_mask_conformance(&alp.to_array());
70    }
71
72    #[test]
73    fn test_mask_alp_with_patches() {
74        use std::f64::consts::PI;
75        // PI doesn't encode cleanly with ALP, so it creates patches.
76        let values: Vec<f64> = (0..100)
77            .map(|i| if i % 4 == 3 { PI } else { 1.0 })
78            .collect();
79        let array = PrimitiveArray::from_iter(values);
80        let alp = alp_encode(&array, None).unwrap();
81        assert!(alp.patches().is_some(), "expected patches");
82        test_mask_conformance(&alp.to_array());
83    }
84}