vortex_sparse/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{FilterKernel, FilterKernelAdapter};
6use vortex_array::{ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8use vortex_mask::Mask;
9
10use crate::{SparseArray, SparseVTable};
11
12mod binary_numeric;
13mod cast;
14mod invert;
15mod take;
16
17impl FilterKernel for SparseVTable {
18    fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
19        let new_length = mask.true_count();
20
21        let Some(new_patches) = array.patches().filter(mask)? else {
22            return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
23        };
24
25        Ok(
26            SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
27                .into_array(),
28        )
29    }
30}
31
32register_kernel!(FilterKernelAdapter(SparseVTable).lift());
33
34#[cfg(test)]
35mod test {
36    use rstest::{fixture, rstest};
37    use vortex_array::arrays::PrimitiveArray;
38    use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
39    use vortex_array::compute::conformance::filter::test_filter_conformance;
40    use vortex_array::compute::conformance::mask::test_mask_conformance;
41    use vortex_array::compute::{cast, filter};
42    use vortex_array::validity::Validity;
43    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
44    use vortex_buffer::buffer;
45    use vortex_dtype::{DType, Nullability, PType};
46    use vortex_mask::Mask;
47    use vortex_scalar::Scalar;
48
49    use crate::{SparseArray, SparseVTable};
50
51    #[fixture]
52    fn array() -> ArrayRef {
53        SparseArray::try_new(
54            buffer![2u64, 9, 15].into_array(),
55            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
56            20,
57            Scalar::null_typed::<i32>(),
58        )
59        .unwrap()
60        .into_array()
61    }
62
63    #[rstest]
64    fn test_filter(array: ArrayRef) {
65        let mut predicate = vec![false, false, true];
66        predicate.extend_from_slice(&[false; 17]);
67        let mask = Mask::from_iter(predicate);
68
69        let filtered_array = filter(&array, &mask).unwrap();
70        let filtered_array = filtered_array.as_::<SparseVTable>();
71
72        assert_eq!(filtered_array.len(), 1);
73        assert_eq!(filtered_array.patches().values().len(), 1);
74        assert_eq!(filtered_array.patches().indices().len(), 1);
75    }
76
77    #[test]
78    fn true_fill_value() {
79        let mask = Mask::from_iter([false, true, false, true, false, true, true]);
80        let array = SparseArray::try_new(
81            buffer![0_u64, 3, 6].into_array(),
82            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
83            7,
84            Scalar::null_typed::<i32>(),
85        )
86        .unwrap()
87        .into_array();
88
89        let filtered_array = filter(&array, &mask).unwrap();
90        let filtered_array = filtered_array.as_::<SparseVTable>();
91
92        assert_eq!(filtered_array.len(), 4);
93        let primitive = filtered_array.patches().indices().to_primitive();
94
95        assert_eq!(primitive.as_slice::<u64>(), &[1, 3]);
96    }
97
98    #[rstest]
99    fn test_sparse_binary_numeric(array: ArrayRef) {
100        test_binary_numeric_array(array)
101    }
102
103    #[test]
104    fn test_mask_sparse_array() {
105        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
106        test_mask_conformance(
107            SparseArray::try_new(
108                buffer![1u64, 2, 4].into_array(),
109                cast(
110                    &buffer![100i32, 200, 300].into_array(),
111                    null_fill_value.dtype(),
112                )
113                .unwrap(),
114                5,
115                null_fill_value,
116            )
117            .unwrap()
118            .as_ref(),
119        );
120
121        let ten_fill_value = Scalar::from(10i32);
122        test_mask_conformance(
123            SparseArray::try_new(
124                buffer![1u64, 2, 4].into_array(),
125                buffer![100i32, 200, 300].into_array(),
126                5,
127                ten_fill_value,
128            )
129            .unwrap()
130            .as_ref(),
131        )
132    }
133
134    #[test]
135    fn test_filter_sparse_array() {
136        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
137        test_filter_conformance(
138            SparseArray::try_new(
139                buffer![1u64, 2, 4].into_array(),
140                cast(
141                    &buffer![100i32, 200, 300].into_array(),
142                    null_fill_value.dtype(),
143                )
144                .unwrap(),
145                5,
146                null_fill_value,
147            )
148            .unwrap()
149            .as_ref(),
150        );
151
152        let ten_fill_value = Scalar::from(10i32);
153        test_filter_conformance(
154            SparseArray::try_new(
155                buffer![1u64, 2, 4].into_array(),
156                buffer![100i32, 200, 300].into_array(),
157                5,
158                ten_fill_value,
159            )
160            .unwrap()
161            .as_ref(),
162        )
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use rstest::rstest;
169    use vortex_array::IntoArray;
170    use vortex_array::arrays::PrimitiveArray;
171    use vortex_array::compute::cast;
172    use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
173    use vortex_array::compute::conformance::consistency::test_array_consistency;
174    use vortex_buffer::buffer;
175    use vortex_dtype::{DType, Nullability, PType};
176    use vortex_scalar::Scalar;
177
178    use crate::SparseArray;
179
180    #[rstest]
181    // Basic sparse arrays
182    #[case::sparse_i32_null_fill(SparseArray::try_new(
183        buffer![2u64, 5, 8].into_array(),
184        PrimitiveArray::from_option_iter([Some(100i32), Some(200), Some(300)]).into_array(),
185        10,
186        Scalar::null_typed::<i32>()
187    ).unwrap())]
188    #[case::sparse_i32_value_fill(SparseArray::try_new(
189        buffer![1u64, 3, 7].into_array(),
190        buffer![42i32, 84, 126].into_array(),
191        10,
192        Scalar::from(0i32)
193    ).unwrap())]
194    // Different types
195    #[case::sparse_u64(SparseArray::try_new(
196        buffer![0u64, 4, 9].into_array(),
197        buffer![1000u64, 2000, 3000].into_array(),
198        10,
199        Scalar::from(999u64)
200    ).unwrap())]
201    #[case::sparse_f32(SparseArray::try_new(
202        buffer![2u64, 6].into_array(),
203        buffer![std::f32::consts::PI, std::f32::consts::E].into_array(),
204        8,
205        Scalar::from(0.0f32)
206    ).unwrap())]
207    // Edge cases
208    #[case::sparse_single_patch(SparseArray::try_new(
209        buffer![5u64].into_array(),
210        buffer![42i32].into_array(),
211        10,
212        Scalar::from(-1i32)
213    ).unwrap())]
214    #[case::sparse_dense_patches(SparseArray::try_new(
215        buffer![0u64, 1, 2, 3, 4].into_array(),
216        PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40), Some(50)]).into_array(),
217        5,
218        Scalar::null_typed::<i32>()
219    ).unwrap())]
220    // Large sparse arrays
221    #[case::sparse_large(SparseArray::try_new(
222        buffer![100u64, 500, 900, 1500, 1999].into_array(),
223        buffer![111i32, 222, 333, 444, 555].into_array(),
224        2000,
225        Scalar::from(0i32)
226    ).unwrap())]
227    // Nullable patches
228    #[case::sparse_nullable_patches({
229        let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
230        SparseArray::try_new(
231            buffer![1u64, 4, 7].into_array(),
232            cast(
233                &PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
234                null_fill_value.dtype()
235            ).unwrap(),
236            10,
237            null_fill_value
238        ).unwrap()
239    })]
240
241    fn test_sparse_consistency(#[case] array: SparseArray) {
242        test_array_consistency(array.as_ref());
243    }
244
245    #[rstest]
246    #[case::sparse_i32_basic(SparseArray::try_new(
247        buffer![2u64, 5, 8].into_array(),
248        buffer![100i32, 200, 300].into_array(),
249        10,
250        Scalar::from(0i32)
251    ).unwrap())]
252    #[case::sparse_u32_basic(SparseArray::try_new(
253        buffer![1u64, 3, 7].into_array(),
254        buffer![1000u32, 2000, 3000].into_array(),
255        10,
256        Scalar::from(100u32)
257    ).unwrap())]
258    #[case::sparse_i64_basic(SparseArray::try_new(
259        buffer![0u64, 4, 9].into_array(),
260        buffer![5000i64, 6000, 7000].into_array(),
261        10,
262        Scalar::from(1000i64)
263    ).unwrap())]
264    #[case::sparse_f32_basic(SparseArray::try_new(
265        buffer![2u64, 6].into_array(),
266        buffer![1.5f32, 2.5].into_array(),
267        8,
268        Scalar::from(0.5f32)
269    ).unwrap())]
270    #[case::sparse_f64_basic(SparseArray::try_new(
271        buffer![1u64, 5, 9].into_array(),
272        buffer![10.1f64, 20.2, 30.3].into_array(),
273        10,
274        Scalar::from(5.0f64)
275    ).unwrap())]
276    #[case::sparse_i32_large(SparseArray::try_new(
277        buffer![10u64, 50, 90, 150, 199].into_array(),
278        buffer![111i32, 222, 333, 444, 555].into_array(),
279        200,
280        Scalar::from(0i32)
281    ).unwrap())]
282    fn test_sparse_binary_numeric(#[case] array: SparseArray) {
283        test_binary_numeric_array(array.into_array());
284    }
285}