vortex_sparse/compute/
take.rs

1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::{TakeKernel, TakeKernelAdapter};
3use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
4use vortex_error::VortexResult;
5
6use crate::{SparseArray, SparseVTable};
7
8impl TakeKernel for SparseVTable {
9    fn take(&self, array: &SparseArray, take_indices: &dyn Array) -> VortexResult<ArrayRef> {
10        let patches_take = if array.fill_scalar().is_null() {
11            array.patches().take(take_indices)?
12        } else {
13            array.patches().take_with_nulls(take_indices)?
14        };
15
16        let Some(new_patches) = patches_take else {
17            let result_fill_scalar = array.fill_scalar().cast(
18                &array
19                    .dtype()
20                    .union_nullability(take_indices.dtype().nullability()),
21            )?;
22            return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array());
23        };
24
25        // See `SparseEncoding::slice`.
26        if new_patches.array_len() == new_patches.values().len() {
27            return Ok(new_patches.into_values());
28        }
29
30        Ok(SparseArray::try_new_from_patches(
31            new_patches,
32            array.fill_scalar().cast(
33                &array
34                    .dtype()
35                    .union_nullability(take_indices.dtype().nullability()),
36            )?,
37        )?
38        .into_array())
39    }
40}
41
42register_kernel!(TakeKernelAdapter(SparseVTable).lift());
43
44#[cfg(test)]
45mod test {
46    use vortex_array::arrays::PrimitiveArray;
47    use vortex_array::compute::take;
48    use vortex_array::validity::Validity;
49    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
50    use vortex_buffer::buffer;
51    use vortex_dtype::PType::I32;
52    use vortex_dtype::{DType, Nullability};
53    use vortex_scalar::Scalar;
54
55    use crate::{SparseArray, SparseVTable};
56
57    fn test_array_fill_value() -> Scalar {
58        // making this const is annoying
59        Scalar::null_typed::<f64>()
60    }
61
62    fn sparse_array() -> ArrayRef {
63        SparseArray::try_new(
64            buffer![0u64, 37, 47, 99].into_array(),
65            PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
66            100,
67            test_array_fill_value(),
68        )
69        .unwrap()
70        .into_array()
71    }
72
73    #[test]
74    fn take_with_non_zero_offset() {
75        let sparse = sparse_array();
76        let sparse = sparse.slice(30, 40).unwrap();
77        let sparse = take(&sparse, &buffer![6, 7, 8].into_array()).unwrap();
78        assert_eq!(sparse.scalar_at(0).unwrap(), test_array_fill_value());
79        assert_eq!(sparse.scalar_at(1).unwrap(), Scalar::from(Some(0.47)));
80        assert_eq!(sparse.scalar_at(2).unwrap(), test_array_fill_value());
81    }
82
83    #[test]
84    fn sparse_take() {
85        let sparse = sparse_array();
86        let prim = take(&sparse, &buffer![0, 47, 47, 0, 99].into_array())
87            .unwrap()
88            .to_primitive()
89            .unwrap();
90        assert_eq!(prim.as_slice::<f64>(), [1.23f64, 9.99, 9.99, 1.23, 3.5]);
91    }
92
93    #[test]
94    fn nonexistent_take() {
95        let sparse = sparse_array();
96        let taken = take(&sparse, &buffer![69].into_array()).unwrap();
97        assert_eq!(taken.len(), 1);
98        assert_eq!(taken.scalar_at(0).unwrap(), test_array_fill_value());
99    }
100
101    #[test]
102    fn ordered_take() {
103        let sparse = sparse_array();
104        let taken_arr = take(&sparse, &buffer![69, 37].into_array()).unwrap();
105        let taken = taken_arr.as_::<SparseVTable>();
106
107        assert_eq!(
108            taken
109                .patches()
110                .indices()
111                .to_primitive()
112                .unwrap()
113                .as_slice::<u64>(),
114            [1]
115        );
116        assert_eq!(
117            taken
118                .patches()
119                .values()
120                .to_primitive()
121                .unwrap()
122                .as_slice::<f64>(),
123            [0.47f64]
124        );
125        assert_eq!(taken.len(), 2);
126    }
127
128    #[test]
129    fn nullable_take() {
130        let arr = SparseArray::try_new(
131            buffer![1u32].into_array(),
132            buffer![10].into_array(),
133            10,
134            Scalar::primitive(1, Nullability::NonNullable),
135        )
136        .unwrap();
137
138        let taken = take(
139            arr.as_ref(),
140            PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
141                .as_ref(),
142        )
143        .unwrap();
144
145        assert_eq!(
146            taken.scalar_at(0).unwrap(),
147            Scalar::primitive(1, Nullability::Nullable)
148        );
149        assert_eq!(
150            taken.scalar_at(1).unwrap(),
151            Scalar::primitive(10, Nullability::Nullable)
152        );
153        assert_eq!(
154            taken.scalar_at(2).unwrap(),
155            Scalar::null(DType::Primitive(I32, Nullability::Nullable))
156        );
157    }
158
159    #[test]
160    fn nullable_take_with_many_patches() {
161        let arr = SparseArray::try_new(
162            buffer![1u32, 3, 7, 8, 9].into_array(),
163            buffer![10, 8, 3, 2, 1].into_array(),
164            10,
165            Scalar::primitive(1, Nullability::NonNullable),
166        )
167        .unwrap();
168
169        let taken = take(
170            arr.as_ref(),
171            PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
172                .as_ref(),
173        )
174        .unwrap();
175
176        assert_eq!(
177            taken.scalar_at(0).unwrap(),
178            Scalar::primitive(1, Nullability::Nullable)
179        );
180        assert_eq!(
181            taken.scalar_at(1).unwrap(),
182            Scalar::primitive(10, Nullability::Nullable)
183        );
184        assert_eq!(
185            taken.scalar_at(2).unwrap(),
186            Scalar::null(DType::Primitive(I32, Nullability::Nullable))
187        );
188    }
189}