vortex_sparse/compute/
take.rs

1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::{TakeKernel, TakeKernelAdapter};
3use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
4use vortex_error::VortexResult;
5
6use crate::{SparseArray, SparseVTable};
7
8impl TakeKernel for SparseVTable {
9    fn take(&self, array: &SparseArray, take_indices: &dyn Array) -> VortexResult<ArrayRef> {
10        let Some(new_patches) = array.patches().take(take_indices)? else {
11            let result_fill_scalar = array.fill_scalar().cast(
12                &array
13                    .dtype()
14                    .union_nullability(take_indices.dtype().nullability()),
15            )?;
16            return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array());
17        };
18
19        // See `SparseEncoding::slice`.
20        if new_patches.array_len() == new_patches.values().len() {
21            return Ok(new_patches.into_values());
22        }
23
24        Ok(
25            SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
26                .into_array(),
27        )
28    }
29}
30
31register_kernel!(TakeKernelAdapter(SparseVTable).lift());
32
33#[cfg(test)]
34mod test {
35    use vortex_array::arrays::PrimitiveArray;
36    use vortex_array::compute::take;
37    use vortex_array::validity::Validity;
38    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
39    use vortex_buffer::buffer;
40    use vortex_scalar::Scalar;
41
42    use crate::{SparseArray, SparseVTable};
43
44    fn test_array_fill_value() -> Scalar {
45        // making this const is annoying
46        Scalar::null_typed::<f64>()
47    }
48
49    fn sparse_array() -> ArrayRef {
50        SparseArray::try_new(
51            buffer![0u64, 37, 47, 99].into_array(),
52            PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
53            100,
54            test_array_fill_value(),
55        )
56        .unwrap()
57        .into_array()
58    }
59
60    #[test]
61    fn take_with_non_zero_offset() {
62        let sparse = sparse_array();
63        let sparse = sparse.slice(30, 40).unwrap();
64        let sparse = take(&sparse, &buffer![6, 7, 8].into_array()).unwrap();
65        assert_eq!(sparse.scalar_at(0).unwrap(), test_array_fill_value());
66        assert_eq!(sparse.scalar_at(1).unwrap(), Scalar::from(Some(0.47)));
67        assert_eq!(sparse.scalar_at(2).unwrap(), test_array_fill_value());
68    }
69
70    #[test]
71    fn sparse_take() {
72        let sparse = sparse_array();
73        let prim = take(&sparse, &buffer![0, 47, 47, 0, 99].into_array())
74            .unwrap()
75            .to_primitive()
76            .unwrap();
77        assert_eq!(prim.as_slice::<f64>(), [1.23f64, 9.99, 9.99, 1.23, 3.5]);
78    }
79
80    #[test]
81    fn nonexistent_take() {
82        let sparse = sparse_array();
83        let taken = take(&sparse, &buffer![69].into_array()).unwrap();
84        assert_eq!(taken.len(), 1);
85        assert_eq!(taken.scalar_at(0).unwrap(), test_array_fill_value());
86    }
87
88    #[test]
89    fn ordered_take() {
90        let sparse = sparse_array();
91        let taken_arr = take(&sparse, &buffer![69, 37].into_array()).unwrap();
92        let taken = taken_arr.as_::<SparseVTable>();
93
94        assert_eq!(
95            taken
96                .patches()
97                .indices()
98                .to_primitive()
99                .unwrap()
100                .as_slice::<u64>(),
101            [1]
102        );
103        assert_eq!(
104            taken
105                .patches()
106                .values()
107                .to_primitive()
108                .unwrap()
109                .as_slice::<f64>(),
110            [0.47f64]
111        );
112        assert_eq!(taken.len(), 2);
113    }
114}