1use vortex_dtype::IntegerPType;
5use vortex_dtype::Nullability;
6use vortex_dtype::match_each_integer_ptype;
7use vortex_dtype::match_smallest_offset_type;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use crate::Array;
12use crate::ArrayRef;
13use crate::ToCanonical;
14use crate::arrays::ListArray;
15use crate::arrays::ListVTable;
16use crate::arrays::PrimitiveArray;
17use crate::arrays::TakeExecute;
18use crate::builders::ArrayBuilder;
19use crate::builders::PrimitiveBuilder;
20use crate::executor::ExecutionCtx;
21use crate::vtable::ValidityHelper;
22
23impl TakeExecute for ListVTable {
27 #[expect(clippy::cognitive_complexity)]
33 fn take(
34 array: &ListArray,
35 indices: &dyn Array,
36 _ctx: &mut ExecutionCtx,
37 ) -> VortexResult<Option<ArrayRef>> {
38 let indices = indices.to_primitive();
39 let total_approx = array.elements().len().saturating_mul(indices.len());
41
42 match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| {
43 match_each_integer_ptype!(indices.ptype(), |I| {
44 match_smallest_offset_type!(total_approx, |OutputOffsetType| {
45 _take::<I, O, OutputOffsetType>(array, &indices).map(Some)
46 })
47 })
48 })
49 }
50}
51
52fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
53 array: &ListArray,
54 indices_array: &PrimitiveArray,
55) -> VortexResult<ArrayRef> {
56 let data_validity = array.validity_mask()?;
57 let indices_validity = indices_array.validity_mask()?;
58
59 if !indices_validity.all_true() || !data_validity.all_true() {
60 return _take_nullable::<I, O, OutputOffsetType>(array, indices_array);
61 }
62
63 let offsets_array = array.offsets().to_primitive();
64 let offsets: &[O] = offsets_array.as_slice();
65 let indices: &[I] = indices_array.as_slice();
66
67 let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
68 Nullability::NonNullable,
69 indices.len(),
70 );
71 let mut elements_to_take =
72 PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
73
74 let mut current_offset = OutputOffsetType::zero();
75 new_offsets.append_zero();
76
77 for &data_idx in indices {
78 let data_idx: usize = data_idx.as_();
79
80 let start = offsets[data_idx];
81 let stop = offsets[data_idx + 1];
82
83 let additional: usize = (stop - start).as_();
89
90 elements_to_take.reserve_exact(additional);
92 for i in 0..additional {
93 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
94 }
95 current_offset +=
96 OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
97 new_offsets.append_value(current_offset);
98 }
99
100 let elements_to_take = elements_to_take.finish();
101 let new_offsets = new_offsets.finish();
102
103 let new_elements = array.elements().take(elements_to_take.to_array())?;
104
105 Ok(ListArray::try_new(
106 new_elements,
107 new_offsets,
108 array.validity().clone().take(indices_array.as_ref())?,
109 )?
110 .to_array())
111}
112
113fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
114 array: &ListArray,
115 indices_array: &PrimitiveArray,
116) -> VortexResult<ArrayRef> {
117 let offsets_array = array.offsets().to_primitive();
118 let offsets: &[O] = offsets_array.as_slice();
119 let indices: &[I] = indices_array.as_slice();
120 let data_validity = array.validity_mask()?;
121 let indices_validity = indices_array.validity_mask()?;
122
123 let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
124 Nullability::NonNullable,
125 indices.len(),
126 );
127
128 let mut elements_to_take =
136 PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
137
138 let mut current_offset = OutputOffsetType::zero();
139 new_offsets.append_zero();
140
141 for (idx, data_idx) in indices.iter().enumerate() {
142 if !indices_validity.value(idx) {
143 new_offsets.append_value(current_offset);
144 continue;
145 }
146
147 let data_idx: usize = data_idx.as_();
148
149 if !data_validity.value(data_idx) {
150 new_offsets.append_value(current_offset);
151 continue;
152 }
153
154 let start = offsets[data_idx];
155 let stop = offsets[data_idx + 1];
156
157 let additional: usize = (stop - start).as_();
159
160 elements_to_take.reserve_exact(additional);
161 for i in 0..additional {
162 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
163 }
164 current_offset +=
165 OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
166 new_offsets.append_value(current_offset);
167 }
168
169 let elements_to_take = elements_to_take.finish();
170 let new_offsets = new_offsets.finish();
171 let new_elements = array.elements().take(elements_to_take.to_array())?;
172
173 Ok(ListArray::try_new(
174 new_elements,
175 new_offsets,
176 array.validity().clone().take(indices_array.as_ref())?,
177 )?
178 .to_array())
179}
180
181#[cfg(test)]
182mod test {
183 use std::sync::Arc;
184
185 use rstest::rstest;
186 use vortex_buffer::buffer;
187 use vortex_dtype::DType;
188 use vortex_dtype::Nullability;
189 use vortex_dtype::PType::I32;
190
191 use crate::Array;
192 use crate::IntoArray as _;
193 use crate::ToCanonical;
194 use crate::arrays::BoolArray;
195 use crate::arrays::PrimitiveArray;
196 use crate::arrays::list::ListArray;
197 use crate::compute::conformance::take::test_take_conformance;
198 use crate::scalar::Scalar;
199 use crate::validity::Validity;
200
201 #[test]
202 fn nullable_take() {
203 let list = ListArray::try_new(
204 buffer![0i32, 5, 3, 4].into_array(),
205 buffer![0, 2, 3, 4, 4].into_array(),
206 Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
207 )
208 .unwrap()
209 .to_array();
210
211 let idx =
212 PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
213
214 let result = list.take(idx.to_array()).unwrap();
215
216 assert_eq!(
217 result.dtype(),
218 &DType::List(
219 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
220 Nullability::Nullable
221 )
222 );
223
224 let result = result.to_listview();
225
226 assert_eq!(result.len(), 4);
227
228 let element_dtype: Arc<DType> = Arc::new(I32.into());
229
230 assert!(result.is_valid(0).unwrap());
231 assert_eq!(
232 result.scalar_at(0).unwrap(),
233 Scalar::list(
234 element_dtype.clone(),
235 vec![0i32.into(), 5.into()],
236 Nullability::Nullable
237 )
238 );
239
240 assert!(result.is_invalid(1).unwrap());
241
242 assert!(result.is_valid(2).unwrap());
243 assert_eq!(
244 result.scalar_at(2).unwrap(),
245 Scalar::list(
246 element_dtype.clone(),
247 vec![3i32.into()],
248 Nullability::Nullable
249 )
250 );
251
252 assert!(result.is_valid(3).unwrap());
253 assert_eq!(
254 result.scalar_at(3).unwrap(),
255 Scalar::list(element_dtype, vec![], Nullability::Nullable)
256 );
257 }
258
259 #[test]
260 fn change_validity() {
261 let list = ListArray::try_new(
262 buffer![0i32, 5, 3, 4].into_array(),
263 buffer![0, 2, 3].into_array(),
264 Validity::NonNullable,
265 )
266 .unwrap()
267 .to_array();
268
269 let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
270 let result = list.take(idx.to_array()).unwrap();
273 assert_eq!(
274 result.dtype(),
275 &DType::List(
276 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
277 Nullability::Nullable
278 )
279 );
280 }
281
282 #[test]
283 fn non_nullable_take() {
284 let list = ListArray::try_new(
285 buffer![0i32, 5, 3, 4].into_array(),
286 buffer![0, 2, 3, 3, 4].into_array(),
287 Validity::NonNullable,
288 )
289 .unwrap()
290 .to_array();
291
292 let idx = buffer![1, 0, 2].into_array();
293
294 let result = list.take(idx.to_array()).unwrap();
295
296 assert_eq!(
297 result.dtype(),
298 &DType::List(
299 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
300 Nullability::NonNullable
301 )
302 );
303
304 let result = result.to_listview();
305
306 assert_eq!(result.len(), 3);
307
308 let element_dtype: Arc<DType> = Arc::new(I32.into());
309
310 assert!(result.is_valid(0).unwrap());
311 assert_eq!(
312 result.scalar_at(0).unwrap(),
313 Scalar::list(
314 element_dtype.clone(),
315 vec![3i32.into()],
316 Nullability::NonNullable
317 )
318 );
319
320 assert!(result.is_valid(1).unwrap());
321 assert_eq!(
322 result.scalar_at(1).unwrap(),
323 Scalar::list(
324 element_dtype.clone(),
325 vec![0i32.into(), 5.into()],
326 Nullability::NonNullable
327 )
328 );
329
330 assert!(result.is_valid(2).unwrap());
331 assert_eq!(
332 result.scalar_at(2).unwrap(),
333 Scalar::list(element_dtype, vec![], Nullability::NonNullable)
334 );
335 }
336
337 #[test]
338 fn test_take_empty_array() {
339 let list = ListArray::try_new(
340 buffer![0i32, 5, 3, 4].into_array(),
341 buffer![0].into_array(),
342 Validity::NonNullable,
343 )
344 .unwrap()
345 .to_array();
346
347 let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
348
349 let result = list.take(idx.to_array()).unwrap();
350 assert_eq!(
351 result.dtype(),
352 &DType::List(
353 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
354 Nullability::Nullable
355 )
356 );
357 assert_eq!(result.len(), 0,);
358 }
359
360 #[rstest]
361 #[case(ListArray::try_new(
362 buffer![0i32, 1, 2, 3, 4, 5].into_array(),
363 buffer![0, 2, 3, 5, 5, 6].into_array(),
364 Validity::NonNullable,
365 ).unwrap())]
366 #[case(ListArray::try_new(
367 buffer![10i32, 20, 30, 40, 50].into_array(),
368 buffer![0, 2, 3, 4, 5].into_array(),
369 Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).to_array()),
370 ).unwrap())]
371 #[case(ListArray::try_new(
372 buffer![1i32, 2, 3].into_array(),
373 buffer![0, 0, 2, 2, 3].into_array(), Validity::NonNullable,
375 ).unwrap())]
376 #[case(ListArray::try_new(
377 buffer![42i32, 43].into_array(),
378 buffer![0, 2].into_array(),
379 Validity::NonNullable,
380 ).unwrap())]
381 #[case({
382 let elements = buffer![0i32..200].into_array();
383 let mut offsets = vec![0u64];
384 for i in 1..=50 {
385 offsets.push(offsets[i - 1] + (i as u64 % 5)); }
387 ListArray::try_new(
388 elements,
389 PrimitiveArray::from_iter(offsets).to_array(),
390 Validity::NonNullable,
391 ).unwrap()
392 })]
393 #[case(ListArray::try_new(
394 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).to_array(),
395 buffer![0, 2, 3, 5].into_array(),
396 Validity::NonNullable,
397 ).unwrap())]
398 fn test_take_list_conformance(#[case] list: ListArray) {
399 test_take_conformance(list.as_ref());
400 }
401
402 #[test]
403 fn test_u64_offset_accumulation_non_nullable() {
404 let elements = buffer![0i32; 200].into_array();
405 let offsets = buffer![0u8, 200].into_array();
406 let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
407 .unwrap()
408 .to_array();
409
410 let idx = buffer![0u8, 0].into_array();
412 let result = list.take(idx.to_array()).unwrap();
413
414 assert_eq!(result.len(), 2);
415
416 let result_view = result.to_listview();
417 assert_eq!(result_view.len(), 2);
418 assert!(result_view.is_valid(0).unwrap());
419 assert!(result_view.is_valid(1).unwrap());
420 }
421
422 #[test]
423 fn test_u64_offset_accumulation_nullable() {
424 let elements = buffer![0i32; 150].into_array();
425 let offsets = buffer![0u8, 150, 150].into_array();
426 let validity = BoolArray::from_iter(vec![true, false]).to_array();
427 let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
428 .unwrap()
429 .to_array();
430
431 let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).to_array();
433 let result = list.take(idx.to_array()).unwrap();
434
435 assert_eq!(result.len(), 3);
436
437 let result_view = result.to_listview();
438 assert_eq!(result_view.len(), 3);
439 assert!(result_view.is_valid(0).unwrap());
440 assert!(result_view.is_invalid(1).unwrap());
441 assert!(result_view.is_valid(2).unwrap());
442 }
443
444 #[test]
449 fn test_take_validity_length_mismatch_regression() {
450 let list = ListArray::try_new(
452 buffer![1i32, 2, 3, 4].into_array(),
453 buffer![0, 2, 4].into_array(),
454 Validity::Array(BoolArray::from_iter(vec![true, true]).to_array()),
455 )
456 .unwrap()
457 .to_array();
458
459 let idx = buffer![0u32, 1, 0, 1].into_array();
461
462 let result = list.take(idx.to_array()).unwrap();
464 assert_eq!(result.len(), 4);
465 }
466}