vortex_sparse/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::IntoArray;
7use vortex_array::arrays::ConstantArray;
8use vortex_array::compute::TakeKernel;
9use vortex_array::compute::TakeKernelAdapter;
10use vortex_array::register_kernel;
11use vortex_error::VortexResult;
12
13use crate::SparseArray;
14use crate::SparseVTable;
15
16impl TakeKernel for SparseVTable {
17    fn take(&self, array: &SparseArray, take_indices: &dyn Array) -> VortexResult<ArrayRef> {
18        let patches_take = if array.fill_scalar().is_null() {
19            array.patches().take(take_indices)?
20        } else {
21            array.patches().take_with_nulls(take_indices)?
22        };
23
24        let Some(new_patches) = patches_take else {
25            let result_fill_scalar = array.fill_scalar().cast(
26                &array
27                    .dtype()
28                    .union_nullability(take_indices.dtype().nullability()),
29            )?;
30            return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array());
31        };
32
33        // See `SparseEncoding::slice`.
34        if new_patches.array_len() == new_patches.values().len() {
35            return Ok(new_patches.into_values());
36        }
37
38        Ok(SparseArray::try_new_from_patches(
39            new_patches,
40            array.fill_scalar().cast(
41                &array
42                    .dtype()
43                    .union_nullability(take_indices.dtype().nullability()),
44            )?,
45        )?
46        .into_array())
47    }
48}
49
50register_kernel!(TakeKernelAdapter(SparseVTable).lift());
51
52#[cfg(test)]
53mod test {
54    use rstest::rstest;
55    use vortex_array::Array;
56    use vortex_array::ArrayRef;
57    use vortex_array::IntoArray;
58    use vortex_array::ToCanonical;
59    use vortex_array::arrays::ConstantArray;
60    use vortex_array::arrays::PrimitiveArray;
61    use vortex_array::assert_arrays_eq;
62    use vortex_array::compute::take;
63    use vortex_array::validity::Validity;
64    use vortex_buffer::buffer;
65    use vortex_dtype::Nullability;
66    use vortex_scalar::Scalar;
67
68    use crate::SparseArray;
69    use crate::SparseVTable;
70
71    fn test_array_fill_value() -> Scalar {
72        // making this const is annoying
73        Scalar::null_typed::<f64>()
74    }
75
76    fn sparse_array() -> ArrayRef {
77        SparseArray::try_new(
78            buffer![0u64, 37, 47, 99].into_array(),
79            PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
80            100,
81            test_array_fill_value(),
82        )
83        .unwrap()
84        .into_array()
85    }
86
87    #[test]
88    fn take_with_non_zero_offset() {
89        let sparse = sparse_array();
90        let sparse = sparse.slice(30..40);
91        let taken = take(&sparse, &buffer![6, 7, 8].into_array()).unwrap();
92        let expected = PrimitiveArray::from_option_iter([Option::<f64>::None, Some(0.47), None]);
93        assert_arrays_eq!(taken, expected.to_array());
94    }
95
96    #[test]
97    fn sparse_take() {
98        let sparse = sparse_array();
99        let taken = take(&sparse, &buffer![0, 47, 47, 0, 99].into_array()).unwrap();
100        let expected = PrimitiveArray::from_option_iter([
101            Some(1.23f64),
102            Some(9.99),
103            Some(9.99),
104            Some(1.23),
105            Some(3.5),
106        ]);
107        assert_arrays_eq!(taken, expected.to_array());
108    }
109
110    #[test]
111    fn nonexistent_take() {
112        let sparse = sparse_array();
113        let taken = take(&sparse, &buffer![69].into_array()).unwrap();
114        let expected = ConstantArray::new(test_array_fill_value(), 1).into_array();
115        assert_arrays_eq!(taken, expected);
116    }
117
118    #[test]
119    fn ordered_take() {
120        let sparse = sparse_array();
121        let taken_arr = take(&sparse, &buffer![69, 37].into_array()).unwrap();
122        let taken = taken_arr.as_::<SparseVTable>();
123
124        assert_eq!(
125            taken.patches().indices().to_primitive().as_slice::<u64>(),
126            [1]
127        );
128        assert_eq!(
129            taken.patches().values().to_primitive().as_slice::<f64>(),
130            [0.47f64]
131        );
132        assert_eq!(taken.len(), 2);
133    }
134
135    #[test]
136    fn nullable_take() {
137        let arr = SparseArray::try_new(
138            buffer![1u32].into_array(),
139            buffer![10].into_array(),
140            10,
141            Scalar::primitive(1, Nullability::NonNullable),
142        )
143        .unwrap();
144
145        let taken = take(
146            arr.as_ref(),
147            PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
148                .as_ref(),
149        )
150        .unwrap();
151
152        let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
153        assert_arrays_eq!(taken, expected.to_array());
154    }
155
156    #[test]
157    fn nullable_take_with_many_patches() {
158        let arr = SparseArray::try_new(
159            buffer![1u32, 3, 7, 8, 9].into_array(),
160            buffer![10, 8, 3, 2, 1].into_array(),
161            10,
162            Scalar::primitive(1, Nullability::NonNullable),
163        )
164        .unwrap();
165
166        let taken = take(
167            arr.as_ref(),
168            PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
169                .as_ref(),
170        )
171        .unwrap();
172
173        let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
174        assert_arrays_eq!(taken, expected.to_array());
175    }
176
177    #[rstest]
178    #[case(SparseArray::try_new(
179        buffer![0u64, 37, 47, 99].into_array(),
180        PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
181        100,
182        Scalar::null_typed::<f64>(),
183    ).unwrap())]
184    #[case(SparseArray::try_new(
185        buffer![1u32, 3, 7, 8, 9].into_array(),
186        buffer![10, 8, 3, 2, 1].into_array(),
187        10,
188        Scalar::from(0i32),
189    ).unwrap())]
190    #[case({
191        let nullable_values = PrimitiveArray::from_option_iter([Some(100i64), None, Some(300)]);
192        SparseArray::try_new(
193            buffer![2u64, 4, 6].into_array(),
194            nullable_values.into_array(),
195            10,
196            Scalar::null_typed::<i64>(),
197        ).unwrap()
198    })]
199    #[case(SparseArray::try_new(
200        buffer![5u64].into_array(),
201        buffer![999i32].into_array(),
202        20,
203        Scalar::from(-1i32),
204    ).unwrap())]
205    fn test_take_sparse_conformance(#[case] sparse: SparseArray) {
206        use vortex_array::compute::conformance::take::test_take_conformance;
207        test_take_conformance(sparse.as_ref());
208    }
209}