1use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::arrays::ConstantArray;
7use vortex_array::compute::FilterKernel;
8use vortex_array::compute::FilterKernelAdapter;
9use vortex_array::register_kernel;
10use vortex_error::VortexResult;
11use vortex_mask::Mask;
12
13use crate::SparseArray;
14use crate::SparseVTable;
15
16mod binary_numeric;
17mod cast;
18mod invert;
19mod take;
20
21impl FilterKernel for SparseVTable {
22 fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
23 let new_length = mask.true_count();
24
25 let Some(new_patches) = array.patches().filter(mask)? else {
26 return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
27 };
28
29 Ok(
30 SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
31 .into_array(),
32 )
33 }
34}
35
36register_kernel!(FilterKernelAdapter(SparseVTable).lift());
37
38#[cfg(test)]
39mod test {
40 use rstest::fixture;
41 use rstest::rstest;
42 use vortex_array::Array;
43 use vortex_array::ArrayRef;
44 use vortex_array::IntoArray;
45 use vortex_array::arrays::PrimitiveArray;
46 use vortex_array::assert_arrays_eq;
47 use vortex_array::compute::cast;
48 use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
49 use vortex_array::compute::conformance::filter::test_filter_conformance;
50 use vortex_array::compute::conformance::mask::test_mask_conformance;
51 use vortex_array::compute::filter;
52 use vortex_array::validity::Validity;
53 use vortex_buffer::buffer;
54 use vortex_dtype::DType;
55 use vortex_dtype::Nullability;
56 use vortex_dtype::PType;
57 use vortex_mask::Mask;
58 use vortex_scalar::Scalar;
59
60 use crate::SparseArray;
61 use crate::SparseVTable;
62
63 #[fixture]
64 fn array() -> ArrayRef {
65 SparseArray::try_new(
66 buffer![2u64, 9, 15].into_array(),
67 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
68 20,
69 Scalar::null_typed::<i32>(),
70 )
71 .unwrap()
72 .into_array()
73 }
74
75 #[rstest]
76 fn test_filter(array: ArrayRef) {
77 let mut predicate = vec![false, false, true];
78 predicate.extend_from_slice(&[false; 17]);
79 let mask = Mask::from_iter(predicate);
80
81 let filtered_array = filter(&array, &mask).unwrap();
82 let filtered_array = filtered_array.as_::<SparseVTable>();
83
84 assert_eq!(filtered_array.len(), 1);
85 assert_eq!(filtered_array.patches().values().len(), 1);
86 assert_arrays_eq!(
87 filtered_array.patches().indices(),
88 PrimitiveArray::from_iter([0u64])
89 );
90 }
91
92 #[test]
93 fn true_fill_value() {
94 let mask = Mask::from_iter([false, true, false, true, false, true, true]);
95 let array = SparseArray::try_new(
96 buffer![0_u64, 3, 6].into_array(),
97 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
98 7,
99 Scalar::null_typed::<i32>(),
100 )
101 .unwrap()
102 .into_array();
103
104 let filtered_array = filter(&array, &mask).unwrap();
105 let filtered_array = filtered_array.as_::<SparseVTable>();
106
107 assert_eq!(filtered_array.len(), 4);
108 assert_arrays_eq!(
109 filtered_array.patches().indices(),
110 PrimitiveArray::from_iter([1u64, 3])
111 );
112 }
113
114 #[rstest]
115 fn test_sparse_binary_numeric(array: ArrayRef) {
116 test_binary_numeric_array(array)
117 }
118
119 #[test]
120 fn test_mask_sparse_array() {
121 let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
122 test_mask_conformance(
123 SparseArray::try_new(
124 buffer![1u64, 2, 4].into_array(),
125 cast(
126 &buffer![100i32, 200, 300].into_array(),
127 null_fill_value.dtype(),
128 )
129 .unwrap(),
130 5,
131 null_fill_value,
132 )
133 .unwrap()
134 .as_ref(),
135 );
136
137 let ten_fill_value = Scalar::from(10i32);
138 test_mask_conformance(
139 SparseArray::try_new(
140 buffer![1u64, 2, 4].into_array(),
141 buffer![100i32, 200, 300].into_array(),
142 5,
143 ten_fill_value,
144 )
145 .unwrap()
146 .as_ref(),
147 )
148 }
149
150 #[test]
151 fn test_filter_sparse_array() {
152 let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
153 test_filter_conformance(
154 SparseArray::try_new(
155 buffer![1u64, 2, 4].into_array(),
156 cast(
157 &buffer![100i32, 200, 300].into_array(),
158 null_fill_value.dtype(),
159 )
160 .unwrap(),
161 5,
162 null_fill_value,
163 )
164 .unwrap()
165 .as_ref(),
166 );
167
168 let ten_fill_value = Scalar::from(10i32);
169 test_filter_conformance(
170 SparseArray::try_new(
171 buffer![1u64, 2, 4].into_array(),
172 buffer![100i32, 200, 300].into_array(),
173 5,
174 ten_fill_value,
175 )
176 .unwrap()
177 .as_ref(),
178 )
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use rstest::rstest;
185 use vortex_array::IntoArray;
186 use vortex_array::arrays::PrimitiveArray;
187 use vortex_array::compute::cast;
188 use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
189 use vortex_array::compute::conformance::consistency::test_array_consistency;
190 use vortex_buffer::buffer;
191 use vortex_dtype::DType;
192 use vortex_dtype::Nullability;
193 use vortex_dtype::PType;
194 use vortex_scalar::Scalar;
195
196 use crate::SparseArray;
197
198 #[rstest]
199 #[case::sparse_i32_null_fill(SparseArray::try_new(
201 buffer![2u64, 5, 8].into_array(),
202 PrimitiveArray::from_option_iter([Some(100i32), Some(200), Some(300)]).into_array(),
203 10,
204 Scalar::null_typed::<i32>()
205 ).unwrap())]
206 #[case::sparse_i32_value_fill(SparseArray::try_new(
207 buffer![1u64, 3, 7].into_array(),
208 buffer![42i32, 84, 126].into_array(),
209 10,
210 Scalar::from(0i32)
211 ).unwrap())]
212 #[case::sparse_u64(SparseArray::try_new(
214 buffer![0u64, 4, 9].into_array(),
215 buffer![1000u64, 2000, 3000].into_array(),
216 10,
217 Scalar::from(999u64)
218 ).unwrap())]
219 #[case::sparse_f32(SparseArray::try_new(
220 buffer![2u64, 6].into_array(),
221 buffer![std::f32::consts::PI, std::f32::consts::E].into_array(),
222 8,
223 Scalar::from(0.0f32)
224 ).unwrap())]
225 #[case::sparse_single_patch(SparseArray::try_new(
227 buffer![5u64].into_array(),
228 buffer![42i32].into_array(),
229 10,
230 Scalar::from(-1i32)
231 ).unwrap())]
232 #[case::sparse_dense_patches(SparseArray::try_new(
233 buffer![0u64, 1, 2, 3, 4].into_array(),
234 PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40), Some(50)]).into_array(),
235 5,
236 Scalar::null_typed::<i32>()
237 ).unwrap())]
238 #[case::sparse_large(SparseArray::try_new(
240 buffer![100u64, 500, 900, 1500, 1999].into_array(),
241 buffer![111i32, 222, 333, 444, 555].into_array(),
242 2000,
243 Scalar::from(0i32)
244 ).unwrap())]
245 #[case::sparse_nullable_patches({
247 let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
248 SparseArray::try_new(
249 buffer![1u64, 4, 7].into_array(),
250 cast(
251 &PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
252 null_fill_value.dtype()
253 ).unwrap(),
254 10,
255 null_fill_value
256 ).unwrap()
257 })]
258
259 fn test_sparse_consistency(#[case] array: SparseArray) {
260 test_array_consistency(array.as_ref());
261 }
262
263 #[rstest]
264 #[case::sparse_i32_basic(SparseArray::try_new(
265 buffer![2u64, 5, 8].into_array(),
266 buffer![100i32, 200, 300].into_array(),
267 10,
268 Scalar::from(0i32)
269 ).unwrap())]
270 #[case::sparse_u32_basic(SparseArray::try_new(
271 buffer![1u64, 3, 7].into_array(),
272 buffer![1000u32, 2000, 3000].into_array(),
273 10,
274 Scalar::from(100u32)
275 ).unwrap())]
276 #[case::sparse_i64_basic(SparseArray::try_new(
277 buffer![0u64, 4, 9].into_array(),
278 buffer![5000i64, 6000, 7000].into_array(),
279 10,
280 Scalar::from(1000i64)
281 ).unwrap())]
282 #[case::sparse_f32_basic(SparseArray::try_new(
283 buffer![2u64, 6].into_array(),
284 buffer![1.5f32, 2.5].into_array(),
285 8,
286 Scalar::from(0.5f32)
287 ).unwrap())]
288 #[case::sparse_f64_basic(SparseArray::try_new(
289 buffer![1u64, 5, 9].into_array(),
290 buffer![10.1f64, 20.2, 30.3].into_array(),
291 10,
292 Scalar::from(5.0f64)
293 ).unwrap())]
294 #[case::sparse_i32_large(SparseArray::try_new(
295 buffer![10u64, 50, 90, 150, 199].into_array(),
296 buffer![111i32, 222, 333, 444, 555].into_array(),
297 200,
298 Scalar::from(0i32)
299 ).unwrap())]
300 fn test_sparse_binary_numeric(#[case] array: SparseArray) {
301 test_binary_numeric_array(array.into_array());
302 }
303}