vortex_sparse/compute/
take.rs

1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::TakeFn;
3use vortex_array::{Array, ArrayRef};
4use vortex_error::VortexResult;
5
6use crate::{SparseArray, SparseEncoding};
7
8impl TakeFn<&SparseArray> for SparseEncoding {
9    fn take(&self, array: &SparseArray, take_indices: &dyn Array) -> VortexResult<ArrayRef> {
10        let Some(new_patches) = array.patches().take(take_indices)? else {
11            let result_nullability =
12                array.dtype().nullability() | take_indices.dtype().nullability();
13            let result_fill_scalar = array
14                .fill_scalar()
15                .cast(&array.dtype().with_nullability(result_nullability))?;
16            return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array());
17        };
18
19        // See `SparseEncoding::slice`.
20        if new_patches.array_len() == new_patches.values().len() {
21            return Ok(new_patches.into_values());
22        }
23
24        Ok(
25            SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
26                .into_array(),
27        )
28    }
29}
30
31#[cfg(test)]
32mod test {
33    use vortex_array::arrays::PrimitiveArray;
34    use vortex_array::compute::{scalar_at, slice, take};
35    use vortex_array::validity::Validity;
36    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
37    use vortex_buffer::buffer;
38    use vortex_scalar::Scalar;
39
40    use crate::SparseArray;
41
42    fn test_array_fill_value() -> Scalar {
43        // making this const is annoying
44        Scalar::null_typed::<f64>()
45    }
46
47    fn sparse_array() -> ArrayRef {
48        SparseArray::try_new(
49            buffer![0u64, 37, 47, 99].into_array(),
50            PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
51            100,
52            test_array_fill_value(),
53        )
54        .unwrap()
55        .into_array()
56    }
57
58    #[test]
59    fn take_with_non_zero_offset() {
60        let sparse = sparse_array();
61        let sparse = slice(&sparse, 30, 40).unwrap();
62        let sparse = take(&sparse, &buffer![6, 7, 8].into_array()).unwrap();
63        assert_eq!(scalar_at(&sparse, 0).unwrap(), test_array_fill_value());
64        assert_eq!(scalar_at(&sparse, 1).unwrap(), Scalar::from(Some(0.47)));
65        assert_eq!(scalar_at(&sparse, 2).unwrap(), test_array_fill_value());
66    }
67
68    #[test]
69    fn sparse_take() {
70        let sparse = sparse_array();
71        let prim = take(&sparse, &buffer![0, 47, 47, 0, 99].into_array())
72            .unwrap()
73            .to_primitive()
74            .unwrap();
75        assert_eq!(prim.as_slice::<f64>(), [1.23f64, 9.99, 9.99, 1.23, 3.5]);
76    }
77
78    #[test]
79    fn nonexistent_take() {
80        let sparse = sparse_array();
81        let taken = take(&sparse, &buffer![69].into_array()).unwrap();
82        assert_eq!(taken.len(), 1);
83        assert_eq!(scalar_at(&taken, 0).unwrap(), test_array_fill_value());
84    }
85
86    #[test]
87    fn ordered_take() {
88        let sparse = sparse_array();
89        let taken =
90            SparseArray::try_from(take(&sparse, &buffer![69, 37].into_array()).unwrap()).unwrap();
91        assert_eq!(
92            taken
93                .patches()
94                .indices()
95                .to_primitive()
96                .unwrap()
97                .as_slice::<u64>(),
98            [1]
99        );
100        assert_eq!(
101            taken
102                .patches()
103                .values()
104                .to_primitive()
105                .unwrap()
106                .as_slice::<f64>(),
107            [0.47f64]
108        );
109        assert_eq!(taken.len(), 2);
110    }
111}