1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::{TakeKernel, TakeKernelAdapter};
3use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
4use vortex_error::VortexResult;
5
6use crate::{SparseArray, SparseVTable};
7
8impl TakeKernel for SparseVTable {
9 fn take(&self, array: &SparseArray, take_indices: &dyn Array) -> VortexResult<ArrayRef> {
10 let patches_take = if array.fill_scalar().is_null() {
11 array.patches().take(take_indices)?
12 } else {
13 array.patches().take_with_nulls(take_indices)?
14 };
15
16 let Some(new_patches) = patches_take else {
17 let result_fill_scalar = array.fill_scalar().cast(
18 &array
19 .dtype()
20 .union_nullability(take_indices.dtype().nullability()),
21 )?;
22 return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array());
23 };
24
25 if new_patches.array_len() == new_patches.values().len() {
27 return Ok(new_patches.into_values());
28 }
29
30 Ok(SparseArray::try_new_from_patches(
31 new_patches,
32 array.fill_scalar().cast(
33 &array
34 .dtype()
35 .union_nullability(take_indices.dtype().nullability()),
36 )?,
37 )?
38 .into_array())
39 }
40}
41
42register_kernel!(TakeKernelAdapter(SparseVTable).lift());
43
44#[cfg(test)]
45mod test {
46 use vortex_array::arrays::PrimitiveArray;
47 use vortex_array::compute::take;
48 use vortex_array::validity::Validity;
49 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
50 use vortex_buffer::buffer;
51 use vortex_dtype::PType::I32;
52 use vortex_dtype::{DType, Nullability};
53 use vortex_scalar::Scalar;
54
55 use crate::{SparseArray, SparseVTable};
56
57 fn test_array_fill_value() -> Scalar {
58 Scalar::null_typed::<f64>()
60 }
61
62 fn sparse_array() -> ArrayRef {
63 SparseArray::try_new(
64 buffer![0u64, 37, 47, 99].into_array(),
65 PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
66 100,
67 test_array_fill_value(),
68 )
69 .unwrap()
70 .into_array()
71 }
72
73 #[test]
74 fn take_with_non_zero_offset() {
75 let sparse = sparse_array();
76 let sparse = sparse.slice(30, 40).unwrap();
77 let sparse = take(&sparse, &buffer![6, 7, 8].into_array()).unwrap();
78 assert_eq!(sparse.scalar_at(0).unwrap(), test_array_fill_value());
79 assert_eq!(sparse.scalar_at(1).unwrap(), Scalar::from(Some(0.47)));
80 assert_eq!(sparse.scalar_at(2).unwrap(), test_array_fill_value());
81 }
82
83 #[test]
84 fn sparse_take() {
85 let sparse = sparse_array();
86 let prim = take(&sparse, &buffer![0, 47, 47, 0, 99].into_array())
87 .unwrap()
88 .to_primitive()
89 .unwrap();
90 assert_eq!(prim.as_slice::<f64>(), [1.23f64, 9.99, 9.99, 1.23, 3.5]);
91 }
92
93 #[test]
94 fn nonexistent_take() {
95 let sparse = sparse_array();
96 let taken = take(&sparse, &buffer![69].into_array()).unwrap();
97 assert_eq!(taken.len(), 1);
98 assert_eq!(taken.scalar_at(0).unwrap(), test_array_fill_value());
99 }
100
101 #[test]
102 fn ordered_take() {
103 let sparse = sparse_array();
104 let taken_arr = take(&sparse, &buffer![69, 37].into_array()).unwrap();
105 let taken = taken_arr.as_::<SparseVTable>();
106
107 assert_eq!(
108 taken
109 .patches()
110 .indices()
111 .to_primitive()
112 .unwrap()
113 .as_slice::<u64>(),
114 [1]
115 );
116 assert_eq!(
117 taken
118 .patches()
119 .values()
120 .to_primitive()
121 .unwrap()
122 .as_slice::<f64>(),
123 [0.47f64]
124 );
125 assert_eq!(taken.len(), 2);
126 }
127
128 #[test]
129 fn nullable_take() {
130 let arr = SparseArray::try_new(
131 buffer![1u32].into_array(),
132 buffer![10].into_array(),
133 10,
134 Scalar::primitive(1, Nullability::NonNullable),
135 )
136 .unwrap();
137
138 let taken = take(
139 arr.as_ref(),
140 PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
141 .as_ref(),
142 )
143 .unwrap();
144
145 assert_eq!(
146 taken.scalar_at(0).unwrap(),
147 Scalar::primitive(1, Nullability::Nullable)
148 );
149 assert_eq!(
150 taken.scalar_at(1).unwrap(),
151 Scalar::primitive(10, Nullability::Nullable)
152 );
153 assert_eq!(
154 taken.scalar_at(2).unwrap(),
155 Scalar::null(DType::Primitive(I32, Nullability::Nullable))
156 );
157 }
158
159 #[test]
160 fn nullable_take_with_many_patches() {
161 let arr = SparseArray::try_new(
162 buffer![1u32, 3, 7, 8, 9].into_array(),
163 buffer![10, 8, 3, 2, 1].into_array(),
164 10,
165 Scalar::primitive(1, Nullability::NonNullable),
166 )
167 .unwrap();
168
169 let taken = take(
170 arr.as_ref(),
171 PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
172 .as_ref(),
173 )
174 .unwrap();
175
176 assert_eq!(
177 taken.scalar_at(0).unwrap(),
178 Scalar::primitive(1, Nullability::Nullable)
179 );
180 assert_eq!(
181 taken.scalar_at(1).unwrap(),
182 Scalar::primitive(10, Nullability::Nullable)
183 );
184 assert_eq!(
185 taken.scalar_at(2).unwrap(),
186 Scalar::null(DType::Primitive(I32, Nullability::Nullable))
187 );
188 }
189}