1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::{
3 BinaryNumericFn, FilterFn, InvertFn, ScalarAtFn, SearchResult, SearchSortedFn,
4 SearchSortedSide, SearchSortedUsizeFn, SliceFn, TakeFn,
5};
6use vortex_array::vtable::ComputeVTable;
7use vortex_array::{Array, ArrayRef};
8use vortex_error::VortexResult;
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use crate::{SparseArray, SparseEncoding};
13
14mod binary_numeric;
15mod invert;
16mod slice;
17mod take;
18
19impl ComputeVTable for SparseEncoding {
20 fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<&dyn Array>> {
21 Some(self)
22 }
23
24 fn filter_fn(&self) -> Option<&dyn FilterFn<&dyn Array>> {
25 Some(self)
26 }
27
28 fn invert_fn(&self) -> Option<&dyn InvertFn<&dyn Array>> {
29 Some(self)
30 }
31
32 fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
33 Some(self)
34 }
35
36 fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn<&dyn Array>> {
37 Some(self)
38 }
39
40 fn search_sorted_usize_fn(&self) -> Option<&dyn SearchSortedUsizeFn<&dyn Array>> {
41 Some(self)
42 }
43
44 fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
45 Some(self)
46 }
47
48 fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
49 Some(self)
50 }
51}
52
53impl ScalarAtFn<&SparseArray> for SparseEncoding {
54 fn scalar_at(&self, array: &SparseArray, index: usize) -> VortexResult<Scalar> {
55 Ok(array
56 .patches()
57 .get_patched(index)?
58 .unwrap_or_else(|| array.fill_scalar().clone()))
59 }
60}
61
62impl SearchSortedFn<&SparseArray> for SparseEncoding {
64 fn search_sorted(
65 &self,
66 array: &SparseArray,
67 value: &Scalar,
68 side: SearchSortedSide,
69 ) -> VortexResult<SearchResult> {
70 array.patches().search_sorted(value.clone(), side)
71 }
72}
73
74impl SearchSortedUsizeFn<&SparseArray> for SparseEncoding {
76 fn search_sorted_usize(
77 &self,
78 array: &SparseArray,
79 value: usize,
80 side: SearchSortedSide,
81 ) -> VortexResult<SearchResult> {
82 let Ok(target) = Scalar::from(value).cast(array.dtype()) else {
83 return Ok(SearchResult::NotFound(array.len()));
85 };
86 SearchSortedFn::search_sorted(self, array, &target, side)
87 }
88}
89
90impl FilterFn<&SparseArray> for SparseEncoding {
91 fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
92 let new_length = mask.true_count();
93
94 let Some(new_patches) = array.patches().filter(mask)? else {
95 return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
96 };
97
98 Ok(
99 SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
100 .into_array(),
101 )
102 }
103}
104
105#[cfg(test)]
106mod test {
107 use rstest::{fixture, rstest};
108 use vortex_array::arrays::PrimitiveArray;
109 use vortex_array::compute::test_harness::{test_binary_numeric, test_mask};
110 use vortex_array::compute::{
111 SearchResult, SearchSortedSide, filter, search_sorted, slice, try_cast,
112 };
113 use vortex_array::validity::Validity;
114 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
115 use vortex_buffer::buffer;
116 use vortex_dtype::{DType, Nullability, PType};
117 use vortex_mask::Mask;
118 use vortex_scalar::Scalar;
119
120 use crate::SparseArray;
121
122 #[fixture]
123 fn array() -> ArrayRef {
124 SparseArray::try_new(
125 buffer![2u64, 9, 15].into_array(),
126 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
127 20,
128 Scalar::null_typed::<i32>(),
129 )
130 .unwrap()
131 .into_array()
132 }
133
134 #[rstest]
135 fn search_larger_than(array: ArrayRef) {
136 let res = search_sorted(&array, 66, SearchSortedSide::Left).unwrap();
137 assert_eq!(res, SearchResult::NotFound(16));
138 }
139
140 #[rstest]
141 fn search_less_than(array: ArrayRef) {
142 let res = search_sorted(&array, 22, SearchSortedSide::Left).unwrap();
143 assert_eq!(res, SearchResult::NotFound(2));
144 }
145
146 #[rstest]
147 fn search_found(array: ArrayRef) {
148 let res = search_sorted(&array, 44, SearchSortedSide::Left).unwrap();
149 assert_eq!(res, SearchResult::Found(9));
150 }
151
152 #[rstest]
153 fn search_not_found_right(array: ArrayRef) {
154 let res = search_sorted(&array, 56, SearchSortedSide::Right).unwrap();
155 assert_eq!(res, SearchResult::NotFound(16));
156 }
157
158 #[rstest]
159 fn search_sliced(array: ArrayRef) {
160 let array = slice(&array, 7, 20).unwrap();
161 assert_eq!(
162 search_sorted(&array, 22, SearchSortedSide::Left).unwrap(),
163 SearchResult::NotFound(2)
164 );
165 }
166
167 #[test]
168 fn search_right() {
169 let array = SparseArray::try_new(
170 buffer![0u64].into_array(),
171 PrimitiveArray::new(buffer![0u8], Validity::AllValid).into_array(),
172 2,
173 Scalar::null_typed::<u8>(),
174 )
175 .unwrap()
176 .into_array();
177
178 assert_eq!(
179 search_sorted(&array, 0, SearchSortedSide::Right).unwrap(),
180 SearchResult::Found(1)
181 );
182 assert_eq!(
183 search_sorted(&array, 1, SearchSortedSide::Right).unwrap(),
184 SearchResult::NotFound(1)
185 );
186 }
187
188 #[rstest]
189 fn test_filter(array: ArrayRef) {
190 let mut predicate = vec![false, false, true];
191 predicate.extend_from_slice(&[false; 17]);
192 let mask = Mask::from_iter(predicate);
193
194 let filtered_array = filter(&array, &mask).unwrap();
195 let filtered_array = SparseArray::try_from(filtered_array).unwrap();
196
197 assert_eq!(filtered_array.len(), 1);
198 assert_eq!(filtered_array.patches().values().len(), 1);
199 assert_eq!(filtered_array.patches().indices().len(), 1);
200 }
201
202 #[test]
203 fn true_fill_value() {
204 let mask = Mask::from_iter([false, true, false, true, false, true, true]);
205 let array = SparseArray::try_new(
206 buffer![0_u64, 3, 6].into_array(),
207 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
208 7,
209 Scalar::null_typed::<i32>(),
210 )
211 .unwrap()
212 .into_array();
213
214 let filtered_array = filter(&array, &mask).unwrap();
215 let filtered_array = SparseArray::try_from(filtered_array).unwrap();
216
217 assert_eq!(filtered_array.len(), 4);
218 let primitive = filtered_array.patches().indices().to_primitive().unwrap();
219
220 assert_eq!(primitive.as_slice::<u64>(), &[1, 3]);
221 }
222
223 #[rstest]
224 fn test_sparse_binary_numeric(array: ArrayRef) {
225 test_binary_numeric::<i32>(array)
226 }
227
228 #[test]
229 fn test_mask_sparse_array() {
230 let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
231 test_mask(
232 &SparseArray::try_new(
233 buffer![1u64, 2, 4].into_array(),
234 try_cast(
235 &buffer![100i32, 200, 300].into_array(),
236 null_fill_value.dtype(),
237 )
238 .unwrap(),
239 5,
240 null_fill_value,
241 )
242 .unwrap(),
243 );
244
245 let ten_fill_value = Scalar::from(10i32);
246 test_mask(
247 &SparseArray::try_new(
248 buffer![1u64, 2, 4].into_array(),
249 buffer![100i32, 200, 300].into_array(),
250 5,
251 ten_fill_value,
252 )
253 .unwrap(),
254 )
255 }
256}