1use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::array::ArrayView;
10use crate::arrays::List;
11use crate::arrays::ListArray;
12use crate::arrays::Primitive;
13use crate::arrays::PrimitiveArray;
14use crate::arrays::dict::TakeExecute;
15use crate::arrays::list::ListArrayExt;
16use crate::arrays::primitive::PrimitiveArrayExt;
17use crate::builders::ArrayBuilder;
18use crate::builders::PrimitiveBuilder;
19use crate::dtype::IntegerPType;
20use crate::dtype::Nullability;
21use crate::executor::ExecutionCtx;
22use crate::match_each_integer_ptype;
23use crate::match_smallest_offset_type;
24
25impl TakeExecute for List {
29 #[expect(clippy::cognitive_complexity)]
35 fn take(
36 array: ArrayView<'_, List>,
37 indices: &ArrayRef,
38 ctx: &mut ExecutionCtx,
39 ) -> VortexResult<Option<ArrayRef>> {
40 let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
41 let total_approx = array.elements().len().saturating_mul(indices.len());
43
44 match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| {
45 match_each_integer_ptype!(indices.ptype(), |I| {
46 match_smallest_offset_type!(total_approx, |OutputOffsetType| {
47 {
48 let indices = indices.as_view();
49 _take::<I, O, OutputOffsetType>(array, indices, ctx).map(Some)
50 }
51 })
52 })
53 })
54 }
55}
56
57fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
58 array: ArrayView<'_, List>,
59 indices_array: ArrayView<'_, Primitive>,
60 ctx: &mut ExecutionCtx,
61) -> VortexResult<ArrayRef> {
62 let data_validity = array.list_validity_mask();
63 let indices_validity = indices_array.validity_mask();
64
65 if !indices_validity.all_true() || !data_validity.all_true() {
66 return _take_nullable::<I, O, OutputOffsetType>(array, indices_array, ctx);
67 }
68
69 let offsets_array = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
70 let offsets: &[O] = offsets_array.as_slice();
71 let indices: &[I] = indices_array.as_slice();
72
73 let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
74 Nullability::NonNullable,
75 indices.len(),
76 );
77 let mut elements_to_take =
78 PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
79
80 let mut current_offset = OutputOffsetType::zero();
81 new_offsets.append_zero();
82
83 for &data_idx in indices {
84 let data_idx: usize = data_idx.as_();
85
86 let start = offsets[data_idx];
87 let stop = offsets[data_idx + 1];
88
89 let additional: usize = (stop - start).as_();
95
96 elements_to_take.reserve_exact(additional);
98 for i in 0..additional {
99 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
100 }
101 current_offset +=
102 OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
103 new_offsets.append_value(current_offset);
104 }
105
106 let elements_to_take = elements_to_take.finish();
107 let new_offsets = new_offsets.finish();
108
109 let new_elements = array.elements().take(elements_to_take)?;
110
111 Ok(ListArray::try_new(
112 new_elements,
113 new_offsets,
114 array.validity()?.take(indices_array.array())?,
115 )?
116 .into_array())
117}
118
119fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
120 array: ArrayView<'_, List>,
121 indices_array: ArrayView<'_, Primitive>,
122 ctx: &mut ExecutionCtx,
123) -> VortexResult<ArrayRef> {
124 let offsets_array = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
125 let offsets: &[O] = offsets_array.as_slice();
126 let indices: &[I] = indices_array.as_slice();
127 let data_validity = array.list_validity_mask();
128 let indices_validity = indices_array.validity_mask();
129
130 let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
131 Nullability::NonNullable,
132 indices.len(),
133 );
134
135 let mut elements_to_take =
143 PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
144
145 let mut current_offset = OutputOffsetType::zero();
146 new_offsets.append_zero();
147
148 for (idx, data_idx) in indices.iter().enumerate() {
149 if !indices_validity.value(idx) {
150 new_offsets.append_value(current_offset);
151 continue;
152 }
153
154 let data_idx: usize = data_idx.as_();
155
156 if !data_validity.value(data_idx) {
157 new_offsets.append_value(current_offset);
158 continue;
159 }
160
161 let start = offsets[data_idx];
162 let stop = offsets[data_idx + 1];
163
164 let additional: usize = (stop - start).as_();
166
167 elements_to_take.reserve_exact(additional);
168 for i in 0..additional {
169 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
170 }
171 current_offset +=
172 OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
173 new_offsets.append_value(current_offset);
174 }
175
176 let elements_to_take = elements_to_take.finish();
177 let new_offsets = new_offsets.finish();
178 let new_elements = array.elements().take(elements_to_take)?;
179
180 Ok(ListArray::try_new(
181 new_elements,
182 new_offsets,
183 array.validity()?.take(indices_array.array())?,
184 )?
185 .into_array())
186}
187
188#[cfg(test)]
189mod test {
190 use std::sync::Arc;
191
192 use rstest::rstest;
193 use vortex_buffer::buffer;
194
195 use crate::IntoArray as _;
196 use crate::ToCanonical;
197 use crate::arrays::BoolArray;
198 use crate::arrays::ListArray;
199 use crate::arrays::PrimitiveArray;
200 use crate::compute::conformance::take::test_take_conformance;
201 use crate::dtype::DType;
202 use crate::dtype::Nullability;
203 use crate::dtype::PType::I32;
204 use crate::scalar::Scalar;
205 use crate::validity::Validity;
206
207 #[test]
208 fn nullable_take() {
209 let list = ListArray::try_new(
210 buffer![0i32, 5, 3, 4].into_array(),
211 buffer![0, 2, 3, 4, 4].into_array(),
212 Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).into_array()),
213 )
214 .unwrap()
215 .into_array();
216
217 let idx =
218 PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).into_array();
219
220 let result = list.take(idx).unwrap();
221
222 assert_eq!(
223 result.dtype(),
224 &DType::List(
225 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
226 Nullability::Nullable
227 )
228 );
229
230 let result = result.to_listview();
231
232 assert_eq!(result.len(), 4);
233
234 let element_dtype: Arc<DType> = Arc::new(I32.into());
235
236 assert!(result.is_valid(0).unwrap());
237 assert_eq!(
238 result.scalar_at(0).unwrap(),
239 Scalar::list(
240 Arc::clone(&element_dtype),
241 vec![0i32.into(), 5.into()],
242 Nullability::Nullable
243 )
244 );
245
246 assert!(result.is_invalid(1).unwrap());
247
248 assert!(result.is_valid(2).unwrap());
249 assert_eq!(
250 result.scalar_at(2).unwrap(),
251 Scalar::list(
252 Arc::clone(&element_dtype),
253 vec![3i32.into()],
254 Nullability::Nullable
255 )
256 );
257
258 assert!(result.is_valid(3).unwrap());
259 assert_eq!(
260 result.scalar_at(3).unwrap(),
261 Scalar::list(element_dtype, vec![], Nullability::Nullable)
262 );
263 }
264
265 #[test]
266 fn change_validity() {
267 let list = ListArray::try_new(
268 buffer![0i32, 5, 3, 4].into_array(),
269 buffer![0, 2, 3].into_array(),
270 Validity::NonNullable,
271 )
272 .unwrap()
273 .into_array();
274
275 let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).into_array();
276 let result = list.take(idx).unwrap();
279 assert_eq!(
280 result.dtype(),
281 &DType::List(
282 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
283 Nullability::Nullable
284 )
285 );
286 }
287
288 #[test]
289 fn non_nullable_take() {
290 let list = ListArray::try_new(
291 buffer![0i32, 5, 3, 4].into_array(),
292 buffer![0, 2, 3, 3, 4].into_array(),
293 Validity::NonNullable,
294 )
295 .unwrap()
296 .into_array();
297
298 let idx = buffer![1, 0, 2].into_array();
299
300 let result = list.take(idx).unwrap();
301
302 assert_eq!(
303 result.dtype(),
304 &DType::List(
305 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
306 Nullability::NonNullable
307 )
308 );
309
310 let result = result.to_listview();
311
312 assert_eq!(result.len(), 3);
313
314 let element_dtype: Arc<DType> = Arc::new(I32.into());
315
316 assert!(result.is_valid(0).unwrap());
317 assert_eq!(
318 result.scalar_at(0).unwrap(),
319 Scalar::list(
320 Arc::clone(&element_dtype),
321 vec![3i32.into()],
322 Nullability::NonNullable
323 )
324 );
325
326 assert!(result.is_valid(1).unwrap());
327 assert_eq!(
328 result.scalar_at(1).unwrap(),
329 Scalar::list(
330 Arc::clone(&element_dtype),
331 vec![0i32.into(), 5.into()],
332 Nullability::NonNullable
333 )
334 );
335
336 assert!(result.is_valid(2).unwrap());
337 assert_eq!(
338 result.scalar_at(2).unwrap(),
339 Scalar::list(element_dtype, vec![], Nullability::NonNullable)
340 );
341 }
342
343 #[test]
344 fn test_take_empty_array() {
345 let list = ListArray::try_new(
346 buffer![0i32, 5, 3, 4].into_array(),
347 buffer![0].into_array(),
348 Validity::NonNullable,
349 )
350 .unwrap()
351 .into_array();
352
353 let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).into_array();
354
355 let result = list.take(idx).unwrap();
356 assert_eq!(
357 result.dtype(),
358 &DType::List(
359 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
360 Nullability::Nullable
361 )
362 );
363 assert_eq!(result.len(), 0,);
364 }
365
366 #[rstest]
367 #[case(ListArray::try_new(
368 buffer![0i32, 1, 2, 3, 4, 5].into_array(),
369 buffer![0, 2, 3, 5, 5, 6].into_array(),
370 Validity::NonNullable,
371 ).unwrap())]
372 #[case(ListArray::try_new(
373 buffer![10i32, 20, 30, 40, 50].into_array(),
374 buffer![0, 2, 3, 4, 5].into_array(),
375 Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).into_array()),
376 ).unwrap())]
377 #[case(ListArray::try_new(
378 buffer![1i32, 2, 3].into_array(),
379 buffer![0, 0, 2, 2, 3].into_array(), Validity::NonNullable,
381 ).unwrap())]
382 #[case(ListArray::try_new(
383 buffer![42i32, 43].into_array(),
384 buffer![0, 2].into_array(),
385 Validity::NonNullable,
386 ).unwrap())]
387 #[case({
388 let elements = buffer![0i32..200].into_array();
389 let mut offsets = vec![0u64];
390 for i in 1..=50 {
391 offsets.push(offsets[i - 1] + (i as u64 % 5)); }
393 ListArray::try_new(
394 elements,
395 PrimitiveArray::from_iter(offsets).into_array(),
396 Validity::NonNullable,
397 ).unwrap()
398 })]
399 #[case(ListArray::try_new(
400 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).into_array(),
401 buffer![0, 2, 3, 5].into_array(),
402 Validity::NonNullable,
403 ).unwrap())]
404 fn test_take_list_conformance(#[case] list: ListArray) {
405 test_take_conformance(&list.into_array());
406 }
407
408 #[test]
409 fn test_u64_offset_accumulation_non_nullable() {
410 let elements = buffer![0i32; 200].into_array();
411 let offsets = buffer![0u8, 200].into_array();
412 let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
413 .unwrap()
414 .into_array();
415
416 let idx = buffer![0u8, 0].into_array();
418 let result = list.take(idx).unwrap();
419
420 assert_eq!(result.len(), 2);
421
422 let result_view = result.to_listview();
423 assert_eq!(result_view.len(), 2);
424 assert!(result_view.is_valid(0).unwrap());
425 assert!(result_view.is_valid(1).unwrap());
426 }
427
428 #[test]
429 fn test_u64_offset_accumulation_nullable() {
430 let elements = buffer![0i32; 150].into_array();
431 let offsets = buffer![0u8, 150, 150].into_array();
432 let validity = BoolArray::from_iter(vec![true, false]).into_array();
433 let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
434 .unwrap()
435 .into_array();
436
437 let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).into_array();
439 let result = list.take(idx).unwrap();
440
441 assert_eq!(result.len(), 3);
442
443 let result_view = result.to_listview();
444 assert_eq!(result_view.len(), 3);
445 assert!(result_view.is_valid(0).unwrap());
446 assert!(result_view.is_invalid(1).unwrap());
447 assert!(result_view.is_valid(2).unwrap());
448 }
449
450 #[test]
455 fn test_take_validity_length_mismatch_regression() {
456 let list = ListArray::try_new(
458 buffer![1i32, 2, 3, 4].into_array(),
459 buffer![0, 2, 4].into_array(),
460 Validity::Array(BoolArray::from_iter(vec![true, true]).into_array()),
461 )
462 .unwrap()
463 .into_array();
464
465 let idx = buffer![0u32, 1, 0, 1].into_array();
467
468 let result = list.take(idx).unwrap();
470 assert_eq!(result.len(), 4);
471 }
472}