vortex_sparse/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{FilterKernel, FilterKernelAdapter};
6use vortex_array::{ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8use vortex_mask::Mask;
9
10use crate::{SparseArray, SparseVTable};
11
12mod binary_numeric;
13mod invert;
14mod take;
15
16impl FilterKernel for SparseVTable {
17    fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
18        let new_length = mask.true_count();
19
20        let Some(new_patches) = array.patches().filter(mask)? else {
21            return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
22        };
23
24        Ok(
25            SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
26                .into_array(),
27        )
28    }
29}
30
31register_kernel!(FilterKernelAdapter(SparseVTable).lift());
32
33#[cfg(test)]
34mod test {
35    use rstest::{fixture, rstest};
36    use vortex_array::arrays::PrimitiveArray;
37    use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
38    use vortex_array::compute::conformance::filter::test_filter_conformance;
39    use vortex_array::compute::conformance::mask::test_mask_conformance;
40    use vortex_array::compute::{cast, filter};
41    use vortex_array::validity::Validity;
42    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
43    use vortex_buffer::buffer;
44    use vortex_dtype::{DType, Nullability, PType};
45    use vortex_mask::Mask;
46    use vortex_scalar::Scalar;
47
48    use crate::{SparseArray, SparseVTable};
49
50    #[fixture]
51    fn array() -> ArrayRef {
52        SparseArray::try_new(
53            buffer![2u64, 9, 15].into_array(),
54            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
55            20,
56            Scalar::null_typed::<i32>(),
57        )
58        .unwrap()
59        .into_array()
60    }
61
62    #[rstest]
63    fn test_filter(array: ArrayRef) {
64        let mut predicate = vec![false, false, true];
65        predicate.extend_from_slice(&[false; 17]);
66        let mask = Mask::from_iter(predicate);
67
68        let filtered_array = filter(&array, &mask).unwrap();
69        let filtered_array = filtered_array.as_::<SparseVTable>();
70
71        assert_eq!(filtered_array.len(), 1);
72        assert_eq!(filtered_array.patches().values().len(), 1);
73        assert_eq!(filtered_array.patches().indices().len(), 1);
74    }
75
76    #[test]
77    fn true_fill_value() {
78        let mask = Mask::from_iter([false, true, false, true, false, true, true]);
79        let array = SparseArray::try_new(
80            buffer![0_u64, 3, 6].into_array(),
81            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
82            7,
83            Scalar::null_typed::<i32>(),
84        )
85        .unwrap()
86        .into_array();
87
88        let filtered_array = filter(&array, &mask).unwrap();
89        let filtered_array = filtered_array.as_::<SparseVTable>();
90
91        assert_eq!(filtered_array.len(), 4);
92        let primitive = filtered_array.patches().indices().to_primitive().unwrap();
93
94        assert_eq!(primitive.as_slice::<u64>(), &[1, 3]);
95    }
96
97    #[rstest]
98    fn test_sparse_binary_numeric(array: ArrayRef) {
99        test_binary_numeric_array(array)
100    }
101
102    #[test]
103    fn test_mask_sparse_array() {
104        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
105        test_mask_conformance(
106            SparseArray::try_new(
107                buffer![1u64, 2, 4].into_array(),
108                cast(
109                    &buffer![100i32, 200, 300].into_array(),
110                    null_fill_value.dtype(),
111                )
112                .unwrap(),
113                5,
114                null_fill_value,
115            )
116            .unwrap()
117            .as_ref(),
118        );
119
120        let ten_fill_value = Scalar::from(10i32);
121        test_mask_conformance(
122            SparseArray::try_new(
123                buffer![1u64, 2, 4].into_array(),
124                buffer![100i32, 200, 300].into_array(),
125                5,
126                ten_fill_value,
127            )
128            .unwrap()
129            .as_ref(),
130        )
131    }
132
133    #[test]
134    fn test_filter_sparse_array() {
135        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
136        test_filter_conformance(
137            SparseArray::try_new(
138                buffer![1u64, 2, 4].into_array(),
139                cast(
140                    &buffer![100i32, 200, 300].into_array(),
141                    null_fill_value.dtype(),
142                )
143                .unwrap(),
144                5,
145                null_fill_value,
146            )
147            .unwrap()
148            .as_ref(),
149        );
150
151        let ten_fill_value = Scalar::from(10i32);
152        test_filter_conformance(
153            SparseArray::try_new(
154                buffer![1u64, 2, 4].into_array(),
155                buffer![100i32, 200, 300].into_array(),
156                5,
157                ten_fill_value,
158            )
159            .unwrap()
160            .as_ref(),
161        )
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use rstest::rstest;
168    use vortex_array::IntoArray;
169    use vortex_array::arrays::PrimitiveArray;
170    use vortex_array::compute::cast;
171    use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
172    use vortex_array::compute::conformance::consistency::test_array_consistency;
173    use vortex_buffer::buffer;
174    use vortex_dtype::{DType, Nullability, PType};
175    use vortex_scalar::Scalar;
176
177    use crate::SparseArray;
178
179    #[rstest]
180    // Basic sparse arrays
181    #[case::sparse_i32_null_fill(SparseArray::try_new(
182        buffer![2u64, 5, 8].into_array(),
183        PrimitiveArray::from_option_iter([Some(100i32), Some(200), Some(300)]).into_array(),
184        10,
185        Scalar::null_typed::<i32>()
186    ).unwrap())]
187    #[case::sparse_i32_value_fill(SparseArray::try_new(
188        buffer![1u64, 3, 7].into_array(),
189        buffer![42i32, 84, 126].into_array(),
190        10,
191        Scalar::from(0i32)
192    ).unwrap())]
193    // Different types
194    #[case::sparse_u64(SparseArray::try_new(
195        buffer![0u64, 4, 9].into_array(),
196        buffer![1000u64, 2000, 3000].into_array(),
197        10,
198        Scalar::from(999u64)
199    ).unwrap())]
200    #[case::sparse_f32(SparseArray::try_new(
201        buffer![2u64, 6].into_array(),
202        buffer![std::f32::consts::PI, std::f32::consts::E].into_array(),
203        8,
204        Scalar::from(0.0f32)
205    ).unwrap())]
206    // Edge cases
207    #[case::sparse_single_patch(SparseArray::try_new(
208        buffer![5u64].into_array(),
209        buffer![42i32].into_array(),
210        10,
211        Scalar::from(-1i32)
212    ).unwrap())]
213    #[case::sparse_dense_patches(SparseArray::try_new(
214        buffer![0u64, 1, 2, 3, 4].into_array(),
215        PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40), Some(50)]).into_array(),
216        5,
217        Scalar::null_typed::<i32>()
218    ).unwrap())]
219    // Large sparse arrays
220    #[case::sparse_large(SparseArray::try_new(
221        buffer![100u64, 500, 900, 1500, 1999].into_array(),
222        buffer![111i32, 222, 333, 444, 555].into_array(),
223        2000,
224        Scalar::from(0i32)
225    ).unwrap())]
226    // Nullable patches
227    #[case::sparse_nullable_patches({
228        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
229        SparseArray::try_new(
230            buffer![1u64, 4, 7].into_array(),
231            cast(
232                &PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
233                null_fill_value.dtype()
234            ).unwrap(),
235            10,
236            null_fill_value
237        ).unwrap()
238    })]
239
240    fn test_sparse_consistency(#[case] array: SparseArray) {
241        test_array_consistency(array.as_ref());
242    }
243
244    #[rstest]
245    #[case::sparse_i32_basic(SparseArray::try_new(
246        buffer![2u64, 5, 8].into_array(),
247        buffer![100i32, 200, 300].into_array(),
248        10,
249        Scalar::from(0i32)
250    ).unwrap())]
251    #[case::sparse_u32_basic(SparseArray::try_new(
252        buffer![1u64, 3, 7].into_array(),
253        buffer![1000u32, 2000, 3000].into_array(),
254        10,
255        Scalar::from(100u32)
256    ).unwrap())]
257    #[case::sparse_i64_basic(SparseArray::try_new(
258        buffer![0u64, 4, 9].into_array(),
259        buffer![5000i64, 6000, 7000].into_array(),
260        10,
261        Scalar::from(1000i64)
262    ).unwrap())]
263    #[case::sparse_f32_basic(SparseArray::try_new(
264        buffer![2u64, 6].into_array(),
265        buffer![1.5f32, 2.5].into_array(),
266        8,
267        Scalar::from(0.5f32)
268    ).unwrap())]
269    #[case::sparse_f64_basic(SparseArray::try_new(
270        buffer![1u64, 5, 9].into_array(),
271        buffer![10.1f64, 20.2, 30.3].into_array(),
272        10,
273        Scalar::from(5.0f64)
274    ).unwrap())]
275    #[case::sparse_i32_large(SparseArray::try_new(
276        buffer![10u64, 50, 90, 150, 199].into_array(),
277        buffer![111i32, 222, 333, 444, 555].into_array(),
278        200,
279        Scalar::from(0i32)
280    ).unwrap())]
281    fn test_sparse_binary_numeric(#[case] array: SparseArray) {
282        test_binary_numeric_array(array.into_array());
283    }
284}