1use vortex_dtype::IntegerPType;
5use vortex_dtype::Nullability;
6use vortex_dtype::match_each_integer_ptype;
7use vortex_dtype::match_smallest_offset_type;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use crate::Array;
12use crate::ArrayRef;
13use crate::ToCanonical;
14use crate::arrays::ListArray;
15use crate::arrays::ListVTable;
16use crate::arrays::PrimitiveArray;
17use crate::builders::ArrayBuilder;
18use crate::builders::PrimitiveBuilder;
19use crate::compute::TakeKernel;
20use crate::compute::TakeKernelAdapter;
21use crate::compute::take;
22use crate::register_kernel;
23use crate::vtable::ValidityHelper;
24
25impl TakeKernel for ListVTable {
34 #[expect(clippy::cognitive_complexity)]
35 fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
36 let indices = indices.to_primitive();
37 let total_approx = array.elements().len().saturating_mul(indices.len());
39
40 match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| {
41 match_each_integer_ptype!(indices.ptype(), |I| {
42 match_smallest_offset_type!(total_approx, |OutputOffsetType| {
43 _take::<I, O, OutputOffsetType>(array, &indices)
44 })
45 })
46 })
47 }
48}
49
50register_kernel!(TakeKernelAdapter(ListVTable).lift());
51
52fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
53 array: &ListArray,
54 indices_array: &PrimitiveArray,
55) -> VortexResult<ArrayRef> {
56 let data_validity = array.validity_mask();
57 let indices_validity = indices_array.validity_mask();
58
59 if !indices_validity.all_true() || !data_validity.all_true() {
60 return _take_nullable::<I, O, OutputOffsetType>(array, indices_array);
61 }
62
63 let offsets_array = array.offsets().to_primitive();
64 let offsets: &[O] = offsets_array.as_slice();
65 let indices: &[I] = indices_array.as_slice();
66
67 let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
68 Nullability::NonNullable,
69 indices.len(),
70 );
71 let mut elements_to_take =
72 PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
73
74 let mut current_offset = OutputOffsetType::zero();
75 new_offsets.append_zero();
76
77 for &data_idx in indices {
78 let data_idx: usize = data_idx.as_();
79
80 let start = offsets[data_idx];
81 let stop = offsets[data_idx + 1];
82
83 let additional: usize = (stop - start).as_();
89
90 elements_to_take.reserve_exact(additional);
92 for i in 0..additional {
93 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
94 }
95 current_offset +=
96 OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
97 new_offsets.append_value(current_offset);
98 }
99
100 let elements_to_take = elements_to_take.finish();
101 let new_offsets = new_offsets.finish();
102
103 let new_elements = take(array.elements(), elements_to_take.as_ref())?;
104
105 Ok(ListArray::try_new(
106 new_elements,
107 new_offsets,
108 array.validity().clone().take(indices_array.as_ref())?,
109 )?
110 .to_array())
111}
112
113fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
114 array: &ListArray,
115 indices_array: &PrimitiveArray,
116) -> VortexResult<ArrayRef> {
117 let offsets_array = array.offsets().to_primitive();
118 let offsets: &[O] = offsets_array.as_slice();
119 let indices: &[I] = indices_array.as_slice();
120 let data_validity = array.validity_mask();
121 let indices_validity = indices_array.validity_mask();
122
123 let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
124 Nullability::NonNullable,
125 indices.len(),
126 );
127
128 let mut elements_to_take =
136 PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
137
138 let mut current_offset = OutputOffsetType::zero();
139 new_offsets.append_zero();
140
141 for (idx, data_idx) in indices.iter().enumerate() {
142 if !indices_validity.value(idx) {
143 new_offsets.append_value(current_offset);
144 continue;
145 }
146
147 let data_idx: usize = data_idx.as_();
148
149 if !data_validity.value(data_idx) {
150 new_offsets.append_value(current_offset);
151 continue;
152 }
153
154 let start = offsets[data_idx];
155 let stop = offsets[data_idx + 1];
156
157 let additional: usize = (stop - start).as_();
159
160 elements_to_take.reserve_exact(additional);
161 for i in 0..additional {
162 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
163 }
164 current_offset +=
165 OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
166 new_offsets.append_value(current_offset);
167 }
168
169 let elements_to_take = elements_to_take.finish();
170 let new_offsets = new_offsets.finish();
171 let new_elements = take(array.elements(), elements_to_take.as_ref())?;
172
173 Ok(ListArray::try_new(
174 new_elements,
175 new_offsets,
176 array.validity().clone().take(indices_array.as_ref())?,
177 )?
178 .to_array())
179}
180
181#[cfg(test)]
182mod test {
183 use std::sync::Arc;
184
185 use rstest::rstest;
186 use vortex_buffer::buffer;
187 use vortex_dtype::DType;
188 use vortex_dtype::Nullability;
189 use vortex_dtype::PType::I32;
190 use vortex_scalar::Scalar;
191
192 use crate::Array;
193 use crate::IntoArray as _;
194 use crate::ToCanonical;
195 use crate::arrays::BoolArray;
196 use crate::arrays::PrimitiveArray;
197 use crate::arrays::list::ListArray;
198 use crate::compute::conformance::take::test_take_conformance;
199 use crate::compute::take;
200 use crate::validity::Validity;
201
202 #[test]
203 fn nullable_take() {
204 let list = ListArray::try_new(
205 buffer![0i32, 5, 3, 4].into_array(),
206 buffer![0, 2, 3, 4, 4].into_array(),
207 Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
208 )
209 .unwrap()
210 .to_array();
211
212 let idx =
213 PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
214
215 let result = take(&list, &idx).unwrap();
216
217 assert_eq!(
218 result.dtype(),
219 &DType::List(
220 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
221 Nullability::Nullable
222 )
223 );
224
225 let result = result.to_listview();
226
227 assert_eq!(result.len(), 4);
228
229 let element_dtype: Arc<DType> = Arc::new(I32.into());
230
231 assert!(result.is_valid(0));
232 assert_eq!(
233 result.scalar_at(0),
234 Scalar::list(
235 element_dtype.clone(),
236 vec![0i32.into(), 5.into()],
237 Nullability::Nullable
238 )
239 );
240
241 assert!(result.is_invalid(1));
242
243 assert!(result.is_valid(2));
244 assert_eq!(
245 result.scalar_at(2),
246 Scalar::list(
247 element_dtype.clone(),
248 vec![3i32.into()],
249 Nullability::Nullable
250 )
251 );
252
253 assert!(result.is_valid(3));
254 assert_eq!(
255 result.scalar_at(3),
256 Scalar::list(element_dtype, vec![], Nullability::Nullable)
257 );
258 }
259
260 #[test]
261 fn change_validity() {
262 let list = ListArray::try_new(
263 buffer![0i32, 5, 3, 4].into_array(),
264 buffer![0, 2, 3].into_array(),
265 Validity::NonNullable,
266 )
267 .unwrap()
268 .to_array();
269
270 let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
271 let result = take(&list, &idx).unwrap();
274 assert_eq!(
275 result.dtype(),
276 &DType::List(
277 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
278 Nullability::Nullable
279 )
280 );
281 }
282
283 #[test]
284 fn non_nullable_take() {
285 let list = ListArray::try_new(
286 buffer![0i32, 5, 3, 4].into_array(),
287 buffer![0, 2, 3, 3, 4].into_array(),
288 Validity::NonNullable,
289 )
290 .unwrap()
291 .to_array();
292
293 let idx = buffer![1, 0, 2].into_array();
294
295 let result = take(&list, &idx).unwrap();
296
297 assert_eq!(
298 result.dtype(),
299 &DType::List(
300 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
301 Nullability::NonNullable
302 )
303 );
304
305 let result = result.to_listview();
306
307 assert_eq!(result.len(), 3);
308
309 let element_dtype: Arc<DType> = Arc::new(I32.into());
310
311 assert!(result.is_valid(0));
312 assert_eq!(
313 result.scalar_at(0),
314 Scalar::list(
315 element_dtype.clone(),
316 vec![3i32.into()],
317 Nullability::NonNullable
318 )
319 );
320
321 assert!(result.is_valid(1));
322 assert_eq!(
323 result.scalar_at(1),
324 Scalar::list(
325 element_dtype.clone(),
326 vec![0i32.into(), 5.into()],
327 Nullability::NonNullable
328 )
329 );
330
331 assert!(result.is_valid(2));
332 assert_eq!(
333 result.scalar_at(2),
334 Scalar::list(element_dtype, vec![], Nullability::NonNullable)
335 );
336 }
337
338 #[test]
339 fn test_take_empty_array() {
340 let list = ListArray::try_new(
341 buffer![0i32, 5, 3, 4].into_array(),
342 buffer![0].into_array(),
343 Validity::NonNullable,
344 )
345 .unwrap()
346 .to_array();
347
348 let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
349
350 let result = take(&list, &idx).unwrap();
351 assert_eq!(
352 result.dtype(),
353 &DType::List(
354 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
355 Nullability::Nullable
356 )
357 );
358 assert_eq!(result.len(), 0,);
359 }
360
361 #[rstest]
362 #[case(ListArray::try_new(
363 buffer![0i32, 1, 2, 3, 4, 5].into_array(),
364 buffer![0, 2, 3, 5, 5, 6].into_array(),
365 Validity::NonNullable,
366 ).unwrap())]
367 #[case(ListArray::try_new(
368 buffer![10i32, 20, 30, 40, 50].into_array(),
369 buffer![0, 2, 3, 4, 5].into_array(),
370 Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).to_array()),
371 ).unwrap())]
372 #[case(ListArray::try_new(
373 buffer![1i32, 2, 3].into_array(),
374 buffer![0, 0, 2, 2, 3].into_array(), Validity::NonNullable,
376 ).unwrap())]
377 #[case(ListArray::try_new(
378 buffer![42i32, 43].into_array(),
379 buffer![0, 2].into_array(),
380 Validity::NonNullable,
381 ).unwrap())]
382 #[case({
383 let elements = buffer![0i32..200].into_array();
384 let mut offsets = vec![0u64];
385 for i in 1..=50 {
386 offsets.push(offsets[i - 1] + (i as u64 % 5)); }
388 ListArray::try_new(
389 elements,
390 PrimitiveArray::from_iter(offsets).to_array(),
391 Validity::NonNullable,
392 ).unwrap()
393 })]
394 #[case(ListArray::try_new(
395 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).to_array(),
396 buffer![0, 2, 3, 5].into_array(),
397 Validity::NonNullable,
398 ).unwrap())]
399 fn test_take_list_conformance(#[case] list: ListArray) {
400 test_take_conformance(list.as_ref());
401 }
402
403 #[test]
404 fn test_u64_offset_accumulation_non_nullable() {
405 let elements = buffer![0i32; 200].into_array();
406 let offsets = buffer![0u8, 200].into_array();
407 let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
408 .unwrap()
409 .to_array();
410
411 let idx = buffer![0u8, 0].into_array();
413 let result = take(&list, &idx).unwrap();
414
415 assert_eq!(result.len(), 2);
416
417 let result_view = result.to_listview();
418 assert_eq!(result_view.len(), 2);
419 assert!(result_view.is_valid(0));
420 assert!(result_view.is_valid(1));
421 }
422
423 #[test]
424 fn test_u64_offset_accumulation_nullable() {
425 let elements = buffer![0i32; 150].into_array();
426 let offsets = buffer![0u8, 150, 150].into_array();
427 let validity = BoolArray::from_iter(vec![true, false]).to_array();
428 let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
429 .unwrap()
430 .to_array();
431
432 let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).to_array();
434 let result = take(&list, &idx).unwrap();
435
436 assert_eq!(result.len(), 3);
437
438 let result_view = result.to_listview();
439 assert_eq!(result_view.len(), 3);
440 assert!(result_view.is_valid(0));
441 assert!(result_view.is_invalid(1));
442 assert!(result_view.is_valid(2));
443 }
444
445 #[test]
450 fn test_take_validity_length_mismatch_regression() {
451 let list = ListArray::try_new(
453 buffer![1i32, 2, 3, 4].into_array(),
454 buffer![0, 2, 4].into_array(),
455 Validity::Array(BoolArray::from_iter(vec![true, true]).to_array()),
456 )
457 .unwrap()
458 .to_array();
459
460 let idx = buffer![0u32, 1, 0, 1].into_array();
462
463 let result = take(&list, &idx).unwrap();
465 assert_eq!(result.len(), 4);
466 }
467}