vortex_sparse/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{TakeKernel, TakeKernelAdapter};
6use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8
9use crate::{SparseArray, SparseVTable};
10
11impl TakeKernel for SparseVTable {
12    fn take(&self, array: &SparseArray, take_indices: &dyn Array) -> VortexResult<ArrayRef> {
13        let patches_take = if array.fill_scalar().is_null() {
14            array.patches().take(take_indices)?
15        } else {
16            array.patches().take_with_nulls(take_indices)?
17        };
18
19        let Some(new_patches) = patches_take else {
20            let result_fill_scalar = array.fill_scalar().cast(
21                &array
22                    .dtype()
23                    .union_nullability(take_indices.dtype().nullability()),
24            )?;
25            return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array());
26        };
27
28        // See `SparseEncoding::slice`.
29        if new_patches.array_len() == new_patches.values().len() {
30            return Ok(new_patches.into_values());
31        }
32
33        Ok(SparseArray::try_new_from_patches(
34            new_patches,
35            array.fill_scalar().cast(
36                &array
37                    .dtype()
38                    .union_nullability(take_indices.dtype().nullability()),
39            )?,
40        )?
41        .into_array())
42    }
43}
44
45register_kernel!(TakeKernelAdapter(SparseVTable).lift());
46
47#[cfg(test)]
48mod test {
49    use rstest::rstest;
50    use vortex_array::arrays::PrimitiveArray;
51    use vortex_array::compute::take;
52    use vortex_array::validity::Validity;
53    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
54    use vortex_buffer::buffer;
55    use vortex_dtype::PType::I32;
56    use vortex_dtype::{DType, Nullability};
57    use vortex_scalar::Scalar;
58
59    use crate::{SparseArray, SparseVTable};
60
61    fn test_array_fill_value() -> Scalar {
62        // making this const is annoying
63        Scalar::null_typed::<f64>()
64    }
65
66    fn sparse_array() -> ArrayRef {
67        SparseArray::try_new(
68            buffer![0u64, 37, 47, 99].into_array(),
69            PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
70            100,
71            test_array_fill_value(),
72        )
73        .unwrap()
74        .into_array()
75    }
76
77    #[test]
78    fn take_with_non_zero_offset() {
79        let sparse = sparse_array();
80        let sparse = sparse.slice(30, 40).unwrap();
81        let sparse = take(&sparse, &buffer![6, 7, 8].into_array()).unwrap();
82        assert_eq!(sparse.scalar_at(0).unwrap(), test_array_fill_value());
83        assert_eq!(sparse.scalar_at(1).unwrap(), Scalar::from(Some(0.47)));
84        assert_eq!(sparse.scalar_at(2).unwrap(), test_array_fill_value());
85    }
86
87    #[test]
88    fn sparse_take() {
89        let sparse = sparse_array();
90        let prim = take(&sparse, &buffer![0, 47, 47, 0, 99].into_array())
91            .unwrap()
92            .to_primitive()
93            .unwrap();
94        assert_eq!(prim.as_slice::<f64>(), [1.23f64, 9.99, 9.99, 1.23, 3.5]);
95    }
96
97    #[test]
98    fn nonexistent_take() {
99        let sparse = sparse_array();
100        let taken = take(&sparse, &buffer![69].into_array()).unwrap();
101        assert_eq!(taken.len(), 1);
102        assert_eq!(taken.scalar_at(0).unwrap(), test_array_fill_value());
103    }
104
105    #[test]
106    fn ordered_take() {
107        let sparse = sparse_array();
108        let taken_arr = take(&sparse, &buffer![69, 37].into_array()).unwrap();
109        let taken = taken_arr.as_::<SparseVTable>();
110
111        assert_eq!(
112            taken
113                .patches()
114                .indices()
115                .to_primitive()
116                .unwrap()
117                .as_slice::<u64>(),
118            [1]
119        );
120        assert_eq!(
121            taken
122                .patches()
123                .values()
124                .to_primitive()
125                .unwrap()
126                .as_slice::<f64>(),
127            [0.47f64]
128        );
129        assert_eq!(taken.len(), 2);
130    }
131
132    #[test]
133    fn nullable_take() {
134        let arr = SparseArray::try_new(
135            buffer![1u32].into_array(),
136            buffer![10].into_array(),
137            10,
138            Scalar::primitive(1, Nullability::NonNullable),
139        )
140        .unwrap();
141
142        let taken = take(
143            arr.as_ref(),
144            PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
145                .as_ref(),
146        )
147        .unwrap();
148
149        assert_eq!(
150            taken.scalar_at(0).unwrap(),
151            Scalar::primitive(1, Nullability::Nullable)
152        );
153        assert_eq!(
154            taken.scalar_at(1).unwrap(),
155            Scalar::primitive(10, Nullability::Nullable)
156        );
157        assert_eq!(
158            taken.scalar_at(2).unwrap(),
159            Scalar::null(DType::Primitive(I32, Nullability::Nullable))
160        );
161    }
162
163    #[test]
164    fn nullable_take_with_many_patches() {
165        let arr = SparseArray::try_new(
166            buffer![1u32, 3, 7, 8, 9].into_array(),
167            buffer![10, 8, 3, 2, 1].into_array(),
168            10,
169            Scalar::primitive(1, Nullability::NonNullable),
170        )
171        .unwrap();
172
173        let taken = take(
174            arr.as_ref(),
175            PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
176                .as_ref(),
177        )
178        .unwrap();
179
180        assert_eq!(
181            taken.scalar_at(0).unwrap(),
182            Scalar::primitive(1, Nullability::Nullable)
183        );
184        assert_eq!(
185            taken.scalar_at(1).unwrap(),
186            Scalar::primitive(10, Nullability::Nullable)
187        );
188        assert_eq!(
189            taken.scalar_at(2).unwrap(),
190            Scalar::null(DType::Primitive(I32, Nullability::Nullable))
191        );
192    }
193
194    #[rstest]
195    #[case(SparseArray::try_new(
196        buffer![0u64, 37, 47, 99].into_array(),
197        PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
198        100,
199        Scalar::null_typed::<f64>(),
200    ).unwrap())]
201    #[case(SparseArray::try_new(
202        buffer![1u32, 3, 7, 8, 9].into_array(),
203        buffer![10, 8, 3, 2, 1].into_array(),
204        10,
205        Scalar::from(0i32),
206    ).unwrap())]
207    #[case({
208        let nullable_values = PrimitiveArray::from_option_iter([Some(100i64), None, Some(300)]);
209        SparseArray::try_new(
210            buffer![2u64, 4, 6].into_array(),
211            nullable_values.into_array(),
212            10,
213            Scalar::null_typed::<i64>(),
214        ).unwrap()
215    })]
216    #[case(SparseArray::try_new(
217        buffer![5u64].into_array(),
218        buffer![999i32].into_array(),
219        20,
220        Scalar::from(-1i32),
221    ).unwrap())]
222    fn test_take_sparse_conformance(#[case] sparse: SparseArray) {
223        use vortex_array::compute::conformance::take::test_take_conformance;
224        test_take_conformance(sparse.as_ref());
225    }
226}