vortex_sparse/compute/
mod.rs

1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::{
3    BinaryNumericFn, FilterKernel, FilterKernelAdapter, InvertFn, KernelRef, ScalarAtFn,
4    SearchSortedFn, SearchSortedUsizeFn, SliceFn, TakeFn,
5};
6use vortex_array::vtable::ComputeVTable;
7use vortex_array::{Array, ArrayComputeImpl, ArrayRef};
8use vortex_error::VortexResult;
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use crate::{SparseArray, SparseEncoding};
13
14mod binary_numeric;
15mod invert;
16mod search_sorted;
17mod slice;
18mod take;
19
20impl ArrayComputeImpl for SparseArray {
21    const FILTER: Option<KernelRef> = FilterKernelAdapter(SparseEncoding).some();
22}
23impl ComputeVTable for SparseEncoding {
24    fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<&dyn Array>> {
25        Some(self)
26    }
27
28    fn invert_fn(&self) -> Option<&dyn InvertFn<&dyn Array>> {
29        Some(self)
30    }
31
32    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
33        Some(self)
34    }
35
36    fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn<&dyn Array>> {
37        Some(self)
38    }
39
40    fn search_sorted_usize_fn(&self) -> Option<&dyn SearchSortedUsizeFn<&dyn Array>> {
41        Some(self)
42    }
43
44    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
45        Some(self)
46    }
47
48    fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
49        Some(self)
50    }
51}
52
53impl ScalarAtFn<&SparseArray> for SparseEncoding {
54    fn scalar_at(&self, array: &SparseArray, index: usize) -> VortexResult<Scalar> {
55        Ok(array
56            .patches()
57            .get_patched(index)?
58            .unwrap_or_else(|| array.fill_scalar().clone()))
59    }
60}
61
62impl FilterKernel for SparseEncoding {
63    fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
64        let new_length = mask.true_count();
65
66        let Some(new_patches) = array.patches().filter(mask)? else {
67            return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
68        };
69
70        Ok(
71            SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
72                .into_array(),
73        )
74    }
75}
76
77#[cfg(test)]
78mod test {
79    use rstest::{fixture, rstest};
80    use vortex_array::arrays::PrimitiveArray;
81    use vortex_array::compute::conformance::binary_numeric::test_binary_numeric;
82    use vortex_array::compute::conformance::mask::test_mask;
83    use vortex_array::compute::{filter, try_cast};
84    use vortex_array::validity::Validity;
85    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
86    use vortex_buffer::buffer;
87    use vortex_dtype::{DType, Nullability, PType};
88    use vortex_mask::Mask;
89    use vortex_scalar::Scalar;
90
91    use crate::SparseArray;
92
93    #[fixture]
94    fn array() -> ArrayRef {
95        SparseArray::try_new(
96            buffer![2u64, 9, 15].into_array(),
97            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
98            20,
99            Scalar::null_typed::<i32>(),
100        )
101        .unwrap()
102        .into_array()
103    }
104
105    #[rstest]
106    fn test_filter(array: ArrayRef) {
107        let mut predicate = vec![false, false, true];
108        predicate.extend_from_slice(&[false; 17]);
109        let mask = Mask::from_iter(predicate);
110
111        let filtered_array = filter(&array, &mask).unwrap();
112        let filtered_array = SparseArray::try_from(filtered_array).unwrap();
113
114        assert_eq!(filtered_array.len(), 1);
115        assert_eq!(filtered_array.patches().values().len(), 1);
116        assert_eq!(filtered_array.patches().indices().len(), 1);
117    }
118
119    #[test]
120    fn true_fill_value() {
121        let mask = Mask::from_iter([false, true, false, true, false, true, true]);
122        let array = SparseArray::try_new(
123            buffer![0_u64, 3, 6].into_array(),
124            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
125            7,
126            Scalar::null_typed::<i32>(),
127        )
128        .unwrap()
129        .into_array();
130
131        let filtered_array = filter(&array, &mask).unwrap();
132        let filtered_array = SparseArray::try_from(filtered_array).unwrap();
133
134        assert_eq!(filtered_array.len(), 4);
135        let primitive = filtered_array.patches().indices().to_primitive().unwrap();
136
137        assert_eq!(primitive.as_slice::<u64>(), &[1, 3]);
138    }
139
140    #[rstest]
141    fn test_sparse_binary_numeric(array: ArrayRef) {
142        test_binary_numeric::<i32>(array)
143    }
144
145    #[test]
146    fn test_mask_sparse_array() {
147        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
148        test_mask(
149            &SparseArray::try_new(
150                buffer![1u64, 2, 4].into_array(),
151                try_cast(
152                    &buffer![100i32, 200, 300].into_array(),
153                    null_fill_value.dtype(),
154                )
155                .unwrap(),
156                5,
157                null_fill_value,
158            )
159            .unwrap(),
160        );
161
162        let ten_fill_value = Scalar::from(10i32);
163        test_mask(
164            &SparseArray::try_new(
165                buffer![1u64, 2, 4].into_array(),
166                buffer![100i32, 200, 300].into_array(),
167                5,
168                ten_fill_value,
169            )
170            .unwrap(),
171        )
172    }
173}