1use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::array::ArrayView;
10use crate::arrays::List;
11use crate::arrays::ListArray;
12use crate::arrays::Primitive;
13use crate::arrays::PrimitiveArray;
14use crate::arrays::dict::TakeExecute;
15use crate::arrays::list::ListArrayExt;
16use crate::arrays::primitive::PrimitiveArrayExt;
17use crate::builders::ArrayBuilder;
18use crate::builders::PrimitiveBuilder;
19use crate::dtype::IntegerPType;
20use crate::dtype::Nullability;
21use crate::executor::ExecutionCtx;
22use crate::match_each_integer_ptype;
23use crate::match_smallest_offset_type;
24
25impl TakeExecute for List {
29 #[expect(clippy::cognitive_complexity)]
35 fn take(
36 array: ArrayView<'_, List>,
37 indices: &ArrayRef,
38 ctx: &mut ExecutionCtx,
39 ) -> VortexResult<Option<ArrayRef>> {
40 let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
41 let total_approx = array.elements().len().saturating_mul(indices.len());
43
44 match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| {
45 match_each_integer_ptype!(indices.ptype(), |I| {
46 match_smallest_offset_type!(total_approx, |OutputOffsetType| {
47 {
48 let indices = indices.as_view();
49 _take::<I, O, OutputOffsetType>(array, indices, ctx).map(Some)
50 }
51 })
52 })
53 })
54 }
55}
56
57fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
58 array: ArrayView<'_, List>,
59 indices_array: ArrayView<'_, Primitive>,
60 ctx: &mut ExecutionCtx,
61) -> VortexResult<ArrayRef> {
62 let data_validity = array.list_validity().to_mask(array.as_ref().len(), ctx)?;
63 let indices_validity = indices_array
64 .validity()
65 .vortex_expect("Failed to compute validity mask")
66 .to_mask(indices_array.as_ref().len(), ctx)?;
67
68 if !indices_validity.all_true() || !data_validity.all_true() {
69 return _take_nullable::<I, O, OutputOffsetType>(array, indices_array, ctx);
70 }
71
72 let offsets_array = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
73 let offsets: &[O] = offsets_array.as_slice();
74 let indices: &[I] = indices_array.as_slice();
75
76 let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
77 Nullability::NonNullable,
78 indices.len(),
79 );
80 let mut elements_to_take =
81 PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
82
83 let mut current_offset = OutputOffsetType::zero();
84 new_offsets.append_zero();
85
86 for &data_idx in indices {
87 let data_idx: usize = data_idx.as_();
88
89 let start = offsets[data_idx];
90 let stop = offsets[data_idx + 1];
91
92 let additional: usize = (stop - start).as_();
98
99 elements_to_take.reserve_exact(additional);
101 for i in 0..additional {
102 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
103 }
104 current_offset +=
105 OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
106 new_offsets.append_value(current_offset);
107 }
108
109 let elements_to_take = elements_to_take.finish();
110 let new_offsets = new_offsets.finish();
111
112 let new_elements = array.elements().take(elements_to_take)?;
113
114 Ok(ListArray::try_new(
115 new_elements,
116 new_offsets,
117 array.validity()?.take(indices_array.array())?,
118 )?
119 .into_array())
120}
121
122fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
123 array: ArrayView<'_, List>,
124 indices_array: ArrayView<'_, Primitive>,
125 ctx: &mut ExecutionCtx,
126) -> VortexResult<ArrayRef> {
127 let offsets_array = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
128 let offsets: &[O] = offsets_array.as_slice();
129 let indices: &[I] = indices_array.as_slice();
130 let data_validity = array.list_validity().to_mask(array.as_ref().len(), ctx)?;
131 let indices_validity = indices_array
132 .validity()
133 .vortex_expect("Failed to compute validity mask")
134 .to_mask(indices_array.as_ref().len(), ctx)?;
135
136 let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
137 Nullability::NonNullable,
138 indices.len(),
139 );
140
141 let mut elements_to_take =
149 PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
150
151 let mut current_offset = OutputOffsetType::zero();
152 new_offsets.append_zero();
153
154 for (idx, data_idx) in indices.iter().enumerate() {
155 if !indices_validity.value(idx) {
156 new_offsets.append_value(current_offset);
157 continue;
158 }
159
160 let data_idx: usize = data_idx.as_();
161
162 if !data_validity.value(data_idx) {
163 new_offsets.append_value(current_offset);
164 continue;
165 }
166
167 let start = offsets[data_idx];
168 let stop = offsets[data_idx + 1];
169
170 let additional: usize = (stop - start).as_();
172
173 elements_to_take.reserve_exact(additional);
174 for i in 0..additional {
175 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
176 }
177 current_offset +=
178 OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
179 new_offsets.append_value(current_offset);
180 }
181
182 let elements_to_take = elements_to_take.finish();
183 let new_offsets = new_offsets.finish();
184 let new_elements = array.elements().take(elements_to_take)?;
185
186 Ok(ListArray::try_new(
187 new_elements,
188 new_offsets,
189 array.validity()?.take(indices_array.array())?,
190 )?
191 .into_array())
192}
193
194#[cfg(test)]
195mod test {
196 use std::sync::Arc;
197
198 use rstest::rstest;
199 use vortex_buffer::buffer;
200
201 use crate::IntoArray as _;
202 use crate::LEGACY_SESSION;
203 use crate::ToCanonical;
204 use crate::VortexSessionExecute;
205 use crate::arrays::BoolArray;
206 use crate::arrays::ListArray;
207 use crate::arrays::PrimitiveArray;
208 use crate::compute::conformance::take::test_take_conformance;
209 use crate::dtype::DType;
210 use crate::dtype::Nullability;
211 use crate::dtype::PType::I32;
212 use crate::scalar::Scalar;
213 use crate::validity::Validity;
214
215 #[test]
216 fn nullable_take() {
217 let list = ListArray::try_new(
218 buffer![0i32, 5, 3, 4].into_array(),
219 buffer![0, 2, 3, 4, 4].into_array(),
220 Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).into_array()),
221 )
222 .unwrap()
223 .into_array();
224
225 let idx =
226 PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).into_array();
227
228 let result = list.take(idx).unwrap();
229
230 assert_eq!(
231 result.dtype(),
232 &DType::List(
233 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
234 Nullability::Nullable
235 )
236 );
237
238 let result = result.to_listview();
239
240 assert_eq!(result.len(), 4);
241
242 let element_dtype: Arc<DType> = Arc::new(I32.into());
243
244 assert!(
245 result
246 .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
247 .unwrap()
248 );
249 assert_eq!(
250 result
251 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
252 .unwrap(),
253 Scalar::list(
254 Arc::clone(&element_dtype),
255 vec![0i32.into(), 5.into()],
256 Nullability::Nullable
257 )
258 );
259
260 assert!(
261 result
262 .is_invalid(1, &mut LEGACY_SESSION.create_execution_ctx())
263 .unwrap()
264 );
265
266 assert!(
267 result
268 .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
269 .unwrap()
270 );
271 assert_eq!(
272 result
273 .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
274 .unwrap(),
275 Scalar::list(
276 Arc::clone(&element_dtype),
277 vec![3i32.into()],
278 Nullability::Nullable
279 )
280 );
281
282 assert!(
283 result
284 .is_valid(3, &mut LEGACY_SESSION.create_execution_ctx())
285 .unwrap()
286 );
287 assert_eq!(
288 result
289 .execute_scalar(3, &mut LEGACY_SESSION.create_execution_ctx())
290 .unwrap(),
291 Scalar::list(element_dtype, vec![], Nullability::Nullable)
292 );
293 }
294
295 #[test]
296 fn change_validity() {
297 let list = ListArray::try_new(
298 buffer![0i32, 5, 3, 4].into_array(),
299 buffer![0, 2, 3].into_array(),
300 Validity::NonNullable,
301 )
302 .unwrap()
303 .into_array();
304
305 let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).into_array();
306 let result = list.take(idx).unwrap();
309 assert_eq!(
310 result.dtype(),
311 &DType::List(
312 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
313 Nullability::Nullable
314 )
315 );
316 }
317
318 #[test]
319 fn non_nullable_take() {
320 let list = ListArray::try_new(
321 buffer![0i32, 5, 3, 4].into_array(),
322 buffer![0, 2, 3, 3, 4].into_array(),
323 Validity::NonNullable,
324 )
325 .unwrap()
326 .into_array();
327
328 let idx = buffer![1, 0, 2].into_array();
329
330 let result = list.take(idx).unwrap();
331
332 assert_eq!(
333 result.dtype(),
334 &DType::List(
335 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
336 Nullability::NonNullable
337 )
338 );
339
340 let result = result.to_listview();
341
342 assert_eq!(result.len(), 3);
343
344 let element_dtype: Arc<DType> = Arc::new(I32.into());
345
346 assert!(
347 result
348 .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
349 .unwrap()
350 );
351 assert_eq!(
352 result
353 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
354 .unwrap(),
355 Scalar::list(
356 Arc::clone(&element_dtype),
357 vec![3i32.into()],
358 Nullability::NonNullable
359 )
360 );
361
362 assert!(
363 result
364 .is_valid(1, &mut LEGACY_SESSION.create_execution_ctx())
365 .unwrap()
366 );
367 assert_eq!(
368 result
369 .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())
370 .unwrap(),
371 Scalar::list(
372 Arc::clone(&element_dtype),
373 vec![0i32.into(), 5.into()],
374 Nullability::NonNullable
375 )
376 );
377
378 assert!(
379 result
380 .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
381 .unwrap()
382 );
383 assert_eq!(
384 result
385 .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
386 .unwrap(),
387 Scalar::list(element_dtype, vec![], Nullability::NonNullable)
388 );
389 }
390
391 #[test]
392 fn test_take_empty_array() {
393 let list = ListArray::try_new(
394 buffer![0i32, 5, 3, 4].into_array(),
395 buffer![0].into_array(),
396 Validity::NonNullable,
397 )
398 .unwrap()
399 .into_array();
400
401 let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).into_array();
402
403 let result = list.take(idx).unwrap();
404 assert_eq!(
405 result.dtype(),
406 &DType::List(
407 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
408 Nullability::Nullable
409 )
410 );
411 assert_eq!(result.len(), 0,);
412 }
413
414 #[rstest]
415 #[case(ListArray::try_new(
416 buffer![0i32, 1, 2, 3, 4, 5].into_array(),
417 buffer![0, 2, 3, 5, 5, 6].into_array(),
418 Validity::NonNullable,
419 ).unwrap())]
420 #[case(ListArray::try_new(
421 buffer![10i32, 20, 30, 40, 50].into_array(),
422 buffer![0, 2, 3, 4, 5].into_array(),
423 Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).into_array()),
424 ).unwrap())]
425 #[case(ListArray::try_new(
426 buffer![1i32, 2, 3].into_array(),
427 buffer![0, 0, 2, 2, 3].into_array(), Validity::NonNullable,
429 ).unwrap())]
430 #[case(ListArray::try_new(
431 buffer![42i32, 43].into_array(),
432 buffer![0, 2].into_array(),
433 Validity::NonNullable,
434 ).unwrap())]
435 #[case({
436 let elements = buffer![0i32..200].into_array();
437 let mut offsets = vec![0u64];
438 for i in 1..=50 {
439 offsets.push(offsets[i - 1] + (i as u64 % 5)); }
441 ListArray::try_new(
442 elements,
443 PrimitiveArray::from_iter(offsets).into_array(),
444 Validity::NonNullable,
445 ).unwrap()
446 })]
447 #[case(ListArray::try_new(
448 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).into_array(),
449 buffer![0, 2, 3, 5].into_array(),
450 Validity::NonNullable,
451 ).unwrap())]
452 fn test_take_list_conformance(#[case] list: ListArray) {
453 test_take_conformance(&list.into_array());
454 }
455
456 #[test]
457 fn test_u64_offset_accumulation_non_nullable() {
458 let elements = buffer![0i32; 200].into_array();
459 let offsets = buffer![0u8, 200].into_array();
460 let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
461 .unwrap()
462 .into_array();
463
464 let idx = buffer![0u8, 0].into_array();
466 let result = list.take(idx).unwrap();
467
468 assert_eq!(result.len(), 2);
469
470 let result_view = result.to_listview();
471 assert_eq!(result_view.len(), 2);
472 assert!(
473 result_view
474 .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
475 .unwrap()
476 );
477 assert!(
478 result_view
479 .is_valid(1, &mut LEGACY_SESSION.create_execution_ctx())
480 .unwrap()
481 );
482 }
483
484 #[test]
485 fn test_u64_offset_accumulation_nullable() {
486 let elements = buffer![0i32; 150].into_array();
487 let offsets = buffer![0u8, 150, 150].into_array();
488 let validity = BoolArray::from_iter(vec![true, false]).into_array();
489 let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
490 .unwrap()
491 .into_array();
492
493 let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).into_array();
495 let result = list.take(idx).unwrap();
496
497 assert_eq!(result.len(), 3);
498
499 let result_view = result.to_listview();
500 assert_eq!(result_view.len(), 3);
501 assert!(
502 result_view
503 .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
504 .unwrap()
505 );
506 assert!(
507 result_view
508 .is_invalid(1, &mut LEGACY_SESSION.create_execution_ctx())
509 .unwrap()
510 );
511 assert!(
512 result_view
513 .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
514 .unwrap()
515 );
516 }
517
518 #[test]
523 fn test_take_validity_length_mismatch_regression() {
524 let list = ListArray::try_new(
526 buffer![1i32, 2, 3, 4].into_array(),
527 buffer![0, 2, 4].into_array(),
528 Validity::Array(BoolArray::from_iter(vec![true, true]).into_array()),
529 )
530 .unwrap()
531 .into_array();
532
533 let idx = buffer![0u32, 1, 0, 1].into_array();
535
536 let result = list.take(idx).unwrap();
538 assert_eq!(result.len(), 4);
539 }
540}