1use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::array::ArrayView;
10use crate::arrays::List;
11use crate::arrays::ListArray;
12use crate::arrays::Primitive;
13use crate::arrays::PrimitiveArray;
14use crate::arrays::dict::TakeExecute;
15use crate::arrays::list::ListArrayExt;
16use crate::arrays::primitive::PrimitiveArrayExt;
17use crate::builders::ArrayBuilder;
18use crate::builders::PrimitiveBuilder;
19use crate::dtype::IntegerPType;
20use crate::dtype::Nullability;
21use crate::executor::ExecutionCtx;
22use crate::match_each_unsigned_integer_ptype;
23use crate::match_smallest_offset_type;
24
25impl TakeExecute for List {
29 #[expect(clippy::cognitive_complexity)]
35 fn take(
36 array: ArrayView<'_, List>,
37 indices: &ArrayRef,
38 ctx: &mut ExecutionCtx,
39 ) -> VortexResult<Option<ArrayRef>> {
40 let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
41 let indices = indices.reinterpret_cast(indices.ptype().to_unsigned());
42 let offsets = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
43 let offsets = offsets.reinterpret_cast(offsets.ptype().to_unsigned());
44 let total_approx = array.elements().len().saturating_mul(indices.len());
46
47 match_each_unsigned_integer_ptype!(offsets.ptype(), |O| {
48 match_each_unsigned_integer_ptype!(indices.ptype(), |I| {
49 match_smallest_offset_type!(total_approx, |OutputOffsetType| {
50 _take::<I, O, OutputOffsetType>(
51 array,
52 offsets.as_view(),
53 indices.as_view(),
54 ctx,
55 )
56 .map(Some)
57 })
58 })
59 })
60 }
61}
62
63fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
64 array: ArrayView<'_, List>,
65 offsets_array: ArrayView<'_, Primitive>,
66 indices_array: ArrayView<'_, Primitive>,
67 ctx: &mut ExecutionCtx,
68) -> VortexResult<ArrayRef> {
69 let data_validity = array
70 .list_validity()
71 .execute_mask(array.as_ref().len(), ctx)?;
72 let indices_validity = indices_array
73 .validity()
74 .vortex_expect("Failed to compute validity mask")
75 .execute_mask(indices_array.as_ref().len(), ctx)?;
76
77 if !indices_validity.all_true() || !data_validity.all_true() {
78 return _take_nullable::<I, O, OutputOffsetType>(array, offsets_array, indices_array, ctx);
79 }
80
81 let offsets: &[O] = offsets_array.as_slice();
82 let indices: &[I] = indices_array.as_slice();
83
84 let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
85 Nullability::NonNullable,
86 indices.len(),
87 );
88 let mut elements_to_take =
89 PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
90
91 let mut current_offset = OutputOffsetType::zero();
92 new_offsets.append_zero();
93
94 for &data_idx in indices {
95 let data_idx: usize = data_idx.as_();
96
97 let start = offsets[data_idx];
98 let stop = offsets[data_idx + 1];
99
100 let additional: usize = (stop - start).as_();
106
107 elements_to_take.reserve_exact(additional);
109 for i in 0..additional {
110 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
111 }
112 current_offset +=
113 OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
114 new_offsets.append_value(current_offset);
115 }
116
117 let elements_to_take = elements_to_take.finish();
118 let new_offsets = new_offsets.finish();
119
120 let new_elements = array.elements().take(elements_to_take)?;
121
122 Ok(ListArray::try_new(
123 new_elements,
124 new_offsets,
125 array.validity()?.take(indices_array.array())?,
126 )?
127 .into_array())
128}
129
130#[inline(never)]
133fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
134 array: ArrayView<'_, List>,
135 offsets_array: ArrayView<'_, Primitive>,
136 indices_array: ArrayView<'_, Primitive>,
137 ctx: &mut ExecutionCtx,
138) -> VortexResult<ArrayRef> {
139 let offsets: &[O] = offsets_array.as_slice();
140 let indices: &[I] = indices_array.as_slice();
141 let data_validity = array
142 .list_validity()
143 .execute_mask(array.as_ref().len(), ctx)?;
144 let indices_validity = indices_array
145 .validity()
146 .vortex_expect("Failed to compute validity mask")
147 .execute_mask(indices_array.as_ref().len(), ctx)?;
148
149 let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
150 Nullability::NonNullable,
151 indices.len(),
152 );
153
154 let mut elements_to_take =
162 PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
163
164 let mut current_offset = OutputOffsetType::zero();
165 new_offsets.append_zero();
166
167 for (idx, data_idx) in indices.iter().enumerate() {
168 if !indices_validity.value(idx) {
169 new_offsets.append_value(current_offset);
170 continue;
171 }
172
173 let data_idx: usize = data_idx.as_();
174
175 if !data_validity.value(data_idx) {
176 new_offsets.append_value(current_offset);
177 continue;
178 }
179
180 let start = offsets[data_idx];
181 let stop = offsets[data_idx + 1];
182
183 let additional: usize = (stop - start).as_();
185
186 elements_to_take.reserve_exact(additional);
187 for i in 0..additional {
188 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
189 }
190 current_offset +=
191 OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
192 new_offsets.append_value(current_offset);
193 }
194
195 let elements_to_take = elements_to_take.finish();
196 let new_offsets = new_offsets.finish();
197 let new_elements = array.elements().take(elements_to_take)?;
198
199 Ok(ListArray::try_new(
200 new_elements,
201 new_offsets,
202 array.validity()?.take(indices_array.array())?,
203 )?
204 .into_array())
205}
206
207#[cfg(test)]
208mod test {
209 use std::sync::Arc;
210
211 use rstest::rstest;
212 use vortex_buffer::buffer;
213
214 use crate::IntoArray as _;
215 use crate::LEGACY_SESSION;
216 #[expect(deprecated)]
217 use crate::ToCanonical as _;
218 use crate::VortexSessionExecute;
219 use crate::arrays::BoolArray;
220 use crate::arrays::ListArray;
221 use crate::arrays::PrimitiveArray;
222 use crate::compute::conformance::take::test_take_conformance;
223 use crate::dtype::DType;
224 use crate::dtype::Nullability;
225 use crate::dtype::PType::I32;
226 use crate::scalar::Scalar;
227 use crate::validity::Validity;
228
229 #[test]
230 fn nullable_take() {
231 let list = ListArray::try_new(
232 buffer![0i32, 5, 3, 4].into_array(),
233 buffer![0, 2, 3, 4, 4].into_array(),
234 Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).into_array()),
235 )
236 .unwrap()
237 .into_array();
238
239 let idx =
240 PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).into_array();
241
242 let result = list.take(idx).unwrap();
243
244 assert_eq!(
245 result.dtype(),
246 &DType::List(
247 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
248 Nullability::Nullable
249 )
250 );
251
252 #[expect(deprecated)]
253 let result = result.to_listview();
254
255 assert_eq!(result.len(), 4);
256
257 let element_dtype: Arc<DType> = Arc::new(I32.into());
258
259 assert!(
260 result
261 .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
262 .unwrap()
263 );
264 assert_eq!(
265 result
266 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
267 .unwrap(),
268 Scalar::list(
269 Arc::clone(&element_dtype),
270 vec![0i32.into(), 5.into()],
271 Nullability::Nullable
272 )
273 );
274
275 assert!(
276 result
277 .is_invalid(1, &mut LEGACY_SESSION.create_execution_ctx())
278 .unwrap()
279 );
280
281 assert!(
282 result
283 .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
284 .unwrap()
285 );
286 assert_eq!(
287 result
288 .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
289 .unwrap(),
290 Scalar::list(
291 Arc::clone(&element_dtype),
292 vec![3i32.into()],
293 Nullability::Nullable
294 )
295 );
296
297 assert!(
298 result
299 .is_valid(3, &mut LEGACY_SESSION.create_execution_ctx())
300 .unwrap()
301 );
302 assert_eq!(
303 result
304 .execute_scalar(3, &mut LEGACY_SESSION.create_execution_ctx())
305 .unwrap(),
306 Scalar::list(element_dtype, vec![], Nullability::Nullable)
307 );
308 }
309
310 #[test]
311 fn change_validity() {
312 let list = ListArray::try_new(
313 buffer![0i32, 5, 3, 4].into_array(),
314 buffer![0, 2, 3].into_array(),
315 Validity::NonNullable,
316 )
317 .unwrap()
318 .into_array();
319
320 let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).into_array();
321 let result = list.take(idx).unwrap();
324 assert_eq!(
325 result.dtype(),
326 &DType::List(
327 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
328 Nullability::Nullable
329 )
330 );
331 }
332
333 #[test]
334 fn non_nullable_take() {
335 let list = ListArray::try_new(
336 buffer![0i32, 5, 3, 4].into_array(),
337 buffer![0, 2, 3, 3, 4].into_array(),
338 Validity::NonNullable,
339 )
340 .unwrap()
341 .into_array();
342
343 let idx = buffer![1, 0, 2].into_array();
344
345 let result = list.take(idx).unwrap();
346
347 assert_eq!(
348 result.dtype(),
349 &DType::List(
350 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
351 Nullability::NonNullable
352 )
353 );
354
355 #[expect(deprecated)]
356 let result = result.to_listview();
357
358 assert_eq!(result.len(), 3);
359
360 let element_dtype: Arc<DType> = Arc::new(I32.into());
361
362 assert!(
363 result
364 .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
365 .unwrap()
366 );
367 assert_eq!(
368 result
369 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
370 .unwrap(),
371 Scalar::list(
372 Arc::clone(&element_dtype),
373 vec![3i32.into()],
374 Nullability::NonNullable
375 )
376 );
377
378 assert!(
379 result
380 .is_valid(1, &mut LEGACY_SESSION.create_execution_ctx())
381 .unwrap()
382 );
383 assert_eq!(
384 result
385 .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())
386 .unwrap(),
387 Scalar::list(
388 Arc::clone(&element_dtype),
389 vec![0i32.into(), 5.into()],
390 Nullability::NonNullable
391 )
392 );
393
394 assert!(
395 result
396 .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
397 .unwrap()
398 );
399 assert_eq!(
400 result
401 .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
402 .unwrap(),
403 Scalar::list(element_dtype, vec![], Nullability::NonNullable)
404 );
405 }
406
407 #[test]
408 fn test_take_empty_array() {
409 let list = ListArray::try_new(
410 buffer![0i32, 5, 3, 4].into_array(),
411 buffer![0].into_array(),
412 Validity::NonNullable,
413 )
414 .unwrap()
415 .into_array();
416
417 let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).into_array();
418
419 let result = list.take(idx).unwrap();
420 assert_eq!(
421 result.dtype(),
422 &DType::List(
423 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
424 Nullability::Nullable
425 )
426 );
427 assert_eq!(result.len(), 0,);
428 }
429
430 #[rstest]
431 #[case(ListArray::try_new(
432 buffer![0i32, 1, 2, 3, 4, 5].into_array(),
433 buffer![0, 2, 3, 5, 5, 6].into_array(),
434 Validity::NonNullable,
435 ).unwrap())]
436 #[case(ListArray::try_new(
437 buffer![10i32, 20, 30, 40, 50].into_array(),
438 buffer![0, 2, 3, 4, 5].into_array(),
439 Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).into_array()),
440 ).unwrap())]
441 #[case(ListArray::try_new(
442 buffer![1i32, 2, 3].into_array(),
443 buffer![0, 0, 2, 2, 3].into_array(), Validity::NonNullable,
445 ).unwrap())]
446 #[case(ListArray::try_new(
447 buffer![42i32, 43].into_array(),
448 buffer![0, 2].into_array(),
449 Validity::NonNullable,
450 ).unwrap())]
451 #[case({
452 let elements = buffer![0i32..200].into_array();
453 let mut offsets = vec![0u64];
454 for i in 1..=50 {
455 offsets.push(offsets[i - 1] + (i as u64 % 5)); }
457 ListArray::try_new(
458 elements,
459 PrimitiveArray::from_iter(offsets).into_array(),
460 Validity::NonNullable,
461 ).unwrap()
462 })]
463 #[case(ListArray::try_new(
464 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).into_array(),
465 buffer![0, 2, 3, 5].into_array(),
466 Validity::NonNullable,
467 ).unwrap())]
468 fn test_take_list_conformance(#[case] list: ListArray) {
469 test_take_conformance(&list.into_array());
470 }
471
472 #[test]
473 fn test_u64_offset_accumulation_non_nullable() {
474 let elements = buffer![0i32; 200].into_array();
475 let offsets = buffer![0u8, 200].into_array();
476 let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
477 .unwrap()
478 .into_array();
479
480 let idx = buffer![0u8, 0].into_array();
482 let result = list.take(idx).unwrap();
483
484 assert_eq!(result.len(), 2);
485
486 #[expect(deprecated)]
487 let result_view = result.to_listview();
488 assert_eq!(result_view.len(), 2);
489 assert!(
490 result_view
491 .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
492 .unwrap()
493 );
494 assert!(
495 result_view
496 .is_valid(1, &mut LEGACY_SESSION.create_execution_ctx())
497 .unwrap()
498 );
499 }
500
501 #[test]
502 fn test_u64_offset_accumulation_nullable() {
503 let elements = buffer![0i32; 150].into_array();
504 let offsets = buffer![0u8, 150, 150].into_array();
505 let validity = BoolArray::from_iter(vec![true, false]).into_array();
506 let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
507 .unwrap()
508 .into_array();
509
510 let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).into_array();
512 let result = list.take(idx).unwrap();
513
514 assert_eq!(result.len(), 3);
515
516 #[expect(deprecated)]
517 let result_view = result.to_listview();
518 assert_eq!(result_view.len(), 3);
519 assert!(
520 result_view
521 .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
522 .unwrap()
523 );
524 assert!(
525 result_view
526 .is_invalid(1, &mut LEGACY_SESSION.create_execution_ctx())
527 .unwrap()
528 );
529 assert!(
530 result_view
531 .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
532 .unwrap()
533 );
534 }
535
536 #[test]
541 fn test_take_validity_length_mismatch_regression() {
542 let list = ListArray::try_new(
544 buffer![1i32, 2, 3, 4].into_array(),
545 buffer![0, 2, 4].into_array(),
546 Validity::Array(BoolArray::from_iter(vec![true, true]).into_array()),
547 )
548 .unwrap()
549 .into_array();
550
551 let idx = buffer![0u32, 1, 0, 1].into_array();
553
554 let result = list.take(idx).unwrap();
556 assert_eq!(result.len(), 4);
557 }
558}