1use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{FilterKernel, FilterKernelAdapter};
6use vortex_array::{ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8use vortex_mask::Mask;
9
10use crate::{SparseArray, SparseVTable};
11
12mod binary_numeric;
13mod invert;
14mod take;
15
16impl FilterKernel for SparseVTable {
17 fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
18 let new_length = mask.true_count();
19
20 let Some(new_patches) = array.patches().filter(mask)? else {
21 return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
22 };
23
24 Ok(
25 SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
26 .into_array(),
27 )
28 }
29}
30
31register_kernel!(FilterKernelAdapter(SparseVTable).lift());
32
33#[cfg(test)]
34mod test {
35 use rstest::{fixture, rstest};
36 use vortex_array::arrays::PrimitiveArray;
37 use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
38 use vortex_array::compute::conformance::filter::test_filter_conformance;
39 use vortex_array::compute::conformance::mask::test_mask_conformance;
40 use vortex_array::compute::{cast, filter};
41 use vortex_array::validity::Validity;
42 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
43 use vortex_buffer::buffer;
44 use vortex_dtype::{DType, Nullability, PType};
45 use vortex_mask::Mask;
46 use vortex_scalar::Scalar;
47
48 use crate::{SparseArray, SparseVTable};
49
50 #[fixture]
51 fn array() -> ArrayRef {
52 SparseArray::try_new(
53 buffer![2u64, 9, 15].into_array(),
54 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
55 20,
56 Scalar::null_typed::<i32>(),
57 )
58 .unwrap()
59 .into_array()
60 }
61
62 #[rstest]
63 fn test_filter(array: ArrayRef) {
64 let mut predicate = vec![false, false, true];
65 predicate.extend_from_slice(&[false; 17]);
66 let mask = Mask::from_iter(predicate);
67
68 let filtered_array = filter(&array, &mask).unwrap();
69 let filtered_array = filtered_array.as_::<SparseVTable>();
70
71 assert_eq!(filtered_array.len(), 1);
72 assert_eq!(filtered_array.patches().values().len(), 1);
73 assert_eq!(filtered_array.patches().indices().len(), 1);
74 }
75
76 #[test]
77 fn true_fill_value() {
78 let mask = Mask::from_iter([false, true, false, true, false, true, true]);
79 let array = SparseArray::try_new(
80 buffer![0_u64, 3, 6].into_array(),
81 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
82 7,
83 Scalar::null_typed::<i32>(),
84 )
85 .unwrap()
86 .into_array();
87
88 let filtered_array = filter(&array, &mask).unwrap();
89 let filtered_array = filtered_array.as_::<SparseVTable>();
90
91 assert_eq!(filtered_array.len(), 4);
92 let primitive = filtered_array.patches().indices().to_primitive().unwrap();
93
94 assert_eq!(primitive.as_slice::<u64>(), &[1, 3]);
95 }
96
97 #[rstest]
98 fn test_sparse_binary_numeric(array: ArrayRef) {
99 test_binary_numeric_array(array)
100 }
101
102 #[test]
103 fn test_mask_sparse_array() {
104 let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
105 test_mask_conformance(
106 SparseArray::try_new(
107 buffer![1u64, 2, 4].into_array(),
108 cast(
109 &buffer![100i32, 200, 300].into_array(),
110 null_fill_value.dtype(),
111 )
112 .unwrap(),
113 5,
114 null_fill_value,
115 )
116 .unwrap()
117 .as_ref(),
118 );
119
120 let ten_fill_value = Scalar::from(10i32);
121 test_mask_conformance(
122 SparseArray::try_new(
123 buffer![1u64, 2, 4].into_array(),
124 buffer![100i32, 200, 300].into_array(),
125 5,
126 ten_fill_value,
127 )
128 .unwrap()
129 .as_ref(),
130 )
131 }
132
133 #[test]
134 fn test_filter_sparse_array() {
135 let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
136 test_filter_conformance(
137 SparseArray::try_new(
138 buffer![1u64, 2, 4].into_array(),
139 cast(
140 &buffer![100i32, 200, 300].into_array(),
141 null_fill_value.dtype(),
142 )
143 .unwrap(),
144 5,
145 null_fill_value,
146 )
147 .unwrap()
148 .as_ref(),
149 );
150
151 let ten_fill_value = Scalar::from(10i32);
152 test_filter_conformance(
153 SparseArray::try_new(
154 buffer![1u64, 2, 4].into_array(),
155 buffer![100i32, 200, 300].into_array(),
156 5,
157 ten_fill_value,
158 )
159 .unwrap()
160 .as_ref(),
161 )
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use rstest::rstest;
168 use vortex_array::IntoArray;
169 use vortex_array::arrays::PrimitiveArray;
170 use vortex_array::compute::cast;
171 use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
172 use vortex_array::compute::conformance::consistency::test_array_consistency;
173 use vortex_buffer::buffer;
174 use vortex_dtype::{DType, Nullability, PType};
175 use vortex_scalar::Scalar;
176
177 use crate::SparseArray;
178
179 #[rstest]
180 #[case::sparse_i32_null_fill(SparseArray::try_new(
182 buffer![2u64, 5, 8].into_array(),
183 PrimitiveArray::from_option_iter([Some(100i32), Some(200), Some(300)]).into_array(),
184 10,
185 Scalar::null_typed::<i32>()
186 ).unwrap())]
187 #[case::sparse_i32_value_fill(SparseArray::try_new(
188 buffer![1u64, 3, 7].into_array(),
189 buffer![42i32, 84, 126].into_array(),
190 10,
191 Scalar::from(0i32)
192 ).unwrap())]
193 #[case::sparse_u64(SparseArray::try_new(
195 buffer![0u64, 4, 9].into_array(),
196 buffer![1000u64, 2000, 3000].into_array(),
197 10,
198 Scalar::from(999u64)
199 ).unwrap())]
200 #[case::sparse_f32(SparseArray::try_new(
201 buffer![2u64, 6].into_array(),
202 buffer![std::f32::consts::PI, std::f32::consts::E].into_array(),
203 8,
204 Scalar::from(0.0f32)
205 ).unwrap())]
206 #[case::sparse_single_patch(SparseArray::try_new(
208 buffer![5u64].into_array(),
209 buffer![42i32].into_array(),
210 10,
211 Scalar::from(-1i32)
212 ).unwrap())]
213 #[case::sparse_dense_patches(SparseArray::try_new(
214 buffer![0u64, 1, 2, 3, 4].into_array(),
215 PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40), Some(50)]).into_array(),
216 5,
217 Scalar::null_typed::<i32>()
218 ).unwrap())]
219 #[case::sparse_large(SparseArray::try_new(
221 buffer![100u64, 500, 900, 1500, 1999].into_array(),
222 buffer![111i32, 222, 333, 444, 555].into_array(),
223 2000,
224 Scalar::from(0i32)
225 ).unwrap())]
226 #[case::sparse_nullable_patches({
228 let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
229 SparseArray::try_new(
230 buffer![1u64, 4, 7].into_array(),
231 cast(
232 &PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
233 null_fill_value.dtype()
234 ).unwrap(),
235 10,
236 null_fill_value
237 ).unwrap()
238 })]
239
240 fn test_sparse_consistency(#[case] array: SparseArray) {
241 test_array_consistency(array.as_ref());
242 }
243
244 #[rstest]
245 #[case::sparse_i32_basic(SparseArray::try_new(
246 buffer![2u64, 5, 8].into_array(),
247 buffer![100i32, 200, 300].into_array(),
248 10,
249 Scalar::from(0i32)
250 ).unwrap())]
251 #[case::sparse_u32_basic(SparseArray::try_new(
252 buffer![1u64, 3, 7].into_array(),
253 buffer![1000u32, 2000, 3000].into_array(),
254 10,
255 Scalar::from(100u32)
256 ).unwrap())]
257 #[case::sparse_i64_basic(SparseArray::try_new(
258 buffer![0u64, 4, 9].into_array(),
259 buffer![5000i64, 6000, 7000].into_array(),
260 10,
261 Scalar::from(1000i64)
262 ).unwrap())]
263 #[case::sparse_f32_basic(SparseArray::try_new(
264 buffer![2u64, 6].into_array(),
265 buffer![1.5f32, 2.5].into_array(),
266 8,
267 Scalar::from(0.5f32)
268 ).unwrap())]
269 #[case::sparse_f64_basic(SparseArray::try_new(
270 buffer![1u64, 5, 9].into_array(),
271 buffer![10.1f64, 20.2, 30.3].into_array(),
272 10,
273 Scalar::from(5.0f64)
274 ).unwrap())]
275 #[case::sparse_i32_large(SparseArray::try_new(
276 buffer![10u64, 50, 90, 150, 199].into_array(),
277 buffer![111i32, 222, 333, 444, 555].into_array(),
278 200,
279 Scalar::from(0i32)
280 ).unwrap())]
281 fn test_sparse_binary_numeric(#[case] array: SparseArray) {
282 test_binary_numeric_array(array.into_array());
283 }
284}