1use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{FilterKernel, FilterKernelAdapter};
6use vortex_array::{ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8use vortex_mask::Mask;
9
10use crate::{SparseArray, SparseVTable};
11
12mod binary_numeric;
13mod cast;
14mod invert;
15mod take;
16
17impl FilterKernel for SparseVTable {
18 fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
19 let new_length = mask.true_count();
20
21 let Some(new_patches) = array.patches().filter(mask)? else {
22 return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
23 };
24
25 Ok(
26 SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
27 .into_array(),
28 )
29 }
30}
31
32register_kernel!(FilterKernelAdapter(SparseVTable).lift());
33
34#[cfg(test)]
35mod test {
36 use rstest::{fixture, rstest};
37 use vortex_array::arrays::PrimitiveArray;
38 use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
39 use vortex_array::compute::conformance::filter::test_filter_conformance;
40 use vortex_array::compute::conformance::mask::test_mask_conformance;
41 use vortex_array::compute::{cast, filter};
42 use vortex_array::validity::Validity;
43 use vortex_array::{Array, ArrayRef, IntoArray, assert_arrays_eq};
44 use vortex_buffer::buffer;
45 use vortex_dtype::{DType, Nullability, PType};
46 use vortex_mask::Mask;
47 use vortex_scalar::Scalar;
48
49 use crate::{SparseArray, SparseVTable};
50
51 #[fixture]
52 fn array() -> ArrayRef {
53 SparseArray::try_new(
54 buffer![2u64, 9, 15].into_array(),
55 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
56 20,
57 Scalar::null_typed::<i32>(),
58 )
59 .unwrap()
60 .into_array()
61 }
62
63 #[rstest]
64 fn test_filter(array: ArrayRef) {
65 let mut predicate = vec![false, false, true];
66 predicate.extend_from_slice(&[false; 17]);
67 let mask = Mask::from_iter(predicate);
68
69 let filtered_array = filter(&array, &mask).unwrap();
70 let filtered_array = filtered_array.as_::<SparseVTable>();
71
72 assert_eq!(filtered_array.len(), 1);
73 assert_eq!(filtered_array.patches().values().len(), 1);
74 assert_arrays_eq!(
75 filtered_array.patches().indices(),
76 PrimitiveArray::from_iter([0u64])
77 );
78 }
79
80 #[test]
81 fn true_fill_value() {
82 let mask = Mask::from_iter([false, true, false, true, false, true, true]);
83 let array = SparseArray::try_new(
84 buffer![0_u64, 3, 6].into_array(),
85 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
86 7,
87 Scalar::null_typed::<i32>(),
88 )
89 .unwrap()
90 .into_array();
91
92 let filtered_array = filter(&array, &mask).unwrap();
93 let filtered_array = filtered_array.as_::<SparseVTable>();
94
95 assert_eq!(filtered_array.len(), 4);
96 assert_arrays_eq!(
97 filtered_array.patches().indices(),
98 PrimitiveArray::from_iter([1u64, 3])
99 );
100 }
101
102 #[rstest]
103 fn test_sparse_binary_numeric(array: ArrayRef) {
104 test_binary_numeric_array(array)
105 }
106
107 #[test]
108 fn test_mask_sparse_array() {
109 let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
110 test_mask_conformance(
111 SparseArray::try_new(
112 buffer![1u64, 2, 4].into_array(),
113 cast(
114 &buffer![100i32, 200, 300].into_array(),
115 null_fill_value.dtype(),
116 )
117 .unwrap(),
118 5,
119 null_fill_value,
120 )
121 .unwrap()
122 .as_ref(),
123 );
124
125 let ten_fill_value = Scalar::from(10i32);
126 test_mask_conformance(
127 SparseArray::try_new(
128 buffer![1u64, 2, 4].into_array(),
129 buffer![100i32, 200, 300].into_array(),
130 5,
131 ten_fill_value,
132 )
133 .unwrap()
134 .as_ref(),
135 )
136 }
137
138 #[test]
139 fn test_filter_sparse_array() {
140 let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
141 test_filter_conformance(
142 SparseArray::try_new(
143 buffer![1u64, 2, 4].into_array(),
144 cast(
145 &buffer![100i32, 200, 300].into_array(),
146 null_fill_value.dtype(),
147 )
148 .unwrap(),
149 5,
150 null_fill_value,
151 )
152 .unwrap()
153 .as_ref(),
154 );
155
156 let ten_fill_value = Scalar::from(10i32);
157 test_filter_conformance(
158 SparseArray::try_new(
159 buffer![1u64, 2, 4].into_array(),
160 buffer![100i32, 200, 300].into_array(),
161 5,
162 ten_fill_value,
163 )
164 .unwrap()
165 .as_ref(),
166 )
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use rstest::rstest;
173 use vortex_array::IntoArray;
174 use vortex_array::arrays::PrimitiveArray;
175 use vortex_array::compute::cast;
176 use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
177 use vortex_array::compute::conformance::consistency::test_array_consistency;
178 use vortex_buffer::buffer;
179 use vortex_dtype::{DType, Nullability, PType};
180 use vortex_scalar::Scalar;
181
182 use crate::SparseArray;
183
184 #[rstest]
185 #[case::sparse_i32_null_fill(SparseArray::try_new(
187 buffer![2u64, 5, 8].into_array(),
188 PrimitiveArray::from_option_iter([Some(100i32), Some(200), Some(300)]).into_array(),
189 10,
190 Scalar::null_typed::<i32>()
191 ).unwrap())]
192 #[case::sparse_i32_value_fill(SparseArray::try_new(
193 buffer![1u64, 3, 7].into_array(),
194 buffer![42i32, 84, 126].into_array(),
195 10,
196 Scalar::from(0i32)
197 ).unwrap())]
198 #[case::sparse_u64(SparseArray::try_new(
200 buffer![0u64, 4, 9].into_array(),
201 buffer![1000u64, 2000, 3000].into_array(),
202 10,
203 Scalar::from(999u64)
204 ).unwrap())]
205 #[case::sparse_f32(SparseArray::try_new(
206 buffer![2u64, 6].into_array(),
207 buffer![std::f32::consts::PI, std::f32::consts::E].into_array(),
208 8,
209 Scalar::from(0.0f32)
210 ).unwrap())]
211 #[case::sparse_single_patch(SparseArray::try_new(
213 buffer![5u64].into_array(),
214 buffer![42i32].into_array(),
215 10,
216 Scalar::from(-1i32)
217 ).unwrap())]
218 #[case::sparse_dense_patches(SparseArray::try_new(
219 buffer![0u64, 1, 2, 3, 4].into_array(),
220 PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40), Some(50)]).into_array(),
221 5,
222 Scalar::null_typed::<i32>()
223 ).unwrap())]
224 #[case::sparse_large(SparseArray::try_new(
226 buffer![100u64, 500, 900, 1500, 1999].into_array(),
227 buffer![111i32, 222, 333, 444, 555].into_array(),
228 2000,
229 Scalar::from(0i32)
230 ).unwrap())]
231 #[case::sparse_nullable_patches({
233 let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
234 SparseArray::try_new(
235 buffer![1u64, 4, 7].into_array(),
236 cast(
237 &PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
238 null_fill_value.dtype()
239 ).unwrap(),
240 10,
241 null_fill_value
242 ).unwrap()
243 })]
244
245 fn test_sparse_consistency(#[case] array: SparseArray) {
246 test_array_consistency(array.as_ref());
247 }
248
249 #[rstest]
250 #[case::sparse_i32_basic(SparseArray::try_new(
251 buffer![2u64, 5, 8].into_array(),
252 buffer![100i32, 200, 300].into_array(),
253 10,
254 Scalar::from(0i32)
255 ).unwrap())]
256 #[case::sparse_u32_basic(SparseArray::try_new(
257 buffer![1u64, 3, 7].into_array(),
258 buffer![1000u32, 2000, 3000].into_array(),
259 10,
260 Scalar::from(100u32)
261 ).unwrap())]
262 #[case::sparse_i64_basic(SparseArray::try_new(
263 buffer![0u64, 4, 9].into_array(),
264 buffer![5000i64, 6000, 7000].into_array(),
265 10,
266 Scalar::from(1000i64)
267 ).unwrap())]
268 #[case::sparse_f32_basic(SparseArray::try_new(
269 buffer![2u64, 6].into_array(),
270 buffer![1.5f32, 2.5].into_array(),
271 8,
272 Scalar::from(0.5f32)
273 ).unwrap())]
274 #[case::sparse_f64_basic(SparseArray::try_new(
275 buffer![1u64, 5, 9].into_array(),
276 buffer![10.1f64, 20.2, 30.3].into_array(),
277 10,
278 Scalar::from(5.0f64)
279 ).unwrap())]
280 #[case::sparse_i32_large(SparseArray::try_new(
281 buffer![10u64, 50, 90, 150, 199].into_array(),
282 buffer![111i32, 222, 333, 444, 555].into_array(),
283 200,
284 Scalar::from(0i32)
285 ).unwrap())]
286 fn test_sparse_binary_numeric(#[case] array: SparseArray) {
287 test_binary_numeric_array(array.into_array());
288 }
289}