vortex_sparse/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{TakeKernel, TakeKernelAdapter};
6use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8
9use crate::{SparseArray, SparseVTable};
10
11impl TakeKernel for SparseVTable {
12    fn take(&self, array: &SparseArray, take_indices: &dyn Array) -> VortexResult<ArrayRef> {
13        let patches_take = if array.fill_scalar().is_null() {
14            array.patches().take(take_indices)?
15        } else {
16            array.patches().take_with_nulls(take_indices)?
17        };
18
19        let Some(new_patches) = patches_take else {
20            let result_fill_scalar = array.fill_scalar().cast(
21                &array
22                    .dtype()
23                    .union_nullability(take_indices.dtype().nullability()),
24            )?;
25            return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array());
26        };
27
28        // See `SparseEncoding::slice`.
29        if new_patches.array_len() == new_patches.values().len() {
30            return Ok(new_patches.into_values());
31        }
32
33        Ok(SparseArray::try_new_from_patches(
34            new_patches,
35            array.fill_scalar().cast(
36                &array
37                    .dtype()
38                    .union_nullability(take_indices.dtype().nullability()),
39            )?,
40        )?
41        .into_array())
42    }
43}
44
45register_kernel!(TakeKernelAdapter(SparseVTable).lift());
46
47#[cfg(test)]
48mod test {
49    use rstest::rstest;
50    use vortex_array::arrays::{ConstantArray, PrimitiveArray};
51    use vortex_array::compute::take;
52    use vortex_array::validity::Validity;
53    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, assert_arrays_eq};
54    use vortex_buffer::buffer;
55    use vortex_dtype::Nullability;
56    use vortex_scalar::Scalar;
57
58    use crate::{SparseArray, SparseVTable};
59
60    fn test_array_fill_value() -> Scalar {
61        // making this const is annoying
62        Scalar::null_typed::<f64>()
63    }
64
65    fn sparse_array() -> ArrayRef {
66        SparseArray::try_new(
67            buffer![0u64, 37, 47, 99].into_array(),
68            PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
69            100,
70            test_array_fill_value(),
71        )
72        .unwrap()
73        .into_array()
74    }
75
76    #[test]
77    fn take_with_non_zero_offset() {
78        let sparse = sparse_array();
79        let sparse = sparse.slice(30..40);
80        let taken = take(&sparse, &buffer![6, 7, 8].into_array()).unwrap();
81        let expected = PrimitiveArray::from_option_iter([Option::<f64>::None, Some(0.47), None]);
82        assert_arrays_eq!(taken, expected.to_array());
83    }
84
85    #[test]
86    fn sparse_take() {
87        let sparse = sparse_array();
88        let taken = take(&sparse, &buffer![0, 47, 47, 0, 99].into_array()).unwrap();
89        let expected = PrimitiveArray::from_option_iter([
90            Some(1.23f64),
91            Some(9.99),
92            Some(9.99),
93            Some(1.23),
94            Some(3.5),
95        ]);
96        assert_arrays_eq!(taken, expected.to_array());
97    }
98
99    #[test]
100    fn nonexistent_take() {
101        let sparse = sparse_array();
102        let taken = take(&sparse, &buffer![69].into_array()).unwrap();
103        let expected = ConstantArray::new(test_array_fill_value(), 1).into_array();
104        assert_arrays_eq!(taken, expected);
105    }
106
107    #[test]
108    fn ordered_take() {
109        let sparse = sparse_array();
110        let taken_arr = take(&sparse, &buffer![69, 37].into_array()).unwrap();
111        let taken = taken_arr.as_::<SparseVTable>();
112
113        assert_eq!(
114            taken.patches().indices().to_primitive().as_slice::<u64>(),
115            [1]
116        );
117        assert_eq!(
118            taken.patches().values().to_primitive().as_slice::<f64>(),
119            [0.47f64]
120        );
121        assert_eq!(taken.len(), 2);
122    }
123
124    #[test]
125    fn nullable_take() {
126        let arr = SparseArray::try_new(
127            buffer![1u32].into_array(),
128            buffer![10].into_array(),
129            10,
130            Scalar::primitive(1, Nullability::NonNullable),
131        )
132        .unwrap();
133
134        let taken = take(
135            arr.as_ref(),
136            PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
137                .as_ref(),
138        )
139        .unwrap();
140
141        let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
142        assert_arrays_eq!(taken, expected.to_array());
143    }
144
145    #[test]
146    fn nullable_take_with_many_patches() {
147        let arr = SparseArray::try_new(
148            buffer![1u32, 3, 7, 8, 9].into_array(),
149            buffer![10, 8, 3, 2, 1].into_array(),
150            10,
151            Scalar::primitive(1, Nullability::NonNullable),
152        )
153        .unwrap();
154
155        let taken = take(
156            arr.as_ref(),
157            PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
158                .as_ref(),
159        )
160        .unwrap();
161
162        let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
163        assert_arrays_eq!(taken, expected.to_array());
164    }
165
166    #[rstest]
167    #[case(SparseArray::try_new(
168        buffer![0u64, 37, 47, 99].into_array(),
169        PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
170        100,
171        Scalar::null_typed::<f64>(),
172    ).unwrap())]
173    #[case(SparseArray::try_new(
174        buffer![1u32, 3, 7, 8, 9].into_array(),
175        buffer![10, 8, 3, 2, 1].into_array(),
176        10,
177        Scalar::from(0i32),
178    ).unwrap())]
179    #[case({
180        let nullable_values = PrimitiveArray::from_option_iter([Some(100i64), None, Some(300)]);
181        SparseArray::try_new(
182            buffer![2u64, 4, 6].into_array(),
183            nullable_values.into_array(),
184            10,
185            Scalar::null_typed::<i64>(),
186        ).unwrap()
187    })]
188    #[case(SparseArray::try_new(
189        buffer![5u64].into_array(),
190        buffer![999i32].into_array(),
191        20,
192        Scalar::from(-1i32),
193    ).unwrap())]
194    fn test_take_sparse_conformance(#[case] sparse: SparseArray) {
195        use vortex_array::compute::conformance::take::test_take_conformance;
196        test_take_conformance(sparse.as_ref());
197    }
198}