vortex_sparse/compute/
take.rs1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::{TakeKernel, TakeKernelAdapter};
3use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
4use vortex_error::VortexResult;
5
6use crate::{SparseArray, SparseVTable};
7
8impl TakeKernel for SparseVTable {
9 fn take(&self, array: &SparseArray, take_indices: &dyn Array) -> VortexResult<ArrayRef> {
10 let Some(new_patches) = array.patches().take(take_indices)? else {
11 let result_fill_scalar = array.fill_scalar().cast(
12 &array
13 .dtype()
14 .union_nullability(take_indices.dtype().nullability()),
15 )?;
16 return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array());
17 };
18
19 if new_patches.array_len() == new_patches.values().len() {
21 return Ok(new_patches.into_values());
22 }
23
24 Ok(
25 SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
26 .into_array(),
27 )
28 }
29}
30
31register_kernel!(TakeKernelAdapter(SparseVTable).lift());
32
33#[cfg(test)]
34mod test {
35 use vortex_array::arrays::PrimitiveArray;
36 use vortex_array::compute::take;
37 use vortex_array::validity::Validity;
38 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
39 use vortex_buffer::buffer;
40 use vortex_scalar::Scalar;
41
42 use crate::{SparseArray, SparseVTable};
43
44 fn test_array_fill_value() -> Scalar {
45 Scalar::null_typed::<f64>()
47 }
48
49 fn sparse_array() -> ArrayRef {
50 SparseArray::try_new(
51 buffer![0u64, 37, 47, 99].into_array(),
52 PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
53 100,
54 test_array_fill_value(),
55 )
56 .unwrap()
57 .into_array()
58 }
59
60 #[test]
61 fn take_with_non_zero_offset() {
62 let sparse = sparse_array();
63 let sparse = sparse.slice(30, 40).unwrap();
64 let sparse = take(&sparse, &buffer![6, 7, 8].into_array()).unwrap();
65 assert_eq!(sparse.scalar_at(0).unwrap(), test_array_fill_value());
66 assert_eq!(sparse.scalar_at(1).unwrap(), Scalar::from(Some(0.47)));
67 assert_eq!(sparse.scalar_at(2).unwrap(), test_array_fill_value());
68 }
69
70 #[test]
71 fn sparse_take() {
72 let sparse = sparse_array();
73 let prim = take(&sparse, &buffer![0, 47, 47, 0, 99].into_array())
74 .unwrap()
75 .to_primitive()
76 .unwrap();
77 assert_eq!(prim.as_slice::<f64>(), [1.23f64, 9.99, 9.99, 1.23, 3.5]);
78 }
79
80 #[test]
81 fn nonexistent_take() {
82 let sparse = sparse_array();
83 let taken = take(&sparse, &buffer![69].into_array()).unwrap();
84 assert_eq!(taken.len(), 1);
85 assert_eq!(taken.scalar_at(0).unwrap(), test_array_fill_value());
86 }
87
88 #[test]
89 fn ordered_take() {
90 let sparse = sparse_array();
91 let taken_arr = take(&sparse, &buffer![69, 37].into_array()).unwrap();
92 let taken = taken_arr.as_::<SparseVTable>();
93
94 assert_eq!(
95 taken
96 .patches()
97 .indices()
98 .to_primitive()
99 .unwrap()
100 .as_slice::<u64>(),
101 [1]
102 );
103 assert_eq!(
104 taken
105 .patches()
106 .values()
107 .to_primitive()
108 .unwrap()
109 .as_slice::<f64>(),
110 [0.47f64]
111 );
112 assert_eq!(taken.len(), 2);
113 }
114}