Skip to main content

vortex_sparse/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::dict::TakeExecute;
9use vortex_error::VortexResult;
10
11use crate::ConstantArray;
12use crate::Sparse;
13use crate::SparseExt as _;
14impl TakeExecute for Sparse {
15    fn take(
16        array: ArrayView<'_, Self>,
17        indices: &ArrayRef,
18        ctx: &mut ExecutionCtx,
19    ) -> VortexResult<Option<ArrayRef>> {
20        let patches_take = if array.fill_scalar().is_null() {
21            array.patches().take(indices, ctx)?
22        } else {
23            array.patches().take_with_nulls(indices, ctx)?
24        };
25
26        let Some(new_patches) = patches_take else {
27            let result_fill_scalar = array.fill_scalar().cast(
28                &array
29                    .dtype()
30                    .union_nullability(indices.dtype().nullability()),
31            )?;
32            return Ok(Some(
33                ConstantArray::new(result_fill_scalar, indices.len()).into_array(),
34            ));
35        };
36
37        // See `SparseEncoding::slice`.
38        if new_patches.array_len() == new_patches.values().len() {
39            return Ok(Some(new_patches.into_values()));
40        }
41
42        Ok(Some(
43            Sparse::try_new_from_patches(
44                new_patches,
45                array.fill_scalar().cast(
46                    &array
47                        .dtype()
48                        .union_nullability(indices.dtype().nullability()),
49                )?,
50            )?
51            .into_array(),
52        ))
53    }
54}
55
56#[cfg(test)]
57mod test {
58    use rstest::rstest;
59    use vortex_array::ArrayRef;
60    use vortex_array::IntoArray;
61    use vortex_array::arrays::ConstantArray;
62    use vortex_array::arrays::PrimitiveArray;
63    use vortex_array::assert_arrays_eq;
64    use vortex_array::dtype::Nullability;
65    use vortex_array::scalar::Scalar;
66    use vortex_array::validity::Validity;
67    use vortex_buffer::buffer;
68
69    use crate::Sparse;
70    use crate::SparseArray;
71
72    fn test_array_fill_value() -> Scalar {
73        // making this const is annoying
74        Scalar::null_native::<f64>()
75    }
76
77    fn sparse_array() -> ArrayRef {
78        Sparse::try_new(
79            buffer![0u64, 37, 47, 99].into_array(),
80            PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
81            100,
82            test_array_fill_value(),
83        )
84        .unwrap()
85        .into_array()
86    }
87
88    #[test]
89    fn take_with_non_zero_offset() {
90        let sparse = sparse_array();
91        let sparse = sparse.slice(30..40).unwrap();
92        let taken = sparse.take(buffer![6, 7, 8].into_array()).unwrap();
93        let expected = PrimitiveArray::from_option_iter([Option::<f64>::None, Some(0.47), None]);
94        assert_arrays_eq!(taken, expected.into_array());
95    }
96
97    #[test]
98    fn sparse_take() {
99        let sparse = sparse_array();
100        let taken = sparse.take(buffer![0, 47, 47, 0, 99].into_array()).unwrap();
101        let expected = PrimitiveArray::from_option_iter([
102            Some(1.23f64),
103            Some(9.99),
104            Some(9.99),
105            Some(1.23),
106            Some(3.5),
107        ]);
108        assert_arrays_eq!(taken, expected.into_array());
109    }
110
111    #[test]
112    fn nonexistent_take() {
113        let sparse = sparse_array();
114        let taken = sparse.take(buffer![69].into_array()).unwrap();
115        let expected = ConstantArray::new(test_array_fill_value(), 1).into_array();
116        assert_arrays_eq!(taken, expected);
117    }
118
119    #[test]
120    fn ordered_take() {
121        let sparse = sparse_array();
122        // Note: take returns a canonical array, not SparseArray
123        let taken = sparse.take(buffer![69, 37].into_array()).unwrap();
124        // Index 69 is not in sparse array (fill value is null), index 37 has value 0.47
125        let expected = PrimitiveArray::from_option_iter([Option::<f64>::None, Some(0.47f64)]);
126        assert_arrays_eq!(taken, expected.into_array());
127    }
128
129    #[test]
130    fn nullable_take() {
131        let arr = Sparse::try_new(
132            buffer![1u32].into_array(),
133            buffer![10].into_array(),
134            10,
135            Scalar::primitive(1, Nullability::NonNullable),
136        )
137        .unwrap();
138
139        let taken = arr
140            .take(
141                PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
142                    .into_array(),
143            )
144            .unwrap();
145
146        let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
147        assert_arrays_eq!(taken, expected.into_array());
148    }
149
150    #[test]
151    fn nullable_take_with_many_patches() {
152        let arr = Sparse::try_new(
153            buffer![1u32, 3, 7, 8, 9].into_array(),
154            buffer![10, 8, 3, 2, 1].into_array(),
155            10,
156            Scalar::primitive(1, Nullability::NonNullable),
157        )
158        .unwrap();
159
160        let taken = arr
161            .take(
162                PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
163                    .into_array(),
164            )
165            .unwrap();
166
167        let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
168        assert_arrays_eq!(taken, expected.into_array());
169    }
170
171    #[rstest]
172    #[case(Sparse::try_new(
173        buffer![0u64, 37, 47, 99].into_array(),
174        PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
175        100,
176        Scalar::null_native::<f64>(),
177    ).unwrap())]
178    #[case(Sparse::try_new(
179        buffer![1u32, 3, 7, 8, 9].into_array(),
180        buffer![10, 8, 3, 2, 1].into_array(),
181        10,
182        Scalar::from(0i32),
183    ).unwrap())]
184    #[case({
185        let nullable_values = PrimitiveArray::from_option_iter([Some(100i64), None, Some(300)]);
186        Sparse::try_new(
187            buffer![2u64, 4, 6].into_array(),
188            nullable_values.into_array(),
189            10,
190            Scalar::null_native::<i64>(),
191        ).unwrap()
192    })]
193    #[case(Sparse::try_new(
194        buffer![5u64].into_array(),
195        buffer![999i32].into_array(),
196        20,
197        Scalar::from(-1i32),
198    ).unwrap())]
199    fn test_take_sparse_conformance(#[case] sparse: SparseArray) {
200        use vortex_array::compute::conformance::take::test_take_conformance;
201        test_take_conformance(&sparse.into_array());
202    }
203}