1use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::array::ArrayView;
10use crate::arrays::List;
11use crate::arrays::ListArray;
12use crate::arrays::Primitive;
13use crate::arrays::PrimitiveArray;
14use crate::arrays::dict::TakeExecute;
15use crate::arrays::list::ListArrayExt;
16use crate::arrays::primitive::PrimitiveArrayExt;
17use crate::builders::ArrayBuilder;
18use crate::builders::PrimitiveBuilder;
19use crate::dtype::IntegerPType;
20use crate::dtype::Nullability;
21use crate::executor::ExecutionCtx;
22use crate::match_each_integer_ptype;
23use crate::match_smallest_offset_type;
24
25impl TakeExecute for List {
29 #[expect(clippy::cognitive_complexity)]
35 fn take(
36 array: ArrayView<'_, List>,
37 indices: &ArrayRef,
38 ctx: &mut ExecutionCtx,
39 ) -> VortexResult<Option<ArrayRef>> {
40 let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
41 let total_approx = array.elements().len().saturating_mul(indices.len());
43
44 match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| {
45 match_each_integer_ptype!(indices.ptype(), |I| {
46 match_smallest_offset_type!(total_approx, |OutputOffsetType| {
47 {
48 let indices = indices.as_view();
49 _take::<I, O, OutputOffsetType>(array, indices, ctx).map(Some)
50 }
51 })
52 })
53 })
54 }
55}
56
57fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
58 array: ArrayView<'_, List>,
59 indices_array: ArrayView<'_, Primitive>,
60 ctx: &mut ExecutionCtx,
61) -> VortexResult<ArrayRef> {
62 let data_validity = array
63 .list_validity()
64 .execute_mask(array.as_ref().len(), ctx)?;
65 let indices_validity = indices_array
66 .validity()
67 .vortex_expect("Failed to compute validity mask")
68 .execute_mask(indices_array.as_ref().len(), ctx)?;
69
70 if !indices_validity.all_true() || !data_validity.all_true() {
71 return _take_nullable::<I, O, OutputOffsetType>(array, indices_array, ctx);
72 }
73
74 let offsets_array = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
75 let offsets: &[O] = offsets_array.as_slice();
76 let indices: &[I] = indices_array.as_slice();
77
78 let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
79 Nullability::NonNullable,
80 indices.len(),
81 );
82 let mut elements_to_take =
83 PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
84
85 let mut current_offset = OutputOffsetType::zero();
86 new_offsets.append_zero();
87
88 for &data_idx in indices {
89 let data_idx: usize = data_idx.as_();
90
91 let start = offsets[data_idx];
92 let stop = offsets[data_idx + 1];
93
94 let additional: usize = (stop - start).as_();
100
101 elements_to_take.reserve_exact(additional);
103 for i in 0..additional {
104 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
105 }
106 current_offset +=
107 OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
108 new_offsets.append_value(current_offset);
109 }
110
111 let elements_to_take = elements_to_take.finish();
112 let new_offsets = new_offsets.finish();
113
114 let new_elements = array.elements().take(elements_to_take)?;
115
116 Ok(ListArray::try_new(
117 new_elements,
118 new_offsets,
119 array.validity()?.take(indices_array.array())?,
120 )?
121 .into_array())
122}
123
124fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
125 array: ArrayView<'_, List>,
126 indices_array: ArrayView<'_, Primitive>,
127 ctx: &mut ExecutionCtx,
128) -> VortexResult<ArrayRef> {
129 let offsets_array = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
130 let offsets: &[O] = offsets_array.as_slice();
131 let indices: &[I] = indices_array.as_slice();
132 let data_validity = array
133 .list_validity()
134 .execute_mask(array.as_ref().len(), ctx)?;
135 let indices_validity = indices_array
136 .validity()
137 .vortex_expect("Failed to compute validity mask")
138 .execute_mask(indices_array.as_ref().len(), ctx)?;
139
140 let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
141 Nullability::NonNullable,
142 indices.len(),
143 );
144
145 let mut elements_to_take =
153 PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
154
155 let mut current_offset = OutputOffsetType::zero();
156 new_offsets.append_zero();
157
158 for (idx, data_idx) in indices.iter().enumerate() {
159 if !indices_validity.value(idx) {
160 new_offsets.append_value(current_offset);
161 continue;
162 }
163
164 let data_idx: usize = data_idx.as_();
165
166 if !data_validity.value(data_idx) {
167 new_offsets.append_value(current_offset);
168 continue;
169 }
170
171 let start = offsets[data_idx];
172 let stop = offsets[data_idx + 1];
173
174 let additional: usize = (stop - start).as_();
176
177 elements_to_take.reserve_exact(additional);
178 for i in 0..additional {
179 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
180 }
181 current_offset +=
182 OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
183 new_offsets.append_value(current_offset);
184 }
185
186 let elements_to_take = elements_to_take.finish();
187 let new_offsets = new_offsets.finish();
188 let new_elements = array.elements().take(elements_to_take)?;
189
190 Ok(ListArray::try_new(
191 new_elements,
192 new_offsets,
193 array.validity()?.take(indices_array.array())?,
194 )?
195 .into_array())
196}
197
198#[cfg(test)]
199mod test {
200 use std::sync::Arc;
201
202 use rstest::rstest;
203 use vortex_buffer::buffer;
204
205 use crate::IntoArray as _;
206 use crate::LEGACY_SESSION;
207 #[expect(deprecated)]
208 use crate::ToCanonical as _;
209 use crate::VortexSessionExecute;
210 use crate::arrays::BoolArray;
211 use crate::arrays::ListArray;
212 use crate::arrays::PrimitiveArray;
213 use crate::compute::conformance::take::test_take_conformance;
214 use crate::dtype::DType;
215 use crate::dtype::Nullability;
216 use crate::dtype::PType::I32;
217 use crate::scalar::Scalar;
218 use crate::validity::Validity;
219
220 #[test]
221 fn nullable_take() {
222 let list = ListArray::try_new(
223 buffer![0i32, 5, 3, 4].into_array(),
224 buffer![0, 2, 3, 4, 4].into_array(),
225 Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).into_array()),
226 )
227 .unwrap()
228 .into_array();
229
230 let idx =
231 PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).into_array();
232
233 let result = list.take(idx).unwrap();
234
235 assert_eq!(
236 result.dtype(),
237 &DType::List(
238 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
239 Nullability::Nullable
240 )
241 );
242
243 #[expect(deprecated)]
244 let result = result.to_listview();
245
246 assert_eq!(result.len(), 4);
247
248 let element_dtype: Arc<DType> = Arc::new(I32.into());
249
250 assert!(
251 result
252 .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
253 .unwrap()
254 );
255 assert_eq!(
256 result
257 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
258 .unwrap(),
259 Scalar::list(
260 Arc::clone(&element_dtype),
261 vec![0i32.into(), 5.into()],
262 Nullability::Nullable
263 )
264 );
265
266 assert!(
267 result
268 .is_invalid(1, &mut LEGACY_SESSION.create_execution_ctx())
269 .unwrap()
270 );
271
272 assert!(
273 result
274 .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
275 .unwrap()
276 );
277 assert_eq!(
278 result
279 .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
280 .unwrap(),
281 Scalar::list(
282 Arc::clone(&element_dtype),
283 vec![3i32.into()],
284 Nullability::Nullable
285 )
286 );
287
288 assert!(
289 result
290 .is_valid(3, &mut LEGACY_SESSION.create_execution_ctx())
291 .unwrap()
292 );
293 assert_eq!(
294 result
295 .execute_scalar(3, &mut LEGACY_SESSION.create_execution_ctx())
296 .unwrap(),
297 Scalar::list(element_dtype, vec![], Nullability::Nullable)
298 );
299 }
300
301 #[test]
302 fn change_validity() {
303 let list = ListArray::try_new(
304 buffer![0i32, 5, 3, 4].into_array(),
305 buffer![0, 2, 3].into_array(),
306 Validity::NonNullable,
307 )
308 .unwrap()
309 .into_array();
310
311 let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).into_array();
312 let result = list.take(idx).unwrap();
315 assert_eq!(
316 result.dtype(),
317 &DType::List(
318 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
319 Nullability::Nullable
320 )
321 );
322 }
323
324 #[test]
325 fn non_nullable_take() {
326 let list = ListArray::try_new(
327 buffer![0i32, 5, 3, 4].into_array(),
328 buffer![0, 2, 3, 3, 4].into_array(),
329 Validity::NonNullable,
330 )
331 .unwrap()
332 .into_array();
333
334 let idx = buffer![1, 0, 2].into_array();
335
336 let result = list.take(idx).unwrap();
337
338 assert_eq!(
339 result.dtype(),
340 &DType::List(
341 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
342 Nullability::NonNullable
343 )
344 );
345
346 #[expect(deprecated)]
347 let result = result.to_listview();
348
349 assert_eq!(result.len(), 3);
350
351 let element_dtype: Arc<DType> = Arc::new(I32.into());
352
353 assert!(
354 result
355 .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
356 .unwrap()
357 );
358 assert_eq!(
359 result
360 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
361 .unwrap(),
362 Scalar::list(
363 Arc::clone(&element_dtype),
364 vec![3i32.into()],
365 Nullability::NonNullable
366 )
367 );
368
369 assert!(
370 result
371 .is_valid(1, &mut LEGACY_SESSION.create_execution_ctx())
372 .unwrap()
373 );
374 assert_eq!(
375 result
376 .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())
377 .unwrap(),
378 Scalar::list(
379 Arc::clone(&element_dtype),
380 vec![0i32.into(), 5.into()],
381 Nullability::NonNullable
382 )
383 );
384
385 assert!(
386 result
387 .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
388 .unwrap()
389 );
390 assert_eq!(
391 result
392 .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
393 .unwrap(),
394 Scalar::list(element_dtype, vec![], Nullability::NonNullable)
395 );
396 }
397
398 #[test]
399 fn test_take_empty_array() {
400 let list = ListArray::try_new(
401 buffer![0i32, 5, 3, 4].into_array(),
402 buffer![0].into_array(),
403 Validity::NonNullable,
404 )
405 .unwrap()
406 .into_array();
407
408 let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).into_array();
409
410 let result = list.take(idx).unwrap();
411 assert_eq!(
412 result.dtype(),
413 &DType::List(
414 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
415 Nullability::Nullable
416 )
417 );
418 assert_eq!(result.len(), 0,);
419 }
420
421 #[rstest]
422 #[case(ListArray::try_new(
423 buffer![0i32, 1, 2, 3, 4, 5].into_array(),
424 buffer![0, 2, 3, 5, 5, 6].into_array(),
425 Validity::NonNullable,
426 ).unwrap())]
427 #[case(ListArray::try_new(
428 buffer![10i32, 20, 30, 40, 50].into_array(),
429 buffer![0, 2, 3, 4, 5].into_array(),
430 Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).into_array()),
431 ).unwrap())]
432 #[case(ListArray::try_new(
433 buffer![1i32, 2, 3].into_array(),
434 buffer![0, 0, 2, 2, 3].into_array(), Validity::NonNullable,
436 ).unwrap())]
437 #[case(ListArray::try_new(
438 buffer![42i32, 43].into_array(),
439 buffer![0, 2].into_array(),
440 Validity::NonNullable,
441 ).unwrap())]
442 #[case({
443 let elements = buffer![0i32..200].into_array();
444 let mut offsets = vec![0u64];
445 for i in 1..=50 {
446 offsets.push(offsets[i - 1] + (i as u64 % 5)); }
448 ListArray::try_new(
449 elements,
450 PrimitiveArray::from_iter(offsets).into_array(),
451 Validity::NonNullable,
452 ).unwrap()
453 })]
454 #[case(ListArray::try_new(
455 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).into_array(),
456 buffer![0, 2, 3, 5].into_array(),
457 Validity::NonNullable,
458 ).unwrap())]
459 fn test_take_list_conformance(#[case] list: ListArray) {
460 test_take_conformance(&list.into_array());
461 }
462
463 #[test]
464 fn test_u64_offset_accumulation_non_nullable() {
465 let elements = buffer![0i32; 200].into_array();
466 let offsets = buffer![0u8, 200].into_array();
467 let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
468 .unwrap()
469 .into_array();
470
471 let idx = buffer![0u8, 0].into_array();
473 let result = list.take(idx).unwrap();
474
475 assert_eq!(result.len(), 2);
476
477 #[expect(deprecated)]
478 let result_view = result.to_listview();
479 assert_eq!(result_view.len(), 2);
480 assert!(
481 result_view
482 .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
483 .unwrap()
484 );
485 assert!(
486 result_view
487 .is_valid(1, &mut LEGACY_SESSION.create_execution_ctx())
488 .unwrap()
489 );
490 }
491
492 #[test]
493 fn test_u64_offset_accumulation_nullable() {
494 let elements = buffer![0i32; 150].into_array();
495 let offsets = buffer![0u8, 150, 150].into_array();
496 let validity = BoolArray::from_iter(vec![true, false]).into_array();
497 let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
498 .unwrap()
499 .into_array();
500
501 let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).into_array();
503 let result = list.take(idx).unwrap();
504
505 assert_eq!(result.len(), 3);
506
507 #[expect(deprecated)]
508 let result_view = result.to_listview();
509 assert_eq!(result_view.len(), 3);
510 assert!(
511 result_view
512 .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
513 .unwrap()
514 );
515 assert!(
516 result_view
517 .is_invalid(1, &mut LEGACY_SESSION.create_execution_ctx())
518 .unwrap()
519 );
520 assert!(
521 result_view
522 .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
523 .unwrap()
524 );
525 }
526
527 #[test]
532 fn test_take_validity_length_mismatch_regression() {
533 let list = ListArray::try_new(
535 buffer![1i32, 2, 3, 4].into_array(),
536 buffer![0, 2, 4].into_array(),
537 Validity::Array(BoolArray::from_iter(vec![true, true]).into_array()),
538 )
539 .unwrap()
540 .into_array();
541
542 let idx = buffer![0u32, 1, 0, 1].into_array();
544
545 let result = list.take(idx).unwrap();
547 assert_eq!(result.len(), 4);
548 }
549}