1use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::DynArray;
9use crate::IntoArray;
10use crate::arrays::ListArray;
11use crate::arrays::ListVTable;
12use crate::arrays::PrimitiveArray;
13use crate::arrays::dict::TakeExecute;
14use crate::builders::ArrayBuilder;
15use crate::builders::PrimitiveBuilder;
16use crate::dtype::IntegerPType;
17use crate::dtype::Nullability;
18use crate::executor::ExecutionCtx;
19use crate::match_each_integer_ptype;
20use crate::match_smallest_offset_type;
21use crate::vtable::ValidityHelper;
22
23impl TakeExecute for ListVTable {
27 #[expect(clippy::cognitive_complexity)]
33 fn take(
34 array: &ListArray,
35 indices: &ArrayRef,
36 ctx: &mut ExecutionCtx,
37 ) -> VortexResult<Option<ArrayRef>> {
38 let indices = indices.to_array().execute::<PrimitiveArray>(ctx)?;
39 let total_approx = array.elements().len().saturating_mul(indices.len());
41
42 match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| {
43 match_each_integer_ptype!(indices.ptype(), |I| {
44 match_smallest_offset_type!(total_approx, |OutputOffsetType| {
45 _take::<I, O, OutputOffsetType>(array, &indices, ctx).map(Some)
46 })
47 })
48 })
49 }
50}
51
52fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
53 array: &ListArray,
54 indices_array: &PrimitiveArray,
55 ctx: &mut ExecutionCtx,
56) -> VortexResult<ArrayRef> {
57 let data_validity = array.validity_mask()?;
58 let indices_validity = indices_array.validity_mask()?;
59
60 if !indices_validity.all_true() || !data_validity.all_true() {
61 return _take_nullable::<I, O, OutputOffsetType>(array, indices_array, ctx);
62 }
63
64 let offsets_array = array.offsets().to_array().execute::<PrimitiveArray>(ctx)?;
65 let offsets: &[O] = offsets_array.as_slice();
66 let indices: &[I] = indices_array.as_slice();
67
68 let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
69 Nullability::NonNullable,
70 indices.len(),
71 );
72 let mut elements_to_take =
73 PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
74
75 let mut current_offset = OutputOffsetType::zero();
76 new_offsets.append_zero();
77
78 for &data_idx in indices {
79 let data_idx: usize = data_idx.as_();
80
81 let start = offsets[data_idx];
82 let stop = offsets[data_idx + 1];
83
84 let additional: usize = (stop - start).as_();
90
91 elements_to_take.reserve_exact(additional);
93 for i in 0..additional {
94 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
95 }
96 current_offset +=
97 OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
98 new_offsets.append_value(current_offset);
99 }
100
101 let elements_to_take = elements_to_take.finish();
102 let new_offsets = new_offsets.finish();
103
104 let new_elements = array.elements().take(elements_to_take.to_array())?;
105
106 Ok(ListArray::try_new(
107 new_elements,
108 new_offsets,
109 array
110 .validity()
111 .clone()
112 .take(&indices_array.clone().into_array())?,
113 )?
114 .into_array())
115}
116
117fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
118 array: &ListArray,
119 indices_array: &PrimitiveArray,
120 ctx: &mut ExecutionCtx,
121) -> VortexResult<ArrayRef> {
122 let offsets_array = array.offsets().to_array().execute::<PrimitiveArray>(ctx)?;
123 let offsets: &[O] = offsets_array.as_slice();
124 let indices: &[I] = indices_array.as_slice();
125 let data_validity = array.validity_mask()?;
126 let indices_validity = indices_array.validity_mask()?;
127
128 let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
129 Nullability::NonNullable,
130 indices.len(),
131 );
132
133 let mut elements_to_take =
141 PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
142
143 let mut current_offset = OutputOffsetType::zero();
144 new_offsets.append_zero();
145
146 for (idx, data_idx) in indices.iter().enumerate() {
147 if !indices_validity.value(idx) {
148 new_offsets.append_value(current_offset);
149 continue;
150 }
151
152 let data_idx: usize = data_idx.as_();
153
154 if !data_validity.value(data_idx) {
155 new_offsets.append_value(current_offset);
156 continue;
157 }
158
159 let start = offsets[data_idx];
160 let stop = offsets[data_idx + 1];
161
162 let additional: usize = (stop - start).as_();
164
165 elements_to_take.reserve_exact(additional);
166 for i in 0..additional {
167 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
168 }
169 current_offset +=
170 OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
171 new_offsets.append_value(current_offset);
172 }
173
174 let elements_to_take = elements_to_take.finish();
175 let new_offsets = new_offsets.finish();
176 let new_elements = array.elements().take(elements_to_take.to_array())?;
177
178 Ok(ListArray::try_new(
179 new_elements,
180 new_offsets,
181 array
182 .validity()
183 .clone()
184 .take(&indices_array.clone().into_array())?,
185 )?
186 .into_array())
187}
188
189#[cfg(test)]
190mod test {
191 use std::sync::Arc;
192
193 use rstest::rstest;
194 use vortex_buffer::buffer;
195
196 use crate::DynArray;
197 use crate::IntoArray as _;
198 use crate::ToCanonical;
199 use crate::arrays::BoolArray;
200 use crate::arrays::ListArray;
201 use crate::arrays::PrimitiveArray;
202 use crate::compute::conformance::take::test_take_conformance;
203 use crate::dtype::DType;
204 use crate::dtype::Nullability;
205 use crate::dtype::PType::I32;
206 use crate::scalar::Scalar;
207 use crate::validity::Validity;
208
209 #[test]
210 fn nullable_take() {
211 let list = ListArray::try_new(
212 buffer![0i32, 5, 3, 4].into_array(),
213 buffer![0, 2, 3, 4, 4].into_array(),
214 Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).into_array()),
215 )
216 .unwrap()
217 .into_array();
218
219 let idx =
220 PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).into_array();
221
222 let result = list.take(idx.to_array()).unwrap();
223
224 assert_eq!(
225 result.dtype(),
226 &DType::List(
227 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
228 Nullability::Nullable
229 )
230 );
231
232 let result = result.to_listview();
233
234 assert_eq!(result.len(), 4);
235
236 let element_dtype: Arc<DType> = Arc::new(I32.into());
237
238 assert!(result.is_valid(0).unwrap());
239 assert_eq!(
240 result.scalar_at(0).unwrap(),
241 Scalar::list(
242 element_dtype.clone(),
243 vec![0i32.into(), 5.into()],
244 Nullability::Nullable
245 )
246 );
247
248 assert!(result.is_invalid(1).unwrap());
249
250 assert!(result.is_valid(2).unwrap());
251 assert_eq!(
252 result.scalar_at(2).unwrap(),
253 Scalar::list(
254 element_dtype.clone(),
255 vec![3i32.into()],
256 Nullability::Nullable
257 )
258 );
259
260 assert!(result.is_valid(3).unwrap());
261 assert_eq!(
262 result.scalar_at(3).unwrap(),
263 Scalar::list(element_dtype, vec![], Nullability::Nullable)
264 );
265 }
266
267 #[test]
268 fn change_validity() {
269 let list = ListArray::try_new(
270 buffer![0i32, 5, 3, 4].into_array(),
271 buffer![0, 2, 3].into_array(),
272 Validity::NonNullable,
273 )
274 .unwrap()
275 .into_array();
276
277 let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).into_array();
278 let result = list.take(idx.to_array()).unwrap();
281 assert_eq!(
282 result.dtype(),
283 &DType::List(
284 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
285 Nullability::Nullable
286 )
287 );
288 }
289
290 #[test]
291 fn non_nullable_take() {
292 let list = ListArray::try_new(
293 buffer![0i32, 5, 3, 4].into_array(),
294 buffer![0, 2, 3, 3, 4].into_array(),
295 Validity::NonNullable,
296 )
297 .unwrap()
298 .into_array();
299
300 let idx = buffer![1, 0, 2].into_array();
301
302 let result = list.take(idx.to_array()).unwrap();
303
304 assert_eq!(
305 result.dtype(),
306 &DType::List(
307 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
308 Nullability::NonNullable
309 )
310 );
311
312 let result = result.to_listview();
313
314 assert_eq!(result.len(), 3);
315
316 let element_dtype: Arc<DType> = Arc::new(I32.into());
317
318 assert!(result.is_valid(0).unwrap());
319 assert_eq!(
320 result.scalar_at(0).unwrap(),
321 Scalar::list(
322 element_dtype.clone(),
323 vec![3i32.into()],
324 Nullability::NonNullable
325 )
326 );
327
328 assert!(result.is_valid(1).unwrap());
329 assert_eq!(
330 result.scalar_at(1).unwrap(),
331 Scalar::list(
332 element_dtype.clone(),
333 vec![0i32.into(), 5.into()],
334 Nullability::NonNullable
335 )
336 );
337
338 assert!(result.is_valid(2).unwrap());
339 assert_eq!(
340 result.scalar_at(2).unwrap(),
341 Scalar::list(element_dtype, vec![], Nullability::NonNullable)
342 );
343 }
344
345 #[test]
346 fn test_take_empty_array() {
347 let list = ListArray::try_new(
348 buffer![0i32, 5, 3, 4].into_array(),
349 buffer![0].into_array(),
350 Validity::NonNullable,
351 )
352 .unwrap()
353 .into_array();
354
355 let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).into_array();
356
357 let result = list.take(idx.to_array()).unwrap();
358 assert_eq!(
359 result.dtype(),
360 &DType::List(
361 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
362 Nullability::Nullable
363 )
364 );
365 assert_eq!(result.len(), 0,);
366 }
367
368 #[rstest]
369 #[case(ListArray::try_new(
370 buffer![0i32, 1, 2, 3, 4, 5].into_array(),
371 buffer![0, 2, 3, 5, 5, 6].into_array(),
372 Validity::NonNullable,
373 ).unwrap())]
374 #[case(ListArray::try_new(
375 buffer![10i32, 20, 30, 40, 50].into_array(),
376 buffer![0, 2, 3, 4, 5].into_array(),
377 Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).into_array()),
378 ).unwrap())]
379 #[case(ListArray::try_new(
380 buffer![1i32, 2, 3].into_array(),
381 buffer![0, 0, 2, 2, 3].into_array(), Validity::NonNullable,
383 ).unwrap())]
384 #[case(ListArray::try_new(
385 buffer![42i32, 43].into_array(),
386 buffer![0, 2].into_array(),
387 Validity::NonNullable,
388 ).unwrap())]
389 #[case({
390 let elements = buffer![0i32..200].into_array();
391 let mut offsets = vec![0u64];
392 for i in 1..=50 {
393 offsets.push(offsets[i - 1] + (i as u64 % 5)); }
395 ListArray::try_new(
396 elements,
397 PrimitiveArray::from_iter(offsets).into_array(),
398 Validity::NonNullable,
399 ).unwrap()
400 })]
401 #[case(ListArray::try_new(
402 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).into_array(),
403 buffer![0, 2, 3, 5].into_array(),
404 Validity::NonNullable,
405 ).unwrap())]
406 fn test_take_list_conformance(#[case] list: ListArray) {
407 test_take_conformance(&list.into_array());
408 }
409
410 #[test]
411 fn test_u64_offset_accumulation_non_nullable() {
412 let elements = buffer![0i32; 200].into_array();
413 let offsets = buffer![0u8, 200].into_array();
414 let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
415 .unwrap()
416 .into_array();
417
418 let idx = buffer![0u8, 0].into_array();
420 let result = list.take(idx.to_array()).unwrap();
421
422 assert_eq!(result.len(), 2);
423
424 let result_view = result.to_listview();
425 assert_eq!(result_view.len(), 2);
426 assert!(result_view.is_valid(0).unwrap());
427 assert!(result_view.is_valid(1).unwrap());
428 }
429
430 #[test]
431 fn test_u64_offset_accumulation_nullable() {
432 let elements = buffer![0i32; 150].into_array();
433 let offsets = buffer![0u8, 150, 150].into_array();
434 let validity = BoolArray::from_iter(vec![true, false]).into_array();
435 let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
436 .unwrap()
437 .into_array();
438
439 let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).into_array();
441 let result = list.take(idx.to_array()).unwrap();
442
443 assert_eq!(result.len(), 3);
444
445 let result_view = result.to_listview();
446 assert_eq!(result_view.len(), 3);
447 assert!(result_view.is_valid(0).unwrap());
448 assert!(result_view.is_invalid(1).unwrap());
449 assert!(result_view.is_valid(2).unwrap());
450 }
451
452 #[test]
457 fn test_take_validity_length_mismatch_regression() {
458 let list = ListArray::try_new(
460 buffer![1i32, 2, 3, 4].into_array(),
461 buffer![0, 2, 4].into_array(),
462 Validity::Array(BoolArray::from_iter(vec![true, true]).into_array()),
463 )
464 .unwrap()
465 .into_array();
466
467 let idx = buffer![0u32, 1, 0, 1].into_array();
469
470 let result = list.take(idx.to_array()).unwrap();
472 assert_eq!(result.len(), 4);
473 }
474}