vortex_sparse/compute/
mod.rs

1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::{
3    BinaryNumericFn, FilterFn, InvertFn, ScalarAtFn, SearchResult, SearchSortedFn,
4    SearchSortedSide, SearchSortedUsizeFn, SliceFn, TakeFn,
5};
6use vortex_array::vtable::ComputeVTable;
7use vortex_array::{Array, ArrayRef};
8use vortex_error::VortexResult;
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use crate::{SparseArray, SparseEncoding};
13
14mod binary_numeric;
15mod invert;
16mod slice;
17mod take;
18
19impl ComputeVTable for SparseEncoding {
20    fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<&dyn Array>> {
21        Some(self)
22    }
23
24    fn filter_fn(&self) -> Option<&dyn FilterFn<&dyn Array>> {
25        Some(self)
26    }
27
28    fn invert_fn(&self) -> Option<&dyn InvertFn<&dyn Array>> {
29        Some(self)
30    }
31
32    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
33        Some(self)
34    }
35
36    fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn<&dyn Array>> {
37        Some(self)
38    }
39
40    fn search_sorted_usize_fn(&self) -> Option<&dyn SearchSortedUsizeFn<&dyn Array>> {
41        Some(self)
42    }
43
44    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
45        Some(self)
46    }
47
48    fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
49        Some(self)
50    }
51}
52
53impl ScalarAtFn<&SparseArray> for SparseEncoding {
54    fn scalar_at(&self, array: &SparseArray, index: usize) -> VortexResult<Scalar> {
55        Ok(array
56            .patches()
57            .get_patched(index)?
58            .unwrap_or_else(|| array.fill_scalar().clone()))
59    }
60}
61
62// FIXME(ngates): these are broken in a way that works for array patches, this will be fixed soon.
63impl SearchSortedFn<&SparseArray> for SparseEncoding {
64    fn search_sorted(
65        &self,
66        array: &SparseArray,
67        value: &Scalar,
68        side: SearchSortedSide,
69    ) -> VortexResult<SearchResult> {
70        array.patches().search_sorted(value.clone(), side)
71    }
72}
73
74// FIXME(ngates): these are broken in a way that works for array patches, this will be fixed soon.
75impl SearchSortedUsizeFn<&SparseArray> for SparseEncoding {
76    fn search_sorted_usize(
77        &self,
78        array: &SparseArray,
79        value: usize,
80        side: SearchSortedSide,
81    ) -> VortexResult<SearchResult> {
82        let Ok(target) = Scalar::from(value).cast(array.dtype()) else {
83            // If the downcast fails, then the target is too large for the dtype.
84            return Ok(SearchResult::NotFound(array.len()));
85        };
86        SearchSortedFn::search_sorted(self, array, &target, side)
87    }
88}
89
90impl FilterFn<&SparseArray> for SparseEncoding {
91    fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
92        let new_length = mask.true_count();
93
94        let Some(new_patches) = array.patches().filter(mask)? else {
95            return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
96        };
97
98        Ok(
99            SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
100                .into_array(),
101        )
102    }
103}
104
105#[cfg(test)]
106mod test {
107    use rstest::{fixture, rstest};
108    use vortex_array::arrays::PrimitiveArray;
109    use vortex_array::compute::test_harness::{test_binary_numeric, test_mask};
110    use vortex_array::compute::{
111        SearchResult, SearchSortedSide, filter, search_sorted, slice, try_cast,
112    };
113    use vortex_array::validity::Validity;
114    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
115    use vortex_buffer::buffer;
116    use vortex_dtype::{DType, Nullability, PType};
117    use vortex_mask::Mask;
118    use vortex_scalar::Scalar;
119
120    use crate::SparseArray;
121
122    #[fixture]
123    fn array() -> ArrayRef {
124        SparseArray::try_new(
125            buffer![2u64, 9, 15].into_array(),
126            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
127            20,
128            Scalar::null_typed::<i32>(),
129        )
130        .unwrap()
131        .into_array()
132    }
133
134    #[rstest]
135    fn search_larger_than(array: ArrayRef) {
136        let res = search_sorted(&array, 66, SearchSortedSide::Left).unwrap();
137        assert_eq!(res, SearchResult::NotFound(16));
138    }
139
140    #[rstest]
141    fn search_less_than(array: ArrayRef) {
142        let res = search_sorted(&array, 22, SearchSortedSide::Left).unwrap();
143        assert_eq!(res, SearchResult::NotFound(2));
144    }
145
146    #[rstest]
147    fn search_found(array: ArrayRef) {
148        let res = search_sorted(&array, 44, SearchSortedSide::Left).unwrap();
149        assert_eq!(res, SearchResult::Found(9));
150    }
151
152    #[rstest]
153    fn search_not_found_right(array: ArrayRef) {
154        let res = search_sorted(&array, 56, SearchSortedSide::Right).unwrap();
155        assert_eq!(res, SearchResult::NotFound(16));
156    }
157
158    #[rstest]
159    fn search_sliced(array: ArrayRef) {
160        let array = slice(&array, 7, 20).unwrap();
161        assert_eq!(
162            search_sorted(&array, 22, SearchSortedSide::Left).unwrap(),
163            SearchResult::NotFound(2)
164        );
165    }
166
167    #[test]
168    fn search_right() {
169        let array = SparseArray::try_new(
170            buffer![0u64].into_array(),
171            PrimitiveArray::new(buffer![0u8], Validity::AllValid).into_array(),
172            2,
173            Scalar::null_typed::<u8>(),
174        )
175        .unwrap()
176        .into_array();
177
178        assert_eq!(
179            search_sorted(&array, 0, SearchSortedSide::Right).unwrap(),
180            SearchResult::Found(1)
181        );
182        assert_eq!(
183            search_sorted(&array, 1, SearchSortedSide::Right).unwrap(),
184            SearchResult::NotFound(1)
185        );
186    }
187
188    #[rstest]
189    fn test_filter(array: ArrayRef) {
190        let mut predicate = vec![false, false, true];
191        predicate.extend_from_slice(&[false; 17]);
192        let mask = Mask::from_iter(predicate);
193
194        let filtered_array = filter(&array, &mask).unwrap();
195        let filtered_array = SparseArray::try_from(filtered_array).unwrap();
196
197        assert_eq!(filtered_array.len(), 1);
198        assert_eq!(filtered_array.patches().values().len(), 1);
199        assert_eq!(filtered_array.patches().indices().len(), 1);
200    }
201
202    #[test]
203    fn true_fill_value() {
204        let mask = Mask::from_iter([false, true, false, true, false, true, true]);
205        let array = SparseArray::try_new(
206            buffer![0_u64, 3, 6].into_array(),
207            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
208            7,
209            Scalar::null_typed::<i32>(),
210        )
211        .unwrap()
212        .into_array();
213
214        let filtered_array = filter(&array, &mask).unwrap();
215        let filtered_array = SparseArray::try_from(filtered_array).unwrap();
216
217        assert_eq!(filtered_array.len(), 4);
218        let primitive = filtered_array.patches().indices().to_primitive().unwrap();
219
220        assert_eq!(primitive.as_slice::<u64>(), &[1, 3]);
221    }
222
223    #[rstest]
224    fn test_sparse_binary_numeric(array: ArrayRef) {
225        test_binary_numeric::<i32>(array)
226    }
227
228    #[test]
229    fn test_mask_sparse_array() {
230        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
231        test_mask(
232            &SparseArray::try_new(
233                buffer![1u64, 2, 4].into_array(),
234                try_cast(
235                    &buffer![100i32, 200, 300].into_array(),
236                    null_fill_value.dtype(),
237                )
238                .unwrap(),
239                5,
240                null_fill_value,
241            )
242            .unwrap(),
243        );
244
245        let ten_fill_value = Scalar::from(10i32);
246        test_mask(
247            &SparseArray::try_new(
248                buffer![1u64, 2, 4].into_array(),
249                buffer![100i32, 200, 300].into_array(),
250                5,
251                ten_fill_value,
252            )
253            .unwrap(),
254        )
255    }
256}