1use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{TakeKernel, TakeKernelAdapter};
6use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8
9use crate::{SparseArray, SparseVTable};
10
11impl TakeKernel for SparseVTable {
12 fn take(&self, array: &SparseArray, take_indices: &dyn Array) -> VortexResult<ArrayRef> {
13 let patches_take = if array.fill_scalar().is_null() {
14 array.patches().take(take_indices)?
15 } else {
16 array.patches().take_with_nulls(take_indices)?
17 };
18
19 let Some(new_patches) = patches_take else {
20 let result_fill_scalar = array.fill_scalar().cast(
21 &array
22 .dtype()
23 .union_nullability(take_indices.dtype().nullability()),
24 )?;
25 return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array());
26 };
27
28 if new_patches.array_len() == new_patches.values().len() {
30 return Ok(new_patches.into_values());
31 }
32
33 Ok(SparseArray::try_new_from_patches(
34 new_patches,
35 array.fill_scalar().cast(
36 &array
37 .dtype()
38 .union_nullability(take_indices.dtype().nullability()),
39 )?,
40 )?
41 .into_array())
42 }
43}
44
45register_kernel!(TakeKernelAdapter(SparseVTable).lift());
46
47#[cfg(test)]
48mod test {
49 use rstest::rstest;
50 use vortex_array::arrays::PrimitiveArray;
51 use vortex_array::compute::take;
52 use vortex_array::validity::Validity;
53 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
54 use vortex_buffer::buffer;
55 use vortex_dtype::PType::I32;
56 use vortex_dtype::{DType, Nullability};
57 use vortex_scalar::Scalar;
58
59 use crate::{SparseArray, SparseVTable};
60
61 fn test_array_fill_value() -> Scalar {
62 Scalar::null_typed::<f64>()
64 }
65
66 fn sparse_array() -> ArrayRef {
67 SparseArray::try_new(
68 buffer![0u64, 37, 47, 99].into_array(),
69 PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
70 100,
71 test_array_fill_value(),
72 )
73 .unwrap()
74 .into_array()
75 }
76
77 #[test]
78 fn take_with_non_zero_offset() {
79 let sparse = sparse_array();
80 let sparse = sparse.slice(30, 40).unwrap();
81 let sparse = take(&sparse, &buffer![6, 7, 8].into_array()).unwrap();
82 assert_eq!(sparse.scalar_at(0).unwrap(), test_array_fill_value());
83 assert_eq!(sparse.scalar_at(1).unwrap(), Scalar::from(Some(0.47)));
84 assert_eq!(sparse.scalar_at(2).unwrap(), test_array_fill_value());
85 }
86
87 #[test]
88 fn sparse_take() {
89 let sparse = sparse_array();
90 let prim = take(&sparse, &buffer![0, 47, 47, 0, 99].into_array())
91 .unwrap()
92 .to_primitive()
93 .unwrap();
94 assert_eq!(prim.as_slice::<f64>(), [1.23f64, 9.99, 9.99, 1.23, 3.5]);
95 }
96
97 #[test]
98 fn nonexistent_take() {
99 let sparse = sparse_array();
100 let taken = take(&sparse, &buffer![69].into_array()).unwrap();
101 assert_eq!(taken.len(), 1);
102 assert_eq!(taken.scalar_at(0).unwrap(), test_array_fill_value());
103 }
104
105 #[test]
106 fn ordered_take() {
107 let sparse = sparse_array();
108 let taken_arr = take(&sparse, &buffer![69, 37].into_array()).unwrap();
109 let taken = taken_arr.as_::<SparseVTable>();
110
111 assert_eq!(
112 taken
113 .patches()
114 .indices()
115 .to_primitive()
116 .unwrap()
117 .as_slice::<u64>(),
118 [1]
119 );
120 assert_eq!(
121 taken
122 .patches()
123 .values()
124 .to_primitive()
125 .unwrap()
126 .as_slice::<f64>(),
127 [0.47f64]
128 );
129 assert_eq!(taken.len(), 2);
130 }
131
132 #[test]
133 fn nullable_take() {
134 let arr = SparseArray::try_new(
135 buffer![1u32].into_array(),
136 buffer![10].into_array(),
137 10,
138 Scalar::primitive(1, Nullability::NonNullable),
139 )
140 .unwrap();
141
142 let taken = take(
143 arr.as_ref(),
144 PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
145 .as_ref(),
146 )
147 .unwrap();
148
149 assert_eq!(
150 taken.scalar_at(0).unwrap(),
151 Scalar::primitive(1, Nullability::Nullable)
152 );
153 assert_eq!(
154 taken.scalar_at(1).unwrap(),
155 Scalar::primitive(10, Nullability::Nullable)
156 );
157 assert_eq!(
158 taken.scalar_at(2).unwrap(),
159 Scalar::null(DType::Primitive(I32, Nullability::Nullable))
160 );
161 }
162
163 #[test]
164 fn nullable_take_with_many_patches() {
165 let arr = SparseArray::try_new(
166 buffer![1u32, 3, 7, 8, 9].into_array(),
167 buffer![10, 8, 3, 2, 1].into_array(),
168 10,
169 Scalar::primitive(1, Nullability::NonNullable),
170 )
171 .unwrap();
172
173 let taken = take(
174 arr.as_ref(),
175 PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
176 .as_ref(),
177 )
178 .unwrap();
179
180 assert_eq!(
181 taken.scalar_at(0).unwrap(),
182 Scalar::primitive(1, Nullability::Nullable)
183 );
184 assert_eq!(
185 taken.scalar_at(1).unwrap(),
186 Scalar::primitive(10, Nullability::Nullable)
187 );
188 assert_eq!(
189 taken.scalar_at(2).unwrap(),
190 Scalar::null(DType::Primitive(I32, Nullability::Nullable))
191 );
192 }
193
194 #[rstest]
195 #[case(SparseArray::try_new(
196 buffer![0u64, 37, 47, 99].into_array(),
197 PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
198 100,
199 Scalar::null_typed::<f64>(),
200 ).unwrap())]
201 #[case(SparseArray::try_new(
202 buffer![1u32, 3, 7, 8, 9].into_array(),
203 buffer![10, 8, 3, 2, 1].into_array(),
204 10,
205 Scalar::from(0i32),
206 ).unwrap())]
207 #[case({
208 let nullable_values = PrimitiveArray::from_option_iter([Some(100i64), None, Some(300)]);
209 SparseArray::try_new(
210 buffer![2u64, 4, 6].into_array(),
211 nullable_values.into_array(),
212 10,
213 Scalar::null_typed::<i64>(),
214 ).unwrap()
215 })]
216 #[case(SparseArray::try_new(
217 buffer![5u64].into_array(),
218 buffer![999i32].into_array(),
219 20,
220 Scalar::from(-1i32),
221 ).unwrap())]
222 fn test_take_sparse_conformance(#[case] sparse: SparseArray) {
223 use vortex_array::compute::conformance::take::test_take_conformance;
224 test_take_conformance(sparse.as_ref());
225 }
226}