1use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6
7use crate::Array;
8use crate::ArrayRef;
9use crate::arrays::ListArray;
10use crate::arrays::ListVTable;
11use crate::arrays::PrimitiveArray;
12use crate::arrays::TakeExecute;
13use crate::builders::ArrayBuilder;
14use crate::builders::PrimitiveBuilder;
15use crate::dtype::IntegerPType;
16use crate::dtype::Nullability;
17use crate::executor::ExecutionCtx;
18use crate::match_each_integer_ptype;
19use crate::match_smallest_offset_type;
20use crate::vtable::ValidityHelper;
21
22impl TakeExecute for ListVTable {
26 #[expect(clippy::cognitive_complexity)]
32 fn take(
33 array: &ListArray,
34 indices: &ArrayRef,
35 ctx: &mut ExecutionCtx,
36 ) -> VortexResult<Option<ArrayRef>> {
37 let indices = indices.to_array().execute::<PrimitiveArray>(ctx)?;
38 let total_approx = array.elements().len().saturating_mul(indices.len());
40
41 match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| {
42 match_each_integer_ptype!(indices.ptype(), |I| {
43 match_smallest_offset_type!(total_approx, |OutputOffsetType| {
44 _take::<I, O, OutputOffsetType>(array, &indices, ctx).map(Some)
45 })
46 })
47 })
48 }
49}
50
51fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
52 array: &ListArray,
53 indices_array: &PrimitiveArray,
54 ctx: &mut ExecutionCtx,
55) -> VortexResult<ArrayRef> {
56 let data_validity = array.validity_mask()?;
57 let indices_validity = indices_array.validity_mask()?;
58
59 if !indices_validity.all_true() || !data_validity.all_true() {
60 return _take_nullable::<I, O, OutputOffsetType>(array, indices_array, ctx);
61 }
62
63 let offsets_array = array.offsets().to_array().execute::<PrimitiveArray>(ctx)?;
64 let offsets: &[O] = offsets_array.as_slice();
65 let indices: &[I] = indices_array.as_slice();
66
67 let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
68 Nullability::NonNullable,
69 indices.len(),
70 );
71 let mut elements_to_take =
72 PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
73
74 let mut current_offset = OutputOffsetType::zero();
75 new_offsets.append_zero();
76
77 for &data_idx in indices {
78 let data_idx: usize = data_idx.as_();
79
80 let start = offsets[data_idx];
81 let stop = offsets[data_idx + 1];
82
83 let additional: usize = (stop - start).as_();
89
90 elements_to_take.reserve_exact(additional);
92 for i in 0..additional {
93 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
94 }
95 current_offset +=
96 OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
97 new_offsets.append_value(current_offset);
98 }
99
100 let elements_to_take = elements_to_take.finish();
101 let new_offsets = new_offsets.finish();
102
103 let new_elements = array.elements().take(elements_to_take.to_array())?;
104
105 Ok(ListArray::try_new(
106 new_elements,
107 new_offsets,
108 array.validity().clone().take(&indices_array.to_array())?,
109 )?
110 .to_array())
111}
112
113fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
114 array: &ListArray,
115 indices_array: &PrimitiveArray,
116 ctx: &mut ExecutionCtx,
117) -> VortexResult<ArrayRef> {
118 let offsets_array = array.offsets().to_array().execute::<PrimitiveArray>(ctx)?;
119 let offsets: &[O] = offsets_array.as_slice();
120 let indices: &[I] = indices_array.as_slice();
121 let data_validity = array.validity_mask()?;
122 let indices_validity = indices_array.validity_mask()?;
123
124 let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
125 Nullability::NonNullable,
126 indices.len(),
127 );
128
129 let mut elements_to_take =
137 PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
138
139 let mut current_offset = OutputOffsetType::zero();
140 new_offsets.append_zero();
141
142 for (idx, data_idx) in indices.iter().enumerate() {
143 if !indices_validity.value(idx) {
144 new_offsets.append_value(current_offset);
145 continue;
146 }
147
148 let data_idx: usize = data_idx.as_();
149
150 if !data_validity.value(data_idx) {
151 new_offsets.append_value(current_offset);
152 continue;
153 }
154
155 let start = offsets[data_idx];
156 let stop = offsets[data_idx + 1];
157
158 let additional: usize = (stop - start).as_();
160
161 elements_to_take.reserve_exact(additional);
162 for i in 0..additional {
163 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
164 }
165 current_offset +=
166 OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
167 new_offsets.append_value(current_offset);
168 }
169
170 let elements_to_take = elements_to_take.finish();
171 let new_offsets = new_offsets.finish();
172 let new_elements = array.elements().take(elements_to_take.to_array())?;
173
174 Ok(ListArray::try_new(
175 new_elements,
176 new_offsets,
177 array.validity().clone().take(&indices_array.to_array())?,
178 )?
179 .to_array())
180}
181
182#[cfg(test)]
183mod test {
184 use std::sync::Arc;
185
186 use rstest::rstest;
187 use vortex_buffer::buffer;
188
189 use crate::Array;
190 use crate::IntoArray as _;
191 use crate::ToCanonical;
192 use crate::arrays::BoolArray;
193 use crate::arrays::PrimitiveArray;
194 use crate::arrays::list::ListArray;
195 use crate::compute::conformance::take::test_take_conformance;
196 use crate::dtype::DType;
197 use crate::dtype::Nullability;
198 use crate::dtype::PType::I32;
199 use crate::scalar::Scalar;
200 use crate::validity::Validity;
201
202 #[test]
203 fn nullable_take() {
204 let list = ListArray::try_new(
205 buffer![0i32, 5, 3, 4].into_array(),
206 buffer![0, 2, 3, 4, 4].into_array(),
207 Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
208 )
209 .unwrap()
210 .to_array();
211
212 let idx =
213 PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
214
215 let result = list.take(idx.to_array()).unwrap();
216
217 assert_eq!(
218 result.dtype(),
219 &DType::List(
220 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
221 Nullability::Nullable
222 )
223 );
224
225 let result = result.to_listview();
226
227 assert_eq!(result.len(), 4);
228
229 let element_dtype: Arc<DType> = Arc::new(I32.into());
230
231 assert!(result.is_valid(0).unwrap());
232 assert_eq!(
233 result.scalar_at(0).unwrap(),
234 Scalar::list(
235 element_dtype.clone(),
236 vec![0i32.into(), 5.into()],
237 Nullability::Nullable
238 )
239 );
240
241 assert!(result.is_invalid(1).unwrap());
242
243 assert!(result.is_valid(2).unwrap());
244 assert_eq!(
245 result.scalar_at(2).unwrap(),
246 Scalar::list(
247 element_dtype.clone(),
248 vec![3i32.into()],
249 Nullability::Nullable
250 )
251 );
252
253 assert!(result.is_valid(3).unwrap());
254 assert_eq!(
255 result.scalar_at(3).unwrap(),
256 Scalar::list(element_dtype, vec![], Nullability::Nullable)
257 );
258 }
259
260 #[test]
261 fn change_validity() {
262 let list = ListArray::try_new(
263 buffer![0i32, 5, 3, 4].into_array(),
264 buffer![0, 2, 3].into_array(),
265 Validity::NonNullable,
266 )
267 .unwrap()
268 .to_array();
269
270 let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
271 let result = list.take(idx.to_array()).unwrap();
274 assert_eq!(
275 result.dtype(),
276 &DType::List(
277 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
278 Nullability::Nullable
279 )
280 );
281 }
282
283 #[test]
284 fn non_nullable_take() {
285 let list = ListArray::try_new(
286 buffer![0i32, 5, 3, 4].into_array(),
287 buffer![0, 2, 3, 3, 4].into_array(),
288 Validity::NonNullable,
289 )
290 .unwrap()
291 .to_array();
292
293 let idx = buffer![1, 0, 2].into_array();
294
295 let result = list.take(idx.to_array()).unwrap();
296
297 assert_eq!(
298 result.dtype(),
299 &DType::List(
300 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
301 Nullability::NonNullable
302 )
303 );
304
305 let result = result.to_listview();
306
307 assert_eq!(result.len(), 3);
308
309 let element_dtype: Arc<DType> = Arc::new(I32.into());
310
311 assert!(result.is_valid(0).unwrap());
312 assert_eq!(
313 result.scalar_at(0).unwrap(),
314 Scalar::list(
315 element_dtype.clone(),
316 vec![3i32.into()],
317 Nullability::NonNullable
318 )
319 );
320
321 assert!(result.is_valid(1).unwrap());
322 assert_eq!(
323 result.scalar_at(1).unwrap(),
324 Scalar::list(
325 element_dtype.clone(),
326 vec![0i32.into(), 5.into()],
327 Nullability::NonNullable
328 )
329 );
330
331 assert!(result.is_valid(2).unwrap());
332 assert_eq!(
333 result.scalar_at(2).unwrap(),
334 Scalar::list(element_dtype, vec![], Nullability::NonNullable)
335 );
336 }
337
338 #[test]
339 fn test_take_empty_array() {
340 let list = ListArray::try_new(
341 buffer![0i32, 5, 3, 4].into_array(),
342 buffer![0].into_array(),
343 Validity::NonNullable,
344 )
345 .unwrap()
346 .to_array();
347
348 let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
349
350 let result = list.take(idx.to_array()).unwrap();
351 assert_eq!(
352 result.dtype(),
353 &DType::List(
354 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
355 Nullability::Nullable
356 )
357 );
358 assert_eq!(result.len(), 0,);
359 }
360
361 #[rstest]
362 #[case(ListArray::try_new(
363 buffer![0i32, 1, 2, 3, 4, 5].into_array(),
364 buffer![0, 2, 3, 5, 5, 6].into_array(),
365 Validity::NonNullable,
366 ).unwrap())]
367 #[case(ListArray::try_new(
368 buffer![10i32, 20, 30, 40, 50].into_array(),
369 buffer![0, 2, 3, 4, 5].into_array(),
370 Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).to_array()),
371 ).unwrap())]
372 #[case(ListArray::try_new(
373 buffer![1i32, 2, 3].into_array(),
374 buffer![0, 0, 2, 2, 3].into_array(), Validity::NonNullable,
376 ).unwrap())]
377 #[case(ListArray::try_new(
378 buffer![42i32, 43].into_array(),
379 buffer![0, 2].into_array(),
380 Validity::NonNullable,
381 ).unwrap())]
382 #[case({
383 let elements = buffer![0i32..200].into_array();
384 let mut offsets = vec![0u64];
385 for i in 1..=50 {
386 offsets.push(offsets[i - 1] + (i as u64 % 5)); }
388 ListArray::try_new(
389 elements,
390 PrimitiveArray::from_iter(offsets).to_array(),
391 Validity::NonNullable,
392 ).unwrap()
393 })]
394 #[case(ListArray::try_new(
395 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).to_array(),
396 buffer![0, 2, 3, 5].into_array(),
397 Validity::NonNullable,
398 ).unwrap())]
399 fn test_take_list_conformance(#[case] list: ListArray) {
400 test_take_conformance(&list.to_array());
401 }
402
403 #[test]
404 fn test_u64_offset_accumulation_non_nullable() {
405 let elements = buffer![0i32; 200].into_array();
406 let offsets = buffer![0u8, 200].into_array();
407 let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
408 .unwrap()
409 .to_array();
410
411 let idx = buffer![0u8, 0].into_array();
413 let result = list.take(idx.to_array()).unwrap();
414
415 assert_eq!(result.len(), 2);
416
417 let result_view = result.to_listview();
418 assert_eq!(result_view.len(), 2);
419 assert!(result_view.is_valid(0).unwrap());
420 assert!(result_view.is_valid(1).unwrap());
421 }
422
423 #[test]
424 fn test_u64_offset_accumulation_nullable() {
425 let elements = buffer![0i32; 150].into_array();
426 let offsets = buffer![0u8, 150, 150].into_array();
427 let validity = BoolArray::from_iter(vec![true, false]).to_array();
428 let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
429 .unwrap()
430 .to_array();
431
432 let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).to_array();
434 let result = list.take(idx.to_array()).unwrap();
435
436 assert_eq!(result.len(), 3);
437
438 let result_view = result.to_listview();
439 assert_eq!(result_view.len(), 3);
440 assert!(result_view.is_valid(0).unwrap());
441 assert!(result_view.is_invalid(1).unwrap());
442 assert!(result_view.is_valid(2).unwrap());
443 }
444
445 #[test]
450 fn test_take_validity_length_mismatch_regression() {
451 let list = ListArray::try_new(
453 buffer![1i32, 2, 3, 4].into_array(),
454 buffer![0, 2, 4].into_array(),
455 Validity::Array(BoolArray::from_iter(vec![true, true]).to_array()),
456 )
457 .unwrap()
458 .to_array();
459
460 let idx = buffer![0u32, 1, 0, 1].into_array();
462
463 let result = list.take(idx.to_array()).unwrap();
465 assert_eq!(result.len(), 4);
466 }
467}