vortex_sparse/compute/
mod.rs

1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::{FilterKernel, FilterKernelAdapter};
3use vortex_array::{ArrayRef, IntoArray, register_kernel};
4use vortex_error::VortexResult;
5use vortex_mask::Mask;
6
7use crate::{SparseArray, SparseVTable};
8
9mod binary_numeric;
10mod invert;
11mod take;
12
13impl FilterKernel for SparseVTable {
14    fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
15        let new_length = mask.true_count();
16
17        let Some(new_patches) = array.patches().filter(mask)? else {
18            return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
19        };
20
21        Ok(
22            SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
23                .into_array(),
24        )
25    }
26}
27
28register_kernel!(FilterKernelAdapter(SparseVTable).lift());
29
30#[cfg(test)]
31mod test {
32    use rstest::{fixture, rstest};
33    use vortex_array::arrays::PrimitiveArray;
34    use vortex_array::compute::conformance::binary_numeric::test_numeric;
35    use vortex_array::compute::conformance::mask::test_mask;
36    use vortex_array::compute::{cast, filter};
37    use vortex_array::validity::Validity;
38    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
39    use vortex_buffer::buffer;
40    use vortex_dtype::{DType, Nullability, PType};
41    use vortex_mask::Mask;
42    use vortex_scalar::Scalar;
43
44    use crate::{SparseArray, SparseVTable};
45
46    #[fixture]
47    fn array() -> ArrayRef {
48        SparseArray::try_new(
49            buffer![2u64, 9, 15].into_array(),
50            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
51            20,
52            Scalar::null_typed::<i32>(),
53        )
54        .unwrap()
55        .into_array()
56    }
57
58    #[rstest]
59    fn test_filter(array: ArrayRef) {
60        let mut predicate = vec![false, false, true];
61        predicate.extend_from_slice(&[false; 17]);
62        let mask = Mask::from_iter(predicate);
63
64        let filtered_array = filter(&array, &mask).unwrap();
65        let filtered_array = filtered_array.as_::<SparseVTable>();
66
67        assert_eq!(filtered_array.len(), 1);
68        assert_eq!(filtered_array.patches().values().len(), 1);
69        assert_eq!(filtered_array.patches().indices().len(), 1);
70    }
71
72    #[test]
73    fn true_fill_value() {
74        let mask = Mask::from_iter([false, true, false, true, false, true, true]);
75        let array = SparseArray::try_new(
76            buffer![0_u64, 3, 6].into_array(),
77            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
78            7,
79            Scalar::null_typed::<i32>(),
80        )
81        .unwrap()
82        .into_array();
83
84        let filtered_array = filter(&array, &mask).unwrap();
85        let filtered_array = filtered_array.as_::<SparseVTable>();
86
87        assert_eq!(filtered_array.len(), 4);
88        let primitive = filtered_array.patches().indices().to_primitive().unwrap();
89
90        assert_eq!(primitive.as_slice::<u64>(), &[1, 3]);
91    }
92
93    #[rstest]
94    fn test_sparse_binary_numeric(array: ArrayRef) {
95        test_numeric::<i32>(array)
96    }
97
98    #[test]
99    fn test_mask_sparse_array() {
100        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
101        test_mask(
102            SparseArray::try_new(
103                buffer![1u64, 2, 4].into_array(),
104                cast(
105                    &buffer![100i32, 200, 300].into_array(),
106                    null_fill_value.dtype(),
107                )
108                .unwrap(),
109                5,
110                null_fill_value,
111            )
112            .unwrap()
113            .as_ref(),
114        );
115
116        let ten_fill_value = Scalar::from(10i32);
117        test_mask(
118            SparseArray::try_new(
119                buffer![1u64, 2, 4].into_array(),
120                buffer![100i32, 200, 300].into_array(),
121                5,
122                ten_fill_value,
123            )
124            .unwrap()
125            .as_ref(),
126        )
127    }
128}