vortex_sparse/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{FilterKernel, FilterKernelAdapter};
6use vortex_array::{ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8use vortex_mask::Mask;
9
10use crate::{SparseArray, SparseVTable};
11
12mod binary_numeric;
13mod cast;
14mod invert;
15mod take;
16
17impl FilterKernel for SparseVTable {
18    fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
19        let new_length = mask.true_count();
20
21        let Some(new_patches) = array.patches().filter(mask)? else {
22            return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
23        };
24
25        Ok(
26            SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
27                .into_array(),
28        )
29    }
30}
31
32register_kernel!(FilterKernelAdapter(SparseVTable).lift());
33
34#[cfg(test)]
35mod test {
36    use rstest::{fixture, rstest};
37    use vortex_array::arrays::PrimitiveArray;
38    use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
39    use vortex_array::compute::conformance::filter::test_filter_conformance;
40    use vortex_array::compute::conformance::mask::test_mask_conformance;
41    use vortex_array::compute::{cast, filter};
42    use vortex_array::validity::Validity;
43    use vortex_array::{Array, ArrayRef, IntoArray, assert_arrays_eq};
44    use vortex_buffer::buffer;
45    use vortex_dtype::{DType, Nullability, PType};
46    use vortex_mask::Mask;
47    use vortex_scalar::Scalar;
48
49    use crate::{SparseArray, SparseVTable};
50
51    #[fixture]
52    fn array() -> ArrayRef {
53        SparseArray::try_new(
54            buffer![2u64, 9, 15].into_array(),
55            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
56            20,
57            Scalar::null_typed::<i32>(),
58        )
59        .unwrap()
60        .into_array()
61    }
62
63    #[rstest]
64    fn test_filter(array: ArrayRef) {
65        let mut predicate = vec![false, false, true];
66        predicate.extend_from_slice(&[false; 17]);
67        let mask = Mask::from_iter(predicate);
68
69        let filtered_array = filter(&array, &mask).unwrap();
70        let filtered_array = filtered_array.as_::<SparseVTable>();
71
72        assert_eq!(filtered_array.len(), 1);
73        assert_eq!(filtered_array.patches().values().len(), 1);
74        assert_arrays_eq!(
75            filtered_array.patches().indices(),
76            PrimitiveArray::from_iter([0u64])
77        );
78    }
79
80    #[test]
81    fn true_fill_value() {
82        let mask = Mask::from_iter([false, true, false, true, false, true, true]);
83        let array = SparseArray::try_new(
84            buffer![0_u64, 3, 6].into_array(),
85            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
86            7,
87            Scalar::null_typed::<i32>(),
88        )
89        .unwrap()
90        .into_array();
91
92        let filtered_array = filter(&array, &mask).unwrap();
93        let filtered_array = filtered_array.as_::<SparseVTable>();
94
95        assert_eq!(filtered_array.len(), 4);
96        assert_arrays_eq!(
97            filtered_array.patches().indices(),
98            PrimitiveArray::from_iter([1u64, 3])
99        );
100    }
101
102    #[rstest]
103    fn test_sparse_binary_numeric(array: ArrayRef) {
104        test_binary_numeric_array(array)
105    }
106
107    #[test]
108    fn test_mask_sparse_array() {
109        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
110        test_mask_conformance(
111            SparseArray::try_new(
112                buffer![1u64, 2, 4].into_array(),
113                cast(
114                    &buffer![100i32, 200, 300].into_array(),
115                    null_fill_value.dtype(),
116                )
117                .unwrap(),
118                5,
119                null_fill_value,
120            )
121            .unwrap()
122            .as_ref(),
123        );
124
125        let ten_fill_value = Scalar::from(10i32);
126        test_mask_conformance(
127            SparseArray::try_new(
128                buffer![1u64, 2, 4].into_array(),
129                buffer![100i32, 200, 300].into_array(),
130                5,
131                ten_fill_value,
132            )
133            .unwrap()
134            .as_ref(),
135        )
136    }
137
138    #[test]
139    fn test_filter_sparse_array() {
140        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
141        test_filter_conformance(
142            SparseArray::try_new(
143                buffer![1u64, 2, 4].into_array(),
144                cast(
145                    &buffer![100i32, 200, 300].into_array(),
146                    null_fill_value.dtype(),
147                )
148                .unwrap(),
149                5,
150                null_fill_value,
151            )
152            .unwrap()
153            .as_ref(),
154        );
155
156        let ten_fill_value = Scalar::from(10i32);
157        test_filter_conformance(
158            SparseArray::try_new(
159                buffer![1u64, 2, 4].into_array(),
160                buffer![100i32, 200, 300].into_array(),
161                5,
162                ten_fill_value,
163            )
164            .unwrap()
165            .as_ref(),
166        )
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use rstest::rstest;
173    use vortex_array::IntoArray;
174    use vortex_array::arrays::PrimitiveArray;
175    use vortex_array::compute::cast;
176    use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
177    use vortex_array::compute::conformance::consistency::test_array_consistency;
178    use vortex_buffer::buffer;
179    use vortex_dtype::{DType, Nullability, PType};
180    use vortex_scalar::Scalar;
181
182    use crate::SparseArray;
183
184    #[rstest]
185    // Basic sparse arrays
186    #[case::sparse_i32_null_fill(SparseArray::try_new(
187        buffer![2u64, 5, 8].into_array(),
188        PrimitiveArray::from_option_iter([Some(100i32), Some(200), Some(300)]).into_array(),
189        10,
190        Scalar::null_typed::<i32>()
191    ).unwrap())]
192    #[case::sparse_i32_value_fill(SparseArray::try_new(
193        buffer![1u64, 3, 7].into_array(),
194        buffer![42i32, 84, 126].into_array(),
195        10,
196        Scalar::from(0i32)
197    ).unwrap())]
198    // Different types
199    #[case::sparse_u64(SparseArray::try_new(
200        buffer![0u64, 4, 9].into_array(),
201        buffer![1000u64, 2000, 3000].into_array(),
202        10,
203        Scalar::from(999u64)
204    ).unwrap())]
205    #[case::sparse_f32(SparseArray::try_new(
206        buffer![2u64, 6].into_array(),
207        buffer![std::f32::consts::PI, std::f32::consts::E].into_array(),
208        8,
209        Scalar::from(0.0f32)
210    ).unwrap())]
211    // Edge cases
212    #[case::sparse_single_patch(SparseArray::try_new(
213        buffer![5u64].into_array(),
214        buffer![42i32].into_array(),
215        10,
216        Scalar::from(-1i32)
217    ).unwrap())]
218    #[case::sparse_dense_patches(SparseArray::try_new(
219        buffer![0u64, 1, 2, 3, 4].into_array(),
220        PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40), Some(50)]).into_array(),
221        5,
222        Scalar::null_typed::<i32>()
223    ).unwrap())]
224    // Large sparse arrays
225    #[case::sparse_large(SparseArray::try_new(
226        buffer![100u64, 500, 900, 1500, 1999].into_array(),
227        buffer![111i32, 222, 333, 444, 555].into_array(),
228        2000,
229        Scalar::from(0i32)
230    ).unwrap())]
231    // Nullable patches
232    #[case::sparse_nullable_patches({
233        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
234        SparseArray::try_new(
235            buffer![1u64, 4, 7].into_array(),
236            cast(
237                &PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
238                null_fill_value.dtype()
239            ).unwrap(),
240            10,
241            null_fill_value
242        ).unwrap()
243    })]
244
245    fn test_sparse_consistency(#[case] array: SparseArray) {
246        test_array_consistency(array.as_ref());
247    }
248
249    #[rstest]
250    #[case::sparse_i32_basic(SparseArray::try_new(
251        buffer![2u64, 5, 8].into_array(),
252        buffer![100i32, 200, 300].into_array(),
253        10,
254        Scalar::from(0i32)
255    ).unwrap())]
256    #[case::sparse_u32_basic(SparseArray::try_new(
257        buffer![1u64, 3, 7].into_array(),
258        buffer![1000u32, 2000, 3000].into_array(),
259        10,
260        Scalar::from(100u32)
261    ).unwrap())]
262    #[case::sparse_i64_basic(SparseArray::try_new(
263        buffer![0u64, 4, 9].into_array(),
264        buffer![5000i64, 6000, 7000].into_array(),
265        10,
266        Scalar::from(1000i64)
267    ).unwrap())]
268    #[case::sparse_f32_basic(SparseArray::try_new(
269        buffer![2u64, 6].into_array(),
270        buffer![1.5f32, 2.5].into_array(),
271        8,
272        Scalar::from(0.5f32)
273    ).unwrap())]
274    #[case::sparse_f64_basic(SparseArray::try_new(
275        buffer![1u64, 5, 9].into_array(),
276        buffer![10.1f64, 20.2, 30.3].into_array(),
277        10,
278        Scalar::from(5.0f64)
279    ).unwrap())]
280    #[case::sparse_i32_large(SparseArray::try_new(
281        buffer![10u64, 50, 90, 150, 199].into_array(),
282        buffer![111i32, 222, 333, 444, 555].into_array(),
283        200,
284        Scalar::from(0i32)
285    ).unwrap())]
286    fn test_sparse_binary_numeric(#[case] array: SparseArray) {
287        test_binary_numeric_array(array.into_array());
288    }
289}