1use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{FilterKernel, FilterKernelAdapter};
6use vortex_array::{ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8use vortex_mask::Mask;
9
10use crate::{SparseArray, SparseVTable};
11
12mod binary_numeric;
13mod cast;
14mod invert;
15mod take;
16
17impl FilterKernel for SparseVTable {
18 fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
19 let new_length = mask.true_count();
20
21 let Some(new_patches) = array.patches().filter(mask)? else {
22 return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
23 };
24
25 Ok(
26 SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
27 .into_array(),
28 )
29 }
30}
31
32register_kernel!(FilterKernelAdapter(SparseVTable).lift());
33
34#[cfg(test)]
35mod test {
36 use rstest::{fixture, rstest};
37 use vortex_array::arrays::PrimitiveArray;
38 use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
39 use vortex_array::compute::conformance::filter::test_filter_conformance;
40 use vortex_array::compute::conformance::mask::test_mask_conformance;
41 use vortex_array::compute::{cast, filter};
42 use vortex_array::validity::Validity;
43 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
44 use vortex_buffer::buffer;
45 use vortex_dtype::{DType, Nullability, PType};
46 use vortex_mask::Mask;
47 use vortex_scalar::Scalar;
48
49 use crate::{SparseArray, SparseVTable};
50
51 #[fixture]
52 fn array() -> ArrayRef {
53 SparseArray::try_new(
54 buffer![2u64, 9, 15].into_array(),
55 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
56 20,
57 Scalar::null_typed::<i32>(),
58 )
59 .unwrap()
60 .into_array()
61 }
62
63 #[rstest]
64 fn test_filter(array: ArrayRef) {
65 let mut predicate = vec![false, false, true];
66 predicate.extend_from_slice(&[false; 17]);
67 let mask = Mask::from_iter(predicate);
68
69 let filtered_array = filter(&array, &mask).unwrap();
70 let filtered_array = filtered_array.as_::<SparseVTable>();
71
72 assert_eq!(filtered_array.len(), 1);
73 assert_eq!(filtered_array.patches().values().len(), 1);
74 assert_eq!(filtered_array.patches().indices().len(), 1);
75 }
76
77 #[test]
78 fn true_fill_value() {
79 let mask = Mask::from_iter([false, true, false, true, false, true, true]);
80 let array = SparseArray::try_new(
81 buffer![0_u64, 3, 6].into_array(),
82 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
83 7,
84 Scalar::null_typed::<i32>(),
85 )
86 .unwrap()
87 .into_array();
88
89 let filtered_array = filter(&array, &mask).unwrap();
90 let filtered_array = filtered_array.as_::<SparseVTable>();
91
92 assert_eq!(filtered_array.len(), 4);
93 let primitive = filtered_array.patches().indices().to_primitive();
94
95 assert_eq!(primitive.as_slice::<u64>(), &[1, 3]);
96 }
97
98 #[rstest]
99 fn test_sparse_binary_numeric(array: ArrayRef) {
100 test_binary_numeric_array(array)
101 }
102
103 #[test]
104 fn test_mask_sparse_array() {
105 let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
106 test_mask_conformance(
107 SparseArray::try_new(
108 buffer![1u64, 2, 4].into_array(),
109 cast(
110 &buffer![100i32, 200, 300].into_array(),
111 null_fill_value.dtype(),
112 )
113 .unwrap(),
114 5,
115 null_fill_value,
116 )
117 .unwrap()
118 .as_ref(),
119 );
120
121 let ten_fill_value = Scalar::from(10i32);
122 test_mask_conformance(
123 SparseArray::try_new(
124 buffer![1u64, 2, 4].into_array(),
125 buffer![100i32, 200, 300].into_array(),
126 5,
127 ten_fill_value,
128 )
129 .unwrap()
130 .as_ref(),
131 )
132 }
133
134 #[test]
135 fn test_filter_sparse_array() {
136 let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
137 test_filter_conformance(
138 SparseArray::try_new(
139 buffer![1u64, 2, 4].into_array(),
140 cast(
141 &buffer![100i32, 200, 300].into_array(),
142 null_fill_value.dtype(),
143 )
144 .unwrap(),
145 5,
146 null_fill_value,
147 )
148 .unwrap()
149 .as_ref(),
150 );
151
152 let ten_fill_value = Scalar::from(10i32);
153 test_filter_conformance(
154 SparseArray::try_new(
155 buffer![1u64, 2, 4].into_array(),
156 buffer![100i32, 200, 300].into_array(),
157 5,
158 ten_fill_value,
159 )
160 .unwrap()
161 .as_ref(),
162 )
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use rstest::rstest;
169 use vortex_array::IntoArray;
170 use vortex_array::arrays::PrimitiveArray;
171 use vortex_array::compute::cast;
172 use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
173 use vortex_array::compute::conformance::consistency::test_array_consistency;
174 use vortex_buffer::buffer;
175 use vortex_dtype::{DType, Nullability, PType};
176 use vortex_scalar::Scalar;
177
178 use crate::SparseArray;
179
180 #[rstest]
181 #[case::sparse_i32_null_fill(SparseArray::try_new(
183 buffer![2u64, 5, 8].into_array(),
184 PrimitiveArray::from_option_iter([Some(100i32), Some(200), Some(300)]).into_array(),
185 10,
186 Scalar::null_typed::<i32>()
187 ).unwrap())]
188 #[case::sparse_i32_value_fill(SparseArray::try_new(
189 buffer![1u64, 3, 7].into_array(),
190 buffer![42i32, 84, 126].into_array(),
191 10,
192 Scalar::from(0i32)
193 ).unwrap())]
194 #[case::sparse_u64(SparseArray::try_new(
196 buffer![0u64, 4, 9].into_array(),
197 buffer![1000u64, 2000, 3000].into_array(),
198 10,
199 Scalar::from(999u64)
200 ).unwrap())]
201 #[case::sparse_f32(SparseArray::try_new(
202 buffer![2u64, 6].into_array(),
203 buffer![std::f32::consts::PI, std::f32::consts::E].into_array(),
204 8,
205 Scalar::from(0.0f32)
206 ).unwrap())]
207 #[case::sparse_single_patch(SparseArray::try_new(
209 buffer![5u64].into_array(),
210 buffer![42i32].into_array(),
211 10,
212 Scalar::from(-1i32)
213 ).unwrap())]
214 #[case::sparse_dense_patches(SparseArray::try_new(
215 buffer![0u64, 1, 2, 3, 4].into_array(),
216 PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40), Some(50)]).into_array(),
217 5,
218 Scalar::null_typed::<i32>()
219 ).unwrap())]
220 #[case::sparse_large(SparseArray::try_new(
222 buffer![100u64, 500, 900, 1500, 1999].into_array(),
223 buffer![111i32, 222, 333, 444, 555].into_array(),
224 2000,
225 Scalar::from(0i32)
226 ).unwrap())]
227 #[case::sparse_nullable_patches({
229 let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
230 SparseArray::try_new(
231 buffer![1u64, 4, 7].into_array(),
232 cast(
233 &PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
234 null_fill_value.dtype()
235 ).unwrap(),
236 10,
237 null_fill_value
238 ).unwrap()
239 })]
240
241 fn test_sparse_consistency(#[case] array: SparseArray) {
242 test_array_consistency(array.as_ref());
243 }
244
245 #[rstest]
246 #[case::sparse_i32_basic(SparseArray::try_new(
247 buffer![2u64, 5, 8].into_array(),
248 buffer![100i32, 200, 300].into_array(),
249 10,
250 Scalar::from(0i32)
251 ).unwrap())]
252 #[case::sparse_u32_basic(SparseArray::try_new(
253 buffer![1u64, 3, 7].into_array(),
254 buffer![1000u32, 2000, 3000].into_array(),
255 10,
256 Scalar::from(100u32)
257 ).unwrap())]
258 #[case::sparse_i64_basic(SparseArray::try_new(
259 buffer![0u64, 4, 9].into_array(),
260 buffer![5000i64, 6000, 7000].into_array(),
261 10,
262 Scalar::from(1000i64)
263 ).unwrap())]
264 #[case::sparse_f32_basic(SparseArray::try_new(
265 buffer![2u64, 6].into_array(),
266 buffer![1.5f32, 2.5].into_array(),
267 8,
268 Scalar::from(0.5f32)
269 ).unwrap())]
270 #[case::sparse_f64_basic(SparseArray::try_new(
271 buffer![1u64, 5, 9].into_array(),
272 buffer![10.1f64, 20.2, 30.3].into_array(),
273 10,
274 Scalar::from(5.0f64)
275 ).unwrap())]
276 #[case::sparse_i32_large(SparseArray::try_new(
277 buffer![10u64, 50, 90, 150, 199].into_array(),
278 buffer![111i32, 222, 333, 444, 555].into_array(),
279 200,
280 Scalar::from(0i32)
281 ).unwrap())]
282 fn test_sparse_binary_numeric(#[case] array: SparseArray) {
283 test_binary_numeric_array(array.into_array());
284 }
285}