Skip to main content

vortex_sparse/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::ConstantArray;
9use vortex_array::arrays::TakeExecute;
10use vortex_error::VortexResult;
11
12use crate::SparseArray;
13use crate::SparseVTable;
14
15impl TakeExecute for SparseVTable {
16    fn take(
17        array: &SparseArray,
18        indices: &dyn Array,
19        _ctx: &mut ExecutionCtx,
20    ) -> VortexResult<Option<ArrayRef>> {
21        let patches_take = if array.fill_scalar().is_null() {
22            array.patches().take(indices)?
23        } else {
24            array.patches().take_with_nulls(indices)?
25        };
26
27        let Some(new_patches) = patches_take else {
28            let result_fill_scalar = array.fill_scalar().cast(
29                &array
30                    .dtype()
31                    .union_nullability(indices.dtype().nullability()),
32            )?;
33            return Ok(Some(
34                ConstantArray::new(result_fill_scalar, indices.len()).into_array(),
35            ));
36        };
37
38        // See `SparseEncoding::slice`.
39        if new_patches.array_len() == new_patches.values().len() {
40            return Ok(Some(new_patches.into_values()));
41        }
42
43        Ok(Some(
44            SparseArray::try_new_from_patches(
45                new_patches,
46                array.fill_scalar().cast(
47                    &array
48                        .dtype()
49                        .union_nullability(indices.dtype().nullability()),
50                )?,
51            )?
52            .into_array(),
53        ))
54    }
55}
56
57#[cfg(test)]
58mod test {
59    use rstest::rstest;
60    use vortex_array::Array;
61    use vortex_array::ArrayRef;
62    use vortex_array::IntoArray;
63    use vortex_array::arrays::ConstantArray;
64    use vortex_array::arrays::PrimitiveArray;
65    use vortex_array::assert_arrays_eq;
66    use vortex_array::scalar::Scalar;
67    use vortex_array::validity::Validity;
68    use vortex_buffer::buffer;
69    use vortex_dtype::Nullability;
70
71    use crate::SparseArray;
72
73    fn test_array_fill_value() -> Scalar {
74        // making this const is annoying
75        Scalar::null_native::<f64>()
76    }
77
78    fn sparse_array() -> ArrayRef {
79        SparseArray::try_new(
80            buffer![0u64, 37, 47, 99].into_array(),
81            PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
82            100,
83            test_array_fill_value(),
84        )
85        .unwrap()
86        .into_array()
87    }
88
89    #[test]
90    fn take_with_non_zero_offset() {
91        let sparse = sparse_array();
92        let sparse = sparse.slice(30..40).unwrap();
93        let taken = sparse.take(buffer![6, 7, 8].into_array()).unwrap();
94        let expected = PrimitiveArray::from_option_iter([Option::<f64>::None, Some(0.47), None]);
95        assert_arrays_eq!(taken, expected.to_array());
96    }
97
98    #[test]
99    fn sparse_take() {
100        let sparse = sparse_array();
101        let taken = sparse.take(buffer![0, 47, 47, 0, 99].into_array()).unwrap();
102        let expected = PrimitiveArray::from_option_iter([
103            Some(1.23f64),
104            Some(9.99),
105            Some(9.99),
106            Some(1.23),
107            Some(3.5),
108        ]);
109        assert_arrays_eq!(taken, expected.to_array());
110    }
111
112    #[test]
113    fn nonexistent_take() {
114        let sparse = sparse_array();
115        let taken = sparse.take(buffer![69].into_array()).unwrap();
116        let expected = ConstantArray::new(test_array_fill_value(), 1).into_array();
117        assert_arrays_eq!(taken, expected);
118    }
119
120    #[test]
121    fn ordered_take() {
122        let sparse = sparse_array();
123        // Note: take returns a canonical array, not SparseArray
124        let taken = sparse.take(buffer![69, 37].into_array()).unwrap();
125        // Index 69 is not in sparse array (fill value is null), index 37 has value 0.47
126        let expected = PrimitiveArray::from_option_iter([Option::<f64>::None, Some(0.47f64)]);
127        assert_arrays_eq!(taken, expected.to_array());
128    }
129
130    #[test]
131    fn nullable_take() {
132        let arr = SparseArray::try_new(
133            buffer![1u32].into_array(),
134            buffer![10].into_array(),
135            10,
136            Scalar::primitive(1, Nullability::NonNullable),
137        )
138        .unwrap();
139
140        let taken = arr
141            .take(
142                PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
143                    .to_array(),
144            )
145            .unwrap();
146
147        let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
148        assert_arrays_eq!(taken, expected.to_array());
149    }
150
151    #[test]
152    fn nullable_take_with_many_patches() {
153        let arr = SparseArray::try_new(
154            buffer![1u32, 3, 7, 8, 9].into_array(),
155            buffer![10, 8, 3, 2, 1].into_array(),
156            10,
157            Scalar::primitive(1, Nullability::NonNullable),
158        )
159        .unwrap();
160
161        let taken = arr
162            .take(
163                PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
164                    .to_array(),
165            )
166            .unwrap();
167
168        let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
169        assert_arrays_eq!(taken, expected.to_array());
170    }
171
172    #[rstest]
173    #[case(SparseArray::try_new(
174        buffer![0u64, 37, 47, 99].into_array(),
175        PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
176        100,
177        Scalar::null_native::<f64>(),
178    ).unwrap())]
179    #[case(SparseArray::try_new(
180        buffer![1u32, 3, 7, 8, 9].into_array(),
181        buffer![10, 8, 3, 2, 1].into_array(),
182        10,
183        Scalar::from(0i32),
184    ).unwrap())]
185    #[case({
186        let nullable_values = PrimitiveArray::from_option_iter([Some(100i64), None, Some(300)]);
187        SparseArray::try_new(
188            buffer![2u64, 4, 6].into_array(),
189            nullable_values.into_array(),
190            10,
191            Scalar::null_native::<i64>(),
192        ).unwrap()
193    })]
194    #[case(SparseArray::try_new(
195        buffer![5u64].into_array(),
196        buffer![999i32].into_array(),
197        20,
198        Scalar::from(-1i32),
199    ).unwrap())]
200    fn test_take_sparse_conformance(#[case] sparse: SparseArray) {
201        use vortex_array::compute::conformance::take::test_take_conformance;
202        test_take_conformance(sparse.as_ref());
203    }
204}