Skip to main content

vortex_sparse/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::dict::TakeExecute;
9use vortex_error::VortexResult;
10
11use crate::ConstantArray;
12use crate::Sparse;
13impl TakeExecute for Sparse {
14    fn take(
15        array: ArrayView<'_, Self>,
16        indices: &ArrayRef,
17        ctx: &mut ExecutionCtx,
18    ) -> VortexResult<Option<ArrayRef>> {
19        let patches_take = if array.fill_scalar().is_null() {
20            array.patches().take(indices, ctx)?
21        } else {
22            array.patches().take_with_nulls(indices, ctx)?
23        };
24
25        let Some(new_patches) = patches_take else {
26            let result_fill_scalar = array.fill_scalar().cast(
27                &array
28                    .dtype()
29                    .union_nullability(indices.dtype().nullability()),
30            )?;
31            return Ok(Some(
32                ConstantArray::new(result_fill_scalar, indices.len()).into_array(),
33            ));
34        };
35
36        // See `SparseEncoding::slice`.
37        if new_patches.array_len() == new_patches.values().len() {
38            return Ok(Some(new_patches.into_values()));
39        }
40
41        Ok(Some(
42            Sparse::try_new_from_patches(
43                new_patches,
44                array.fill_scalar().cast(
45                    &array
46                        .dtype()
47                        .union_nullability(indices.dtype().nullability()),
48                )?,
49            )?
50            .into_array(),
51        ))
52    }
53}
54
55#[cfg(test)]
56mod test {
57    use rstest::rstest;
58    use vortex_array::ArrayRef;
59    use vortex_array::IntoArray;
60    use vortex_array::arrays::ConstantArray;
61    use vortex_array::arrays::PrimitiveArray;
62    use vortex_array::assert_arrays_eq;
63    use vortex_array::dtype::Nullability;
64    use vortex_array::scalar::Scalar;
65    use vortex_array::validity::Validity;
66    use vortex_buffer::buffer;
67
68    use crate::Sparse;
69    use crate::SparseArray;
70
71    fn test_array_fill_value() -> Scalar {
72        // making this const is annoying
73        Scalar::null_native::<f64>()
74    }
75
76    fn sparse_array() -> ArrayRef {
77        Sparse::try_new(
78            buffer![0u64, 37, 47, 99].into_array(),
79            PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
80            100,
81            test_array_fill_value(),
82        )
83        .unwrap()
84        .into_array()
85    }
86
87    #[test]
88    fn take_with_non_zero_offset() {
89        let sparse = sparse_array();
90        let sparse = sparse.slice(30..40).unwrap();
91        let taken = sparse.take(buffer![6, 7, 8].into_array()).unwrap();
92        let expected = PrimitiveArray::from_option_iter([Option::<f64>::None, Some(0.47), None]);
93        assert_arrays_eq!(taken, expected.into_array());
94    }
95
96    #[test]
97    fn sparse_take() {
98        let sparse = sparse_array();
99        let taken = sparse.take(buffer![0, 47, 47, 0, 99].into_array()).unwrap();
100        let expected = PrimitiveArray::from_option_iter([
101            Some(1.23f64),
102            Some(9.99),
103            Some(9.99),
104            Some(1.23),
105            Some(3.5),
106        ]);
107        assert_arrays_eq!(taken, expected.into_array());
108    }
109
110    #[test]
111    fn nonexistent_take() {
112        let sparse = sparse_array();
113        let taken = sparse.take(buffer![69].into_array()).unwrap();
114        let expected = ConstantArray::new(test_array_fill_value(), 1).into_array();
115        assert_arrays_eq!(taken, expected);
116    }
117
118    #[test]
119    fn ordered_take() {
120        let sparse = sparse_array();
121        // Note: take returns a canonical array, not SparseArray
122        let taken = sparse.take(buffer![69, 37].into_array()).unwrap();
123        // Index 69 is not in sparse array (fill value is null), index 37 has value 0.47
124        let expected = PrimitiveArray::from_option_iter([Option::<f64>::None, Some(0.47f64)]);
125        assert_arrays_eq!(taken, expected.into_array());
126    }
127
128    #[test]
129    fn nullable_take() {
130        let arr = Sparse::try_new(
131            buffer![1u32].into_array(),
132            buffer![10].into_array(),
133            10,
134            Scalar::primitive(1, Nullability::NonNullable),
135        )
136        .unwrap();
137
138        let taken = arr
139            .take(
140                PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
141                    .into_array(),
142            )
143            .unwrap();
144
145        let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
146        assert_arrays_eq!(taken, expected.into_array());
147    }
148
149    #[test]
150    fn nullable_take_with_many_patches() {
151        let arr = Sparse::try_new(
152            buffer![1u32, 3, 7, 8, 9].into_array(),
153            buffer![10, 8, 3, 2, 1].into_array(),
154            10,
155            Scalar::primitive(1, Nullability::NonNullable),
156        )
157        .unwrap();
158
159        let taken = arr
160            .take(
161                PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
162                    .into_array(),
163            )
164            .unwrap();
165
166        let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
167        assert_arrays_eq!(taken, expected.into_array());
168    }
169
170    #[rstest]
171    #[case(Sparse::try_new(
172        buffer![0u64, 37, 47, 99].into_array(),
173        PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
174        100,
175        Scalar::null_native::<f64>(),
176    ).unwrap())]
177    #[case(Sparse::try_new(
178        buffer![1u32, 3, 7, 8, 9].into_array(),
179        buffer![10, 8, 3, 2, 1].into_array(),
180        10,
181        Scalar::from(0i32),
182    ).unwrap())]
183    #[case({
184        let nullable_values = PrimitiveArray::from_option_iter([Some(100i64), None, Some(300)]);
185        Sparse::try_new(
186            buffer![2u64, 4, 6].into_array(),
187            nullable_values.into_array(),
188            10,
189            Scalar::null_native::<i64>(),
190        ).unwrap()
191    })]
192    #[case(Sparse::try_new(
193        buffer![5u64].into_array(),
194        buffer![999i32].into_array(),
195        20,
196        Scalar::from(-1i32),
197    ).unwrap())]
198    fn test_take_sparse_conformance(#[case] sparse: SparseArray) {
199        use vortex_array::compute::conformance::take::test_take_conformance;
200        test_take_conformance(&sparse.into_array());
201    }
202}