vortex_sparse/compute/
take.rs

1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::TakeFn;
3use vortex_array::{Array, ArrayRef};
4use vortex_error::VortexResult;
5
6use crate::{SparseArray, SparseEncoding};
7
8impl TakeFn<&SparseArray> for SparseEncoding {
9    fn take(&self, array: &SparseArray, take_indices: &dyn Array) -> VortexResult<ArrayRef> {
10        let Some(new_patches) = array.patches().take(take_indices)? else {
11            let result_nullability =
12                array.dtype().nullability() | take_indices.dtype().nullability();
13            let result_fill_scalar = array
14                .fill_scalar()
15                .cast(&array.dtype().with_nullability(result_nullability))?;
16            return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array());
17        };
18
19        Ok(
20            SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
21                .into_array(),
22        )
23    }
24}
25
26#[cfg(test)]
27mod test {
28    use vortex_array::arrays::PrimitiveArray;
29    use vortex_array::compute::{scalar_at, slice, take};
30    use vortex_array::validity::Validity;
31    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
32    use vortex_buffer::buffer;
33    use vortex_scalar::Scalar;
34
35    use crate::SparseArray;
36
37    fn test_array_fill_value() -> Scalar {
38        // making this const is annoying
39        Scalar::null_typed::<f64>()
40    }
41
42    fn sparse_array() -> ArrayRef {
43        SparseArray::try_new(
44            buffer![0u64, 37, 47, 99].into_array(),
45            PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
46            100,
47            test_array_fill_value(),
48        )
49        .unwrap()
50        .into_array()
51    }
52
53    #[test]
54    fn take_with_non_zero_offset() {
55        let sparse = sparse_array();
56        let sparse = slice(&sparse, 30, 40).unwrap();
57        let sparse = take(&sparse, &buffer![6, 7, 8].into_array()).unwrap();
58        assert_eq!(scalar_at(&sparse, 0).unwrap(), test_array_fill_value());
59        assert_eq!(scalar_at(&sparse, 1).unwrap(), Scalar::from(Some(0.47)));
60        assert_eq!(scalar_at(&sparse, 2).unwrap(), test_array_fill_value());
61    }
62
63    #[test]
64    fn sparse_take() {
65        let sparse = sparse_array();
66        let taken =
67            SparseArray::try_from(take(&sparse, &buffer![0, 47, 47, 0, 99].into_array()).unwrap())
68                .unwrap();
69        assert_eq!(
70            taken
71                .patches()
72                .indices()
73                .to_primitive()
74                .unwrap()
75                .as_slice::<u64>(),
76            [0, 1, 2, 3, 4]
77        );
78        assert_eq!(
79            taken
80                .patches()
81                .values()
82                .to_primitive()
83                .unwrap()
84                .as_slice::<f64>(),
85            [1.23f64, 9.99, 9.99, 1.23, 3.5]
86        );
87    }
88
89    #[test]
90    fn nonexistent_take() {
91        let sparse = sparse_array();
92        let taken = take(&sparse, &buffer![69].into_array()).unwrap();
93        assert_eq!(taken.len(), 1);
94        assert_eq!(scalar_at(&taken, 0).unwrap(), test_array_fill_value());
95    }
96
97    #[test]
98    fn ordered_take() {
99        let sparse = sparse_array();
100        let taken =
101            SparseArray::try_from(take(&sparse, &buffer![69, 37].into_array()).unwrap()).unwrap();
102        assert_eq!(
103            taken
104                .patches()
105                .indices()
106                .to_primitive()
107                .unwrap()
108                .as_slice::<u64>(),
109            [1]
110        );
111        assert_eq!(
112            taken
113                .patches()
114                .values()
115                .to_primitive()
116                .unwrap()
117                .as_slice::<f64>(),
118            [0.47f64]
119        );
120        assert_eq!(taken.len(), 2);
121    }
122}