vortex_sparse/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{FilterKernel, FilterKernelAdapter};
6use vortex_array::{ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8use vortex_mask::Mask;
9
10use crate::{SparseArray, SparseVTable};
11
12mod binary_numeric;
13mod invert;
14mod take;
15
16impl FilterKernel for SparseVTable {
17    fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
18        let new_length = mask.true_count();
19
20        let Some(new_patches) = array.patches().filter(mask)? else {
21            return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
22        };
23
24        Ok(
25            SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
26                .into_array(),
27        )
28    }
29}
30
31register_kernel!(FilterKernelAdapter(SparseVTable).lift());
32
33#[cfg(test)]
34mod test {
35    use rstest::{fixture, rstest};
36    use vortex_array::arrays::PrimitiveArray;
37    use vortex_array::compute::conformance::binary_numeric::test_numeric;
38    use vortex_array::compute::conformance::mask::test_mask;
39    use vortex_array::compute::{cast, filter};
40    use vortex_array::validity::Validity;
41    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
42    use vortex_buffer::buffer;
43    use vortex_dtype::{DType, Nullability, PType};
44    use vortex_mask::Mask;
45    use vortex_scalar::Scalar;
46
47    use crate::{SparseArray, SparseVTable};
48
49    #[fixture]
50    fn array() -> ArrayRef {
51        SparseArray::try_new(
52            buffer![2u64, 9, 15].into_array(),
53            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
54            20,
55            Scalar::null_typed::<i32>(),
56        )
57        .unwrap()
58        .into_array()
59    }
60
61    #[rstest]
62    fn test_filter(array: ArrayRef) {
63        let mut predicate = vec![false, false, true];
64        predicate.extend_from_slice(&[false; 17]);
65        let mask = Mask::from_iter(predicate);
66
67        let filtered_array = filter(&array, &mask).unwrap();
68        let filtered_array = filtered_array.as_::<SparseVTable>();
69
70        assert_eq!(filtered_array.len(), 1);
71        assert_eq!(filtered_array.patches().values().len(), 1);
72        assert_eq!(filtered_array.patches().indices().len(), 1);
73    }
74
75    #[test]
76    fn true_fill_value() {
77        let mask = Mask::from_iter([false, true, false, true, false, true, true]);
78        let array = SparseArray::try_new(
79            buffer![0_u64, 3, 6].into_array(),
80            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
81            7,
82            Scalar::null_typed::<i32>(),
83        )
84        .unwrap()
85        .into_array();
86
87        let filtered_array = filter(&array, &mask).unwrap();
88        let filtered_array = filtered_array.as_::<SparseVTable>();
89
90        assert_eq!(filtered_array.len(), 4);
91        let primitive = filtered_array.patches().indices().to_primitive().unwrap();
92
93        assert_eq!(primitive.as_slice::<u64>(), &[1, 3]);
94    }
95
96    #[rstest]
97    fn test_sparse_binary_numeric(array: ArrayRef) {
98        test_numeric::<i32>(array)
99    }
100
101    #[test]
102    fn test_mask_sparse_array() {
103        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
104        test_mask(
105            SparseArray::try_new(
106                buffer![1u64, 2, 4].into_array(),
107                cast(
108                    &buffer![100i32, 200, 300].into_array(),
109                    null_fill_value.dtype(),
110                )
111                .unwrap(),
112                5,
113                null_fill_value,
114            )
115            .unwrap()
116            .as_ref(),
117        );
118
119        let ten_fill_value = Scalar::from(10i32);
120        test_mask(
121            SparseArray::try_new(
122                buffer![1u64, 2, 4].into_array(),
123                buffer![100i32, 200, 300].into_array(),
124                5,
125                ten_fill_value,
126            )
127            .unwrap()
128            .as_ref(),
129        )
130    }
131}