vortex_sparse/compute/
mod.rs

1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::{
3    FilterKernel, FilterKernelAdapter, ScalarAtFn, SearchSortedFn, SearchSortedUsizeFn, SliceFn,
4    TakeFn,
5};
6use vortex_array::vtable::ComputeVTable;
7use vortex_array::{Array, ArrayRef, register_kernel};
8use vortex_error::VortexResult;
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use crate::{SparseArray, SparseEncoding};
13
14mod binary_numeric;
15mod invert;
16mod search_sorted;
17mod slice;
18mod take;
19
20impl ComputeVTable for SparseEncoding {
21    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
22        Some(self)
23    }
24
25    fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn<&dyn Array>> {
26        Some(self)
27    }
28
29    fn search_sorted_usize_fn(&self) -> Option<&dyn SearchSortedUsizeFn<&dyn Array>> {
30        Some(self)
31    }
32
33    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
34        Some(self)
35    }
36
37    fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
38        Some(self)
39    }
40}
41
42impl ScalarAtFn<&SparseArray> for SparseEncoding {
43    fn scalar_at(&self, array: &SparseArray, index: usize) -> VortexResult<Scalar> {
44        Ok(array
45            .patches()
46            .get_patched(index)?
47            .unwrap_or_else(|| array.fill_scalar().clone()))
48    }
49}
50
51impl FilterKernel for SparseEncoding {
52    fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
53        let new_length = mask.true_count();
54
55        let Some(new_patches) = array.patches().filter(mask)? else {
56            return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
57        };
58
59        Ok(
60            SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
61                .into_array(),
62        )
63    }
64}
65
66register_kernel!(FilterKernelAdapter(SparseEncoding).lift());
67
68#[cfg(test)]
69mod test {
70    use rstest::{fixture, rstest};
71    use vortex_array::arrays::PrimitiveArray;
72    use vortex_array::compute::conformance::binary_numeric::test_numeric;
73    use vortex_array::compute::conformance::mask::test_mask;
74    use vortex_array::compute::{cast, filter};
75    use vortex_array::validity::Validity;
76    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
77    use vortex_buffer::buffer;
78    use vortex_dtype::{DType, Nullability, PType};
79    use vortex_mask::Mask;
80    use vortex_scalar::Scalar;
81
82    use crate::SparseArray;
83
84    #[fixture]
85    fn array() -> ArrayRef {
86        SparseArray::try_new(
87            buffer![2u64, 9, 15].into_array(),
88            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
89            20,
90            Scalar::null_typed::<i32>(),
91        )
92        .unwrap()
93        .into_array()
94    }
95
96    #[rstest]
97    fn test_filter(array: ArrayRef) {
98        let mut predicate = vec![false, false, true];
99        predicate.extend_from_slice(&[false; 17]);
100        let mask = Mask::from_iter(predicate);
101
102        let filtered_array = filter(&array, &mask).unwrap();
103        let filtered_array = SparseArray::try_from(filtered_array).unwrap();
104
105        assert_eq!(filtered_array.len(), 1);
106        assert_eq!(filtered_array.patches().values().len(), 1);
107        assert_eq!(filtered_array.patches().indices().len(), 1);
108    }
109
110    #[test]
111    fn true_fill_value() {
112        let mask = Mask::from_iter([false, true, false, true, false, true, true]);
113        let array = SparseArray::try_new(
114            buffer![0_u64, 3, 6].into_array(),
115            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
116            7,
117            Scalar::null_typed::<i32>(),
118        )
119        .unwrap()
120        .into_array();
121
122        let filtered_array = filter(&array, &mask).unwrap();
123        let filtered_array = SparseArray::try_from(filtered_array).unwrap();
124
125        assert_eq!(filtered_array.len(), 4);
126        let primitive = filtered_array.patches().indices().to_primitive().unwrap();
127
128        assert_eq!(primitive.as_slice::<u64>(), &[1, 3]);
129    }
130
131    #[rstest]
132    fn test_sparse_binary_numeric(array: ArrayRef) {
133        test_numeric::<i32>(array)
134    }
135
136    #[test]
137    fn test_mask_sparse_array() {
138        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
139        test_mask(
140            &SparseArray::try_new(
141                buffer![1u64, 2, 4].into_array(),
142                cast(
143                    &buffer![100i32, 200, 300].into_array(),
144                    null_fill_value.dtype(),
145                )
146                .unwrap(),
147                5,
148                null_fill_value,
149            )
150            .unwrap(),
151        );
152
153        let ten_fill_value = Scalar::from(10i32);
154        test_mask(
155            &SparseArray::try_new(
156                buffer![1u64, 2, 4].into_array(),
157                buffer![100i32, 200, 300].into_array(),
158                5,
159                ten_fill_value,
160            )
161            .unwrap(),
162        )
163    }
164}