vortex_sparse/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::arrays::ConstantArray;
7use vortex_array::compute::FilterKernel;
8use vortex_array::compute::FilterKernelAdapter;
9use vortex_array::register_kernel;
10use vortex_error::VortexResult;
11use vortex_mask::Mask;
12
13use crate::SparseArray;
14use crate::SparseVTable;
15
16mod binary_numeric;
17mod cast;
18mod invert;
19mod take;
20
21impl FilterKernel for SparseVTable {
22    fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
23        let new_length = mask.true_count();
24
25        let Some(new_patches) = array.patches().filter(mask)? else {
26            return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
27        };
28
29        Ok(
30            SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
31                .into_array(),
32        )
33    }
34}
35
36register_kernel!(FilterKernelAdapter(SparseVTable).lift());
37
38#[cfg(test)]
39mod test {
40    use rstest::fixture;
41    use rstest::rstest;
42    use vortex_array::Array;
43    use vortex_array::ArrayRef;
44    use vortex_array::IntoArray;
45    use vortex_array::arrays::PrimitiveArray;
46    use vortex_array::assert_arrays_eq;
47    use vortex_array::compute::cast;
48    use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
49    use vortex_array::compute::conformance::filter::test_filter_conformance;
50    use vortex_array::compute::conformance::mask::test_mask_conformance;
51    use vortex_array::compute::filter;
52    use vortex_array::validity::Validity;
53    use vortex_buffer::buffer;
54    use vortex_dtype::DType;
55    use vortex_dtype::Nullability;
56    use vortex_dtype::PType;
57    use vortex_mask::Mask;
58    use vortex_scalar::Scalar;
59
60    use crate::SparseArray;
61    use crate::SparseVTable;
62
63    #[fixture]
64    fn array() -> ArrayRef {
65        SparseArray::try_new(
66            buffer![2u64, 9, 15].into_array(),
67            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
68            20,
69            Scalar::null_typed::<i32>(),
70        )
71        .unwrap()
72        .into_array()
73    }
74
75    #[rstest]
76    fn test_filter(array: ArrayRef) {
77        let mut predicate = vec![false, false, true];
78        predicate.extend_from_slice(&[false; 17]);
79        let mask = Mask::from_iter(predicate);
80
81        let filtered_array = filter(&array, &mask).unwrap();
82        let filtered_array = filtered_array.as_::<SparseVTable>();
83
84        assert_eq!(filtered_array.len(), 1);
85        assert_eq!(filtered_array.patches().values().len(), 1);
86        assert_arrays_eq!(
87            filtered_array.patches().indices(),
88            PrimitiveArray::from_iter([0u64])
89        );
90    }
91
92    #[test]
93    fn true_fill_value() {
94        let mask = Mask::from_iter([false, true, false, true, false, true, true]);
95        let array = SparseArray::try_new(
96            buffer![0_u64, 3, 6].into_array(),
97            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
98            7,
99            Scalar::null_typed::<i32>(),
100        )
101        .unwrap()
102        .into_array();
103
104        let filtered_array = filter(&array, &mask).unwrap();
105        let filtered_array = filtered_array.as_::<SparseVTable>();
106
107        assert_eq!(filtered_array.len(), 4);
108        assert_arrays_eq!(
109            filtered_array.patches().indices(),
110            PrimitiveArray::from_iter([1u64, 3])
111        );
112    }
113
114    #[rstest]
115    fn test_sparse_binary_numeric(array: ArrayRef) {
116        test_binary_numeric_array(array)
117    }
118
119    #[test]
120    fn test_mask_sparse_array() {
121        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
122        test_mask_conformance(
123            SparseArray::try_new(
124                buffer![1u64, 2, 4].into_array(),
125                cast(
126                    &buffer![100i32, 200, 300].into_array(),
127                    null_fill_value.dtype(),
128                )
129                .unwrap(),
130                5,
131                null_fill_value,
132            )
133            .unwrap()
134            .as_ref(),
135        );
136
137        let ten_fill_value = Scalar::from(10i32);
138        test_mask_conformance(
139            SparseArray::try_new(
140                buffer![1u64, 2, 4].into_array(),
141                buffer![100i32, 200, 300].into_array(),
142                5,
143                ten_fill_value,
144            )
145            .unwrap()
146            .as_ref(),
147        )
148    }
149
150    #[test]
151    fn test_filter_sparse_array() {
152        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
153        test_filter_conformance(
154            SparseArray::try_new(
155                buffer![1u64, 2, 4].into_array(),
156                cast(
157                    &buffer![100i32, 200, 300].into_array(),
158                    null_fill_value.dtype(),
159                )
160                .unwrap(),
161                5,
162                null_fill_value,
163            )
164            .unwrap()
165            .as_ref(),
166        );
167
168        let ten_fill_value = Scalar::from(10i32);
169        test_filter_conformance(
170            SparseArray::try_new(
171                buffer![1u64, 2, 4].into_array(),
172                buffer![100i32, 200, 300].into_array(),
173                5,
174                ten_fill_value,
175            )
176            .unwrap()
177            .as_ref(),
178        )
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use rstest::rstest;
185    use vortex_array::IntoArray;
186    use vortex_array::arrays::PrimitiveArray;
187    use vortex_array::compute::cast;
188    use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
189    use vortex_array::compute::conformance::consistency::test_array_consistency;
190    use vortex_buffer::buffer;
191    use vortex_dtype::DType;
192    use vortex_dtype::Nullability;
193    use vortex_dtype::PType;
194    use vortex_scalar::Scalar;
195
196    use crate::SparseArray;
197
198    #[rstest]
199    // Basic sparse arrays
200    #[case::sparse_i32_null_fill(SparseArray::try_new(
201        buffer![2u64, 5, 8].into_array(),
202        PrimitiveArray::from_option_iter([Some(100i32), Some(200), Some(300)]).into_array(),
203        10,
204        Scalar::null_typed::<i32>()
205    ).unwrap())]
206    #[case::sparse_i32_value_fill(SparseArray::try_new(
207        buffer![1u64, 3, 7].into_array(),
208        buffer![42i32, 84, 126].into_array(),
209        10,
210        Scalar::from(0i32)
211    ).unwrap())]
212    // Different types
213    #[case::sparse_u64(SparseArray::try_new(
214        buffer![0u64, 4, 9].into_array(),
215        buffer![1000u64, 2000, 3000].into_array(),
216        10,
217        Scalar::from(999u64)
218    ).unwrap())]
219    #[case::sparse_f32(SparseArray::try_new(
220        buffer![2u64, 6].into_array(),
221        buffer![std::f32::consts::PI, std::f32::consts::E].into_array(),
222        8,
223        Scalar::from(0.0f32)
224    ).unwrap())]
225    // Edge cases
226    #[case::sparse_single_patch(SparseArray::try_new(
227        buffer![5u64].into_array(),
228        buffer![42i32].into_array(),
229        10,
230        Scalar::from(-1i32)
231    ).unwrap())]
232    #[case::sparse_dense_patches(SparseArray::try_new(
233        buffer![0u64, 1, 2, 3, 4].into_array(),
234        PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40), Some(50)]).into_array(),
235        5,
236        Scalar::null_typed::<i32>()
237    ).unwrap())]
238    // Large sparse arrays
239    #[case::sparse_large(SparseArray::try_new(
240        buffer![100u64, 500, 900, 1500, 1999].into_array(),
241        buffer![111i32, 222, 333, 444, 555].into_array(),
242        2000,
243        Scalar::from(0i32)
244    ).unwrap())]
245    // Nullable patches
246    #[case::sparse_nullable_patches({
247        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
248        SparseArray::try_new(
249            buffer![1u64, 4, 7].into_array(),
250            cast(
251                &PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
252                null_fill_value.dtype()
253            ).unwrap(),
254            10,
255            null_fill_value
256        ).unwrap()
257    })]
258
259    fn test_sparse_consistency(#[case] array: SparseArray) {
260        test_array_consistency(array.as_ref());
261    }
262
263    #[rstest]
264    #[case::sparse_i32_basic(SparseArray::try_new(
265        buffer![2u64, 5, 8].into_array(),
266        buffer![100i32, 200, 300].into_array(),
267        10,
268        Scalar::from(0i32)
269    ).unwrap())]
270    #[case::sparse_u32_basic(SparseArray::try_new(
271        buffer![1u64, 3, 7].into_array(),
272        buffer![1000u32, 2000, 3000].into_array(),
273        10,
274        Scalar::from(100u32)
275    ).unwrap())]
276    #[case::sparse_i64_basic(SparseArray::try_new(
277        buffer![0u64, 4, 9].into_array(),
278        buffer![5000i64, 6000, 7000].into_array(),
279        10,
280        Scalar::from(1000i64)
281    ).unwrap())]
282    #[case::sparse_f32_basic(SparseArray::try_new(
283        buffer![2u64, 6].into_array(),
284        buffer![1.5f32, 2.5].into_array(),
285        8,
286        Scalar::from(0.5f32)
287    ).unwrap())]
288    #[case::sparse_f64_basic(SparseArray::try_new(
289        buffer![1u64, 5, 9].into_array(),
290        buffer![10.1f64, 20.2, 30.3].into_array(),
291        10,
292        Scalar::from(5.0f64)
293    ).unwrap())]
294    #[case::sparse_i32_large(SparseArray::try_new(
295        buffer![10u64, 50, 90, 150, 199].into_array(),
296        buffer![111i32, 222, 333, 444, 555].into_array(),
297        200,
298        Scalar::from(0i32)
299    ).unwrap())]
300    fn test_sparse_binary_numeric(#[case] array: SparseArray) {
301        test_binary_numeric_array(array.into_array());
302    }
303}