vortex_sparse/compute/
mod.rs1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::{
3 FilterKernel, FilterKernelAdapter, ScalarAtFn, SearchSortedFn, SearchSortedUsizeFn, SliceFn,
4 TakeFn,
5};
6use vortex_array::vtable::ComputeVTable;
7use vortex_array::{Array, ArrayRef, register_kernel};
8use vortex_error::VortexResult;
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use crate::{SparseArray, SparseEncoding};
13
14mod binary_numeric;
15mod invert;
16mod search_sorted;
17mod slice;
18mod take;
19
20impl ComputeVTable for SparseEncoding {
21 fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
22 Some(self)
23 }
24
25 fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn<&dyn Array>> {
26 Some(self)
27 }
28
29 fn search_sorted_usize_fn(&self) -> Option<&dyn SearchSortedUsizeFn<&dyn Array>> {
30 Some(self)
31 }
32
33 fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
34 Some(self)
35 }
36
37 fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
38 Some(self)
39 }
40}
41
42impl ScalarAtFn<&SparseArray> for SparseEncoding {
43 fn scalar_at(&self, array: &SparseArray, index: usize) -> VortexResult<Scalar> {
44 Ok(array
45 .patches()
46 .get_patched(index)?
47 .unwrap_or_else(|| array.fill_scalar().clone()))
48 }
49}
50
51impl FilterKernel for SparseEncoding {
52 fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
53 let new_length = mask.true_count();
54
55 let Some(new_patches) = array.patches().filter(mask)? else {
56 return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
57 };
58
59 Ok(
60 SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
61 .into_array(),
62 )
63 }
64}
65
66register_kernel!(FilterKernelAdapter(SparseEncoding).lift());
67
68#[cfg(test)]
69mod test {
70 use rstest::{fixture, rstest};
71 use vortex_array::arrays::PrimitiveArray;
72 use vortex_array::compute::conformance::binary_numeric::test_numeric;
73 use vortex_array::compute::conformance::mask::test_mask;
74 use vortex_array::compute::{cast, filter};
75 use vortex_array::validity::Validity;
76 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
77 use vortex_buffer::buffer;
78 use vortex_dtype::{DType, Nullability, PType};
79 use vortex_mask::Mask;
80 use vortex_scalar::Scalar;
81
82 use crate::SparseArray;
83
84 #[fixture]
85 fn array() -> ArrayRef {
86 SparseArray::try_new(
87 buffer![2u64, 9, 15].into_array(),
88 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
89 20,
90 Scalar::null_typed::<i32>(),
91 )
92 .unwrap()
93 .into_array()
94 }
95
96 #[rstest]
97 fn test_filter(array: ArrayRef) {
98 let mut predicate = vec![false, false, true];
99 predicate.extend_from_slice(&[false; 17]);
100 let mask = Mask::from_iter(predicate);
101
102 let filtered_array = filter(&array, &mask).unwrap();
103 let filtered_array = SparseArray::try_from(filtered_array).unwrap();
104
105 assert_eq!(filtered_array.len(), 1);
106 assert_eq!(filtered_array.patches().values().len(), 1);
107 assert_eq!(filtered_array.patches().indices().len(), 1);
108 }
109
110 #[test]
111 fn true_fill_value() {
112 let mask = Mask::from_iter([false, true, false, true, false, true, true]);
113 let array = SparseArray::try_new(
114 buffer![0_u64, 3, 6].into_array(),
115 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
116 7,
117 Scalar::null_typed::<i32>(),
118 )
119 .unwrap()
120 .into_array();
121
122 let filtered_array = filter(&array, &mask).unwrap();
123 let filtered_array = SparseArray::try_from(filtered_array).unwrap();
124
125 assert_eq!(filtered_array.len(), 4);
126 let primitive = filtered_array.patches().indices().to_primitive().unwrap();
127
128 assert_eq!(primitive.as_slice::<u64>(), &[1, 3]);
129 }
130
131 #[rstest]
132 fn test_sparse_binary_numeric(array: ArrayRef) {
133 test_numeric::<i32>(array)
134 }
135
136 #[test]
137 fn test_mask_sparse_array() {
138 let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
139 test_mask(
140 &SparseArray::try_new(
141 buffer![1u64, 2, 4].into_array(),
142 cast(
143 &buffer![100i32, 200, 300].into_array(),
144 null_fill_value.dtype(),
145 )
146 .unwrap(),
147 5,
148 null_fill_value,
149 )
150 .unwrap(),
151 );
152
153 let ten_fill_value = Scalar::from(10i32);
154 test_mask(
155 &SparseArray::try_new(
156 buffer![1u64, 2, 4].into_array(),
157 buffer![100i32, 200, 300].into_array(),
158 5,
159 ten_fill_value,
160 )
161 .unwrap(),
162 )
163 }
164}