vortex_sparse/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{TakeKernel, TakeKernelAdapter};
6use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8
9use crate::{SparseArray, SparseVTable};
10
11impl TakeKernel for SparseVTable {
12    fn take(&self, array: &SparseArray, take_indices: &dyn Array) -> VortexResult<ArrayRef> {
13        let patches_take = if array.fill_scalar().is_null() {
14            array.patches().take(take_indices)?
15        } else {
16            array.patches().take_with_nulls(take_indices)?
17        };
18
19        let Some(new_patches) = patches_take else {
20            let result_fill_scalar = array.fill_scalar().cast(
21                &array
22                    .dtype()
23                    .union_nullability(take_indices.dtype().nullability()),
24            )?;
25            return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array());
26        };
27
28        // See `SparseEncoding::slice`.
29        if new_patches.array_len() == new_patches.values().len() {
30            return Ok(new_patches.into_values());
31        }
32
33        Ok(SparseArray::try_new_from_patches(
34            new_patches,
35            array.fill_scalar().cast(
36                &array
37                    .dtype()
38                    .union_nullability(take_indices.dtype().nullability()),
39            )?,
40        )?
41        .into_array())
42    }
43}
44
45register_kernel!(TakeKernelAdapter(SparseVTable).lift());
46
47#[cfg(test)]
48mod test {
49    use vortex_array::arrays::PrimitiveArray;
50    use vortex_array::compute::take;
51    use vortex_array::validity::Validity;
52    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
53    use vortex_buffer::buffer;
54    use vortex_dtype::PType::I32;
55    use vortex_dtype::{DType, Nullability};
56    use vortex_scalar::Scalar;
57
58    use crate::{SparseArray, SparseVTable};
59
60    fn test_array_fill_value() -> Scalar {
61        // making this const is annoying
62        Scalar::null_typed::<f64>()
63    }
64
65    fn sparse_array() -> ArrayRef {
66        SparseArray::try_new(
67            buffer![0u64, 37, 47, 99].into_array(),
68            PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
69            100,
70            test_array_fill_value(),
71        )
72        .unwrap()
73        .into_array()
74    }
75
76    #[test]
77    fn take_with_non_zero_offset() {
78        let sparse = sparse_array();
79        let sparse = sparse.slice(30, 40).unwrap();
80        let sparse = take(&sparse, &buffer![6, 7, 8].into_array()).unwrap();
81        assert_eq!(sparse.scalar_at(0).unwrap(), test_array_fill_value());
82        assert_eq!(sparse.scalar_at(1).unwrap(), Scalar::from(Some(0.47)));
83        assert_eq!(sparse.scalar_at(2).unwrap(), test_array_fill_value());
84    }
85
86    #[test]
87    fn sparse_take() {
88        let sparse = sparse_array();
89        let prim = take(&sparse, &buffer![0, 47, 47, 0, 99].into_array())
90            .unwrap()
91            .to_primitive()
92            .unwrap();
93        assert_eq!(prim.as_slice::<f64>(), [1.23f64, 9.99, 9.99, 1.23, 3.5]);
94    }
95
96    #[test]
97    fn nonexistent_take() {
98        let sparse = sparse_array();
99        let taken = take(&sparse, &buffer![69].into_array()).unwrap();
100        assert_eq!(taken.len(), 1);
101        assert_eq!(taken.scalar_at(0).unwrap(), test_array_fill_value());
102    }
103
104    #[test]
105    fn ordered_take() {
106        let sparse = sparse_array();
107        let taken_arr = take(&sparse, &buffer![69, 37].into_array()).unwrap();
108        let taken = taken_arr.as_::<SparseVTable>();
109
110        assert_eq!(
111            taken
112                .patches()
113                .indices()
114                .to_primitive()
115                .unwrap()
116                .as_slice::<u64>(),
117            [1]
118        );
119        assert_eq!(
120            taken
121                .patches()
122                .values()
123                .to_primitive()
124                .unwrap()
125                .as_slice::<f64>(),
126            [0.47f64]
127        );
128        assert_eq!(taken.len(), 2);
129    }
130
131    #[test]
132    fn nullable_take() {
133        let arr = SparseArray::try_new(
134            buffer![1u32].into_array(),
135            buffer![10].into_array(),
136            10,
137            Scalar::primitive(1, Nullability::NonNullable),
138        )
139        .unwrap();
140
141        let taken = take(
142            arr.as_ref(),
143            PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
144                .as_ref(),
145        )
146        .unwrap();
147
148        assert_eq!(
149            taken.scalar_at(0).unwrap(),
150            Scalar::primitive(1, Nullability::Nullable)
151        );
152        assert_eq!(
153            taken.scalar_at(1).unwrap(),
154            Scalar::primitive(10, Nullability::Nullable)
155        );
156        assert_eq!(
157            taken.scalar_at(2).unwrap(),
158            Scalar::null(DType::Primitive(I32, Nullability::Nullable))
159        );
160    }
161
162    #[test]
163    fn nullable_take_with_many_patches() {
164        let arr = SparseArray::try_new(
165            buffer![1u32, 3, 7, 8, 9].into_array(),
166            buffer![10, 8, 3, 2, 1].into_array(),
167            10,
168            Scalar::primitive(1, Nullability::NonNullable),
169        )
170        .unwrap();
171
172        let taken = take(
173            arr.as_ref(),
174            PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
175                .as_ref(),
176        )
177        .unwrap();
178
179        assert_eq!(
180            taken.scalar_at(0).unwrap(),
181            Scalar::primitive(1, Nullability::Nullable)
182        );
183        assert_eq!(
184            taken.scalar_at(1).unwrap(),
185            Scalar::primitive(10, Nullability::Nullable)
186        );
187        assert_eq!(
188            taken.scalar_at(2).unwrap(),
189            Scalar::null(DType::Primitive(I32, Nullability::Nullable))
190        );
191    }
192}