vortex_alp/alp/compute/
mask.rs1use vortex_array::ArrayRef;
5use vortex_array::ExecutionCtx;
6use vortex_array::builtins::ArrayBuiltins;
7use vortex_array::scalar_fn::fns::mask::MaskKernel;
8use vortex_array::scalar_fn::fns::mask::MaskReduce;
9use vortex_array::validity::Validity;
10use vortex_error::VortexResult;
11
12use crate::ALPArray;
13use crate::ALPVTable;
14
15impl MaskReduce for ALPVTable {
16 fn mask(array: &ALPArray, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
17 if array.patches().is_some() {
19 return Ok(None);
20 }
21 let masked_encoded = array.encoded().clone().mask(mask.clone())?;
22 Ok(Some(
23 ALPArray::new(masked_encoded, array.exponents(), None).to_array(),
24 ))
25 }
26}
27
28impl MaskKernel for ALPVTable {
29 fn mask(
30 array: &ALPArray,
31 mask: &ArrayRef,
32 ctx: &mut ExecutionCtx,
33 ) -> VortexResult<Option<ArrayRef>> {
34 let vortex_mask = Validity::Array(mask.not()?).to_mask(array.len());
35 let masked_encoded = array.encoded().clone().mask(mask.clone())?;
36 let masked_patches = array
37 .patches()
38 .map(|p| p.mask(&vortex_mask, ctx))
39 .transpose()?
40 .flatten();
41 Ok(Some(
42 ALPArray::new(masked_encoded, array.exponents(), masked_patches).to_array(),
43 ))
44 }
45}
46
47#[cfg(test)]
48mod test {
49 use rstest::rstest;
50 use vortex_array::IntoArray;
51 use vortex_array::ToCanonical;
52 use vortex_array::arrays::PrimitiveArray;
53 use vortex_array::compute::conformance::mask::test_mask_conformance;
54 use vortex_buffer::buffer;
55
56 use crate::alp_encode;
57
58 #[rstest]
59 #[case(buffer![10.5f32, 20.5, 30.5, 40.5, 50.5].into_array())]
60 #[case(buffer![1000.123f64, 2000.456, 3000.789, 4000.012, 5000.345].into_array())]
61 #[case(PrimitiveArray::from_option_iter([Some(1.1f32), None, Some(2.2), Some(3.3), None]).into_array())]
62 #[case(buffer![99.99f64].into_array())]
63 #[case(buffer![
64 0.1f32, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0,
65 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0
66 ].into_array())]
67 fn test_mask_alp_conformance(#[case] array: vortex_array::ArrayRef) {
68 let alp = alp_encode(&array.to_primitive(), None).unwrap();
69 test_mask_conformance(&alp.to_array());
70 }
71
72 #[test]
73 fn test_mask_alp_with_patches() {
74 use std::f64::consts::PI;
75 let values: Vec<f64> = (0..100)
77 .map(|i| if i % 4 == 3 { PI } else { 1.0 })
78 .collect();
79 let array = PrimitiveArray::from_iter(values);
80 let alp = alp_encode(&array, None).unwrap();
81 assert!(alp.patches().is_some(), "expected patches");
82 test_mask_conformance(&alp.to_array());
83 }
84}