vortex_array/arrays/list/compute/
take.rs1use arrow_buffer::BooleanBufferBuilder;
5use num_traits::PrimInt;
6use vortex_dtype::{NativePType, Nullability, match_each_integer_ptype};
7use vortex_error::{VortexExpect, VortexResult, vortex_panic};
8use vortex_mask::Mask;
9
10use crate::arrays::{ListArray, ListVTable, OffsetPType, PrimitiveArray};
11use crate::builders::{ArrayBuilder, PrimitiveBuilder};
12use crate::compute::{TakeKernel, TakeKernelAdapter, take};
13use crate::validity::Validity;
14use crate::vtable::ValidityHelper;
15use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
16
17impl TakeKernel for ListVTable {
18 fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
19 let indices = indices.to_primitive()?;
20 let offsets = array.offsets().to_primitive()?;
21
22 match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
23 match_each_integer_ptype!(indices.ptype(), |I| {
24 Ok(_take::<I, O>(
25 array,
26 offsets.as_slice::<O>(),
27 &indices,
28 array.validity_mask()?,
29 indices.validity_mask()?,
30 )?
31 .into_array())
32 })
33 })
34 }
35}
36
37register_kernel!(TakeKernelAdapter(ListVTable).lift());
38
39fn _take<I: NativePType, O: OffsetPType + NativePType + PrimInt>(
40 array: &ListArray,
41 offsets: &[O],
42 indices_array: &PrimitiveArray,
43 data_validity: Mask,
44 indices_validity_mask: Mask,
45) -> VortexResult<ArrayRef> {
46 let indices: &[I] = indices_array.as_slice::<I>();
47
48 if !indices_validity_mask.all_true() || !data_validity.all_true() {
49 return _take_nullable::<I, O>(
50 array,
51 offsets,
52 indices,
53 data_validity,
54 indices_validity_mask,
55 );
56 }
57
58 let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
59 let mut elements_to_take =
60 PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
61
62 let mut current_offset = O::zero();
63 new_offsets.append_zero();
64
65 for &data_idx in indices {
66 let data_idx = data_idx
67 .to_usize()
68 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
69
70 let start = offsets[data_idx];
71 let stop = offsets[data_idx + 1];
72
73 let additional = (stop - start).to_usize().unwrap_or_else(|| {
79 vortex_panic!("Failed to convert range length to usize: {}", stop - start)
80 });
81
82 elements_to_take.ensure_capacity(elements_to_take.len() + additional);
83 for i in 0..additional {
84 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
85 }
86 current_offset = current_offset + (stop - start);
87 new_offsets.append_value(current_offset);
88 }
89
90 let elements_to_take = elements_to_take.finish();
91 let new_offsets = new_offsets.finish();
92
93 let new_elements = take(array.elements(), elements_to_take.as_ref())?;
94
95 Ok(ListArray::try_new(
96 new_elements,
97 new_offsets,
98 indices_array
99 .validity()
100 .clone()
101 .and(array.validity().clone())?,
102 )?
103 .to_array())
104}
105
106fn _take_nullable<I: NativePType, O: OffsetPType + NativePType + PrimInt>(
107 array: &ListArray,
108 offsets: &[O],
109 indices: &[I],
110 data_validity: Mask,
111 indices_validity: Mask,
112) -> VortexResult<ArrayRef> {
113 let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
114 let mut elements_to_take =
115 PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
116
117 let mut current_offset = O::zero();
118 new_offsets.append_zero();
119 let mut new_validity = BooleanBufferBuilder::new(2 * indices.len());
120
121 for (idx, data_idx) in indices.iter().enumerate() {
122 if !indices_validity.value(idx) {
123 new_offsets.append_value(current_offset);
124 new_validity.append(false);
125 continue;
126 }
127
128 let data_idx = data_idx
129 .to_usize()
130 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
131
132 if data_validity.value(data_idx) {
133 let start = offsets[data_idx];
134 let stop = offsets[data_idx + 1];
135
136 let additional = (stop - start).to_usize().unwrap_or_else(|| {
138 vortex_panic!("Failed to convert range length to usize: {}", stop - start)
139 });
140
141 elements_to_take.ensure_capacity(elements_to_take.len() + additional);
142 for i in 0..additional {
143 elements_to_take
144 .append_value(start + O::from_usize(i).vortex_expect("i < additional"));
145 }
146 current_offset = current_offset + (stop - start);
147 new_offsets.append_value(current_offset);
148 new_validity.append(true);
149 } else {
150 new_offsets.append_value(current_offset);
151 new_validity.append(false);
152 }
153 }
154
155 let elements_to_take = elements_to_take.finish();
156 let new_offsets = new_offsets.finish();
157 let new_elements = take(array.elements(), elements_to_take.as_ref())?;
158
159 let new_validity: Validity = Validity::from(new_validity.finish());
160 Ok(ListArray::try_new(new_elements, new_offsets, new_validity)?.to_array())
163}
164
165#[cfg(test)]
166mod test {
167 use std::sync::Arc;
168
169 use vortex_dtype::PType::I32;
170 use vortex_dtype::{DType, Nullability};
171 use vortex_scalar::Scalar;
172
173 use crate::arrays::list::ListArray;
174 use crate::arrays::{BoolArray, PrimitiveArray};
175 use crate::compute::take;
176 use crate::validity::Validity;
177 use crate::{Array, ToCanonical};
178
179 #[test]
180 fn nullable_take() {
181 let list = ListArray::try_new(
182 PrimitiveArray::from_iter([0i32, 5, 3, 4]).to_array(),
183 PrimitiveArray::from_iter([0, 2, 3, 4, 4]).to_array(),
184 Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
185 )
186 .unwrap()
187 .to_array();
188
189 let idx =
190 PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
191
192 let result = take(&list, &idx).unwrap();
193
194 assert_eq!(
195 result.dtype(),
196 &DType::List(
197 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
198 Nullability::Nullable
199 )
200 );
201
202 let result = result.to_list().unwrap();
203
204 assert_eq!(result.len(), 4);
205
206 let element_dtype: Arc<DType> = Arc::new(I32.into());
207
208 assert!(result.is_valid(0).unwrap());
209 assert_eq!(
210 result.scalar_at(0).unwrap(),
211 Scalar::list(
212 element_dtype.clone(),
213 vec![0i32.into(), 5.into()],
214 Nullability::Nullable
215 )
216 );
217
218 assert!(result.is_invalid(1).unwrap());
219
220 assert!(result.is_valid(2).unwrap());
221 assert_eq!(
222 result.scalar_at(2).unwrap(),
223 Scalar::list(
224 element_dtype.clone(),
225 vec![3i32.into()],
226 Nullability::Nullable
227 )
228 );
229
230 assert!(result.is_valid(3).unwrap());
231 assert_eq!(
232 result.scalar_at(3).unwrap(),
233 Scalar::list(element_dtype, vec![], Nullability::Nullable)
234 );
235 }
236
237 #[test]
238 fn change_validity() {
239 let list = ListArray::try_new(
240 PrimitiveArray::from_iter([0i32, 5, 3, 4]).to_array(),
241 PrimitiveArray::from_iter([0, 2, 3]).to_array(),
242 Validity::NonNullable,
243 )
244 .unwrap()
245 .to_array();
246
247 let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
248 let result = take(&list, &idx).unwrap();
251 assert_eq!(
252 result.dtype(),
253 &DType::List(
254 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
255 Nullability::Nullable
256 )
257 );
258 }
259
260 #[test]
261 fn non_nullable_take() {
262 let list = ListArray::try_new(
263 PrimitiveArray::from_iter([0i32, 5, 3, 4]).to_array(),
264 PrimitiveArray::from_iter([0, 2, 3, 3, 4]).to_array(),
265 Validity::NonNullable,
266 )
267 .unwrap()
268 .to_array();
269
270 let idx = PrimitiveArray::from_iter([1, 0, 2]).to_array();
271
272 let result = take(&list, &idx).unwrap();
273
274 assert_eq!(
275 result.dtype(),
276 &DType::List(
277 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
278 Nullability::NonNullable
279 )
280 );
281
282 let result = result.to_list().unwrap();
283
284 assert_eq!(result.len(), 3);
285
286 let element_dtype: Arc<DType> = Arc::new(I32.into());
287
288 assert!(result.is_valid(0).unwrap());
289 assert_eq!(
290 result.scalar_at(0).unwrap(),
291 Scalar::list(
292 element_dtype.clone(),
293 vec![3i32.into()],
294 Nullability::NonNullable
295 )
296 );
297
298 assert!(result.is_valid(1).unwrap());
299 assert_eq!(
300 result.scalar_at(1).unwrap(),
301 Scalar::list(
302 element_dtype.clone(),
303 vec![0i32.into(), 5.into()],
304 Nullability::NonNullable
305 )
306 );
307
308 assert!(result.is_valid(2).unwrap());
309 assert_eq!(
310 result.scalar_at(2).unwrap(),
311 Scalar::list(element_dtype, vec![], Nullability::NonNullable)
312 );
313 }
314
315 #[test]
316 fn test_take_empty_array() {
317 let list = ListArray::try_new(
318 PrimitiveArray::from_iter([0i32, 5, 3, 4]).to_array(),
319 PrimitiveArray::from_iter([0]).to_array(),
320 Validity::NonNullable,
321 )
322 .unwrap()
323 .to_array();
324
325 let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
326
327 let result = take(&list, &idx).unwrap();
328 assert_eq!(
329 result.dtype(),
330 &DType::List(
331 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
332 Nullability::Nullable
333 )
334 );
335 assert_eq!(result.len(), 0,);
336 }
337}