vortex_array/arrays/list/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::IntegerPType;
5use vortex_dtype::Nullability;
6use vortex_dtype::match_each_integer_ptype;
7use vortex_dtype::match_smallest_offset_type;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use crate::Array;
12use crate::ArrayRef;
13use crate::ToCanonical;
14use crate::arrays::ListArray;
15use crate::arrays::ListVTable;
16use crate::arrays::PrimitiveArray;
17use crate::builders::ArrayBuilder;
18use crate::builders::PrimitiveBuilder;
19use crate::compute::TakeKernel;
20use crate::compute::TakeKernelAdapter;
21use crate::compute::take;
22use crate::register_kernel;
23use crate::vtable::ValidityHelper;
24
25// TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call
26// the `ListView::take` compute function once `ListView` is more stable.
27
28/// Take implementation for [`ListArray`].
29///
30/// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant
31/// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
32/// non-contiguous indices would violate this requirement.
33impl TakeKernel for ListVTable {
34    #[expect(clippy::cognitive_complexity)]
35    fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
36        let indices = indices.to_primitive();
37        // This is an over-approximation of the total number of elements in the resulting array.
38        let total_approx = array.elements().len().saturating_mul(indices.len());
39
40        match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| {
41            match_each_integer_ptype!(indices.ptype(), |I| {
42                match_smallest_offset_type!(total_approx, |OutputOffsetType| {
43                    _take::<I, O, OutputOffsetType>(array, &indices)
44                })
45            })
46        })
47    }
48}
49
50register_kernel!(TakeKernelAdapter(ListVTable).lift());
51
52fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
53    array: &ListArray,
54    indices_array: &PrimitiveArray,
55) -> VortexResult<ArrayRef> {
56    let data_validity = array.validity_mask();
57    let indices_validity = indices_array.validity_mask();
58
59    if !indices_validity.all_true() || !data_validity.all_true() {
60        return _take_nullable::<I, O, OutputOffsetType>(array, indices_array);
61    }
62
63    let offsets_array = array.offsets().to_primitive();
64    let offsets: &[O] = offsets_array.as_slice();
65    let indices: &[I] = indices_array.as_slice();
66
67    let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
68        Nullability::NonNullable,
69        indices.len(),
70    );
71    let mut elements_to_take =
72        PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
73
74    let mut current_offset = OutputOffsetType::zero();
75    new_offsets.append_zero();
76
77    for &data_idx in indices {
78        let data_idx: usize = data_idx.as_();
79
80        let start = offsets[data_idx];
81        let stop = offsets[data_idx + 1];
82
83        // Annoyingly, we can't turn (start..end) into a range, so we're doing that manually.
84        //
85        // We could convert start and end to usize, but that would impose a potentially
86        // harder constraint - now we don't care if they fit into usize as long as their
87        // difference does.
88        let additional: usize = (stop - start).as_();
89
90        // TODO(0ax1): optimize this
91        elements_to_take.reserve_exact(additional);
92        for i in 0..additional {
93            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
94        }
95        current_offset +=
96            OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
97        new_offsets.append_value(current_offset);
98    }
99
100    let elements_to_take = elements_to_take.finish();
101    let new_offsets = new_offsets.finish();
102
103    let new_elements = take(array.elements(), elements_to_take.as_ref())?;
104
105    Ok(ListArray::try_new(
106        new_elements,
107        new_offsets,
108        array.validity().clone().take(indices_array.as_ref())?,
109    )?
110    .to_array())
111}
112
113fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
114    array: &ListArray,
115    indices_array: &PrimitiveArray,
116) -> VortexResult<ArrayRef> {
117    let offsets_array = array.offsets().to_primitive();
118    let offsets: &[O] = offsets_array.as_slice();
119    let indices: &[I] = indices_array.as_slice();
120    let data_validity = array.validity_mask();
121    let indices_validity = indices_array.validity_mask();
122
123    let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
124        Nullability::NonNullable,
125        indices.len(),
126    );
127
128    // This will be the indices we push down to the child array to call `take` with.
129    //
130    // There are 2 things to note here:
131    // - We do not know how many elements we need to take from our child since lists are variable
132    //   size: thus we arbitrarily choose a capacity of `2 * # of indices`.
133    // - The type of the primitive builder needs to fit the largest offset of the (parent)
134    //   `ListArray`, so we make this `PrimitiveBuilder` generic over `O` (instead of `I`).
135    let mut elements_to_take =
136        PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
137
138    let mut current_offset = OutputOffsetType::zero();
139    new_offsets.append_zero();
140
141    for (idx, data_idx) in indices.iter().enumerate() {
142        if !indices_validity.value(idx) {
143            new_offsets.append_value(current_offset);
144            continue;
145        }
146
147        let data_idx: usize = data_idx.as_();
148
149        if !data_validity.value(data_idx) {
150            new_offsets.append_value(current_offset);
151            continue;
152        }
153
154        let start = offsets[data_idx];
155        let stop = offsets[data_idx + 1];
156
157        // See the note in `_take` on the reasoning.
158        let additional: usize = (stop - start).as_();
159
160        elements_to_take.reserve_exact(additional);
161        for i in 0..additional {
162            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
163        }
164        current_offset +=
165            OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
166        new_offsets.append_value(current_offset);
167    }
168
169    let elements_to_take = elements_to_take.finish();
170    let new_offsets = new_offsets.finish();
171    let new_elements = take(array.elements(), elements_to_take.as_ref())?;
172
173    Ok(ListArray::try_new(
174        new_elements,
175        new_offsets,
176        array.validity().clone().take(indices_array.as_ref())?,
177    )?
178    .to_array())
179}
180
181#[cfg(test)]
182mod test {
183    use std::sync::Arc;
184
185    use rstest::rstest;
186    use vortex_buffer::buffer;
187    use vortex_dtype::DType;
188    use vortex_dtype::Nullability;
189    use vortex_dtype::PType::I32;
190    use vortex_scalar::Scalar;
191
192    use crate::Array;
193    use crate::IntoArray as _;
194    use crate::ToCanonical;
195    use crate::arrays::BoolArray;
196    use crate::arrays::PrimitiveArray;
197    use crate::arrays::list::ListArray;
198    use crate::compute::conformance::take::test_take_conformance;
199    use crate::compute::take;
200    use crate::validity::Validity;
201
202    #[test]
203    fn nullable_take() {
204        let list = ListArray::try_new(
205            buffer![0i32, 5, 3, 4].into_array(),
206            buffer![0, 2, 3, 4, 4].into_array(),
207            Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
208        )
209        .unwrap()
210        .to_array();
211
212        let idx =
213            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
214
215        let result = take(&list, &idx).unwrap();
216
217        assert_eq!(
218            result.dtype(),
219            &DType::List(
220                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
221                Nullability::Nullable
222            )
223        );
224
225        let result = result.to_listview();
226
227        assert_eq!(result.len(), 4);
228
229        let element_dtype: Arc<DType> = Arc::new(I32.into());
230
231        assert!(result.is_valid(0));
232        assert_eq!(
233            result.scalar_at(0),
234            Scalar::list(
235                element_dtype.clone(),
236                vec![0i32.into(), 5.into()],
237                Nullability::Nullable
238            )
239        );
240
241        assert!(result.is_invalid(1));
242
243        assert!(result.is_valid(2));
244        assert_eq!(
245            result.scalar_at(2),
246            Scalar::list(
247                element_dtype.clone(),
248                vec![3i32.into()],
249                Nullability::Nullable
250            )
251        );
252
253        assert!(result.is_valid(3));
254        assert_eq!(
255            result.scalar_at(3),
256            Scalar::list(element_dtype, vec![], Nullability::Nullable)
257        );
258    }
259
260    #[test]
261    fn change_validity() {
262        let list = ListArray::try_new(
263            buffer![0i32, 5, 3, 4].into_array(),
264            buffer![0, 2, 3].into_array(),
265            Validity::NonNullable,
266        )
267        .unwrap()
268        .to_array();
269
270        let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
271        // since idx is nullable, the final list will also be nullable
272
273        let result = take(&list, &idx).unwrap();
274        assert_eq!(
275            result.dtype(),
276            &DType::List(
277                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
278                Nullability::Nullable
279            )
280        );
281    }
282
283    #[test]
284    fn non_nullable_take() {
285        let list = ListArray::try_new(
286            buffer![0i32, 5, 3, 4].into_array(),
287            buffer![0, 2, 3, 3, 4].into_array(),
288            Validity::NonNullable,
289        )
290        .unwrap()
291        .to_array();
292
293        let idx = buffer![1, 0, 2].into_array();
294
295        let result = take(&list, &idx).unwrap();
296
297        assert_eq!(
298            result.dtype(),
299            &DType::List(
300                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
301                Nullability::NonNullable
302            )
303        );
304
305        let result = result.to_listview();
306
307        assert_eq!(result.len(), 3);
308
309        let element_dtype: Arc<DType> = Arc::new(I32.into());
310
311        assert!(result.is_valid(0));
312        assert_eq!(
313            result.scalar_at(0),
314            Scalar::list(
315                element_dtype.clone(),
316                vec![3i32.into()],
317                Nullability::NonNullable
318            )
319        );
320
321        assert!(result.is_valid(1));
322        assert_eq!(
323            result.scalar_at(1),
324            Scalar::list(
325                element_dtype.clone(),
326                vec![0i32.into(), 5.into()],
327                Nullability::NonNullable
328            )
329        );
330
331        assert!(result.is_valid(2));
332        assert_eq!(
333            result.scalar_at(2),
334            Scalar::list(element_dtype, vec![], Nullability::NonNullable)
335        );
336    }
337
338    #[test]
339    fn test_take_empty_array() {
340        let list = ListArray::try_new(
341            buffer![0i32, 5, 3, 4].into_array(),
342            buffer![0].into_array(),
343            Validity::NonNullable,
344        )
345        .unwrap()
346        .to_array();
347
348        let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
349
350        let result = take(&list, &idx).unwrap();
351        assert_eq!(
352            result.dtype(),
353            &DType::List(
354                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
355                Nullability::Nullable
356            )
357        );
358        assert_eq!(result.len(), 0,);
359    }
360
361    #[rstest]
362    #[case(ListArray::try_new(
363        buffer![0i32, 1, 2, 3, 4, 5].into_array(),
364        buffer![0, 2, 3, 5, 5, 6].into_array(),
365        Validity::NonNullable,
366    ).unwrap())]
367    #[case(ListArray::try_new(
368        buffer![10i32, 20, 30, 40, 50].into_array(),
369        buffer![0, 2, 3, 4, 5].into_array(),
370        Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).to_array()),
371    ).unwrap())]
372    #[case(ListArray::try_new(
373        buffer![1i32, 2, 3].into_array(),
374        buffer![0, 0, 2, 2, 3].into_array(), // First and third are empty
375        Validity::NonNullable,
376    ).unwrap())]
377    #[case(ListArray::try_new(
378        buffer![42i32, 43].into_array(),
379        buffer![0, 2].into_array(),
380        Validity::NonNullable,
381    ).unwrap())]
382    #[case({
383        let elements = buffer![0i32..200].into_array();
384        let mut offsets = vec![0u64];
385        for i in 1..=50 {
386            offsets.push(offsets[i - 1] + (i as u64 % 5)); // Variable length lists
387        }
388        ListArray::try_new(
389            elements,
390            PrimitiveArray::from_iter(offsets).to_array(),
391            Validity::NonNullable,
392        ).unwrap()
393    })]
394    #[case(ListArray::try_new(
395        PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).to_array(),
396        buffer![0, 2, 3, 5].into_array(),
397        Validity::NonNullable,
398    ).unwrap())]
399    fn test_take_list_conformance(#[case] list: ListArray) {
400        test_take_conformance(list.as_ref());
401    }
402
403    #[test]
404    fn test_u64_offset_accumulation_non_nullable() {
405        let elements = buffer![0i32; 200].into_array();
406        let offsets = buffer![0u8, 200].into_array();
407        let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
408            .unwrap()
409            .to_array();
410
411        // Take the same large list twice - would overflow u8 but works with u64.
412        let idx = buffer![0u8, 0].into_array();
413        let result = take(&list, &idx).unwrap();
414
415        assert_eq!(result.len(), 2);
416
417        let result_view = result.to_listview();
418        assert_eq!(result_view.len(), 2);
419        assert!(result_view.is_valid(0));
420        assert!(result_view.is_valid(1));
421    }
422
423    #[test]
424    fn test_u64_offset_accumulation_nullable() {
425        let elements = buffer![0i32; 150].into_array();
426        let offsets = buffer![0u8, 150, 150].into_array();
427        let validity = BoolArray::from_iter(vec![true, false]).to_array();
428        let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
429            .unwrap()
430            .to_array();
431
432        // Take the same large list twice - would overflow u8 but works with u64.
433        let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).to_array();
434        let result = take(&list, &idx).unwrap();
435
436        assert_eq!(result.len(), 3);
437
438        let result_view = result.to_listview();
439        assert_eq!(result_view.len(), 3);
440        assert!(result_view.is_valid(0));
441        assert!(result_view.is_invalid(1));
442        assert!(result_view.is_valid(2));
443    }
444
445    /// Regression test for validity length mismatch bug.
446    ///
447    /// When source array has `Validity::Array(...)` and indices are non-nullable,
448    /// the result validity must have length equal to indices.len(), not source.len().
449    #[test]
450    fn test_take_validity_length_mismatch_regression() {
451        // Source array with explicit validity array (length 2).
452        let list = ListArray::try_new(
453            buffer![1i32, 2, 3, 4].into_array(),
454            buffer![0, 2, 4].into_array(),
455            Validity::Array(BoolArray::from_iter(vec![true, true]).to_array()),
456        )
457        .unwrap()
458        .to_array();
459
460        // Take more indices than source length (4 vs 2) with non-nullable indices.
461        let idx = buffer![0u32, 1, 0, 1].into_array();
462
463        // This should not panic - result should have length 4.
464        let result = take(&list, &idx).unwrap();
465        assert_eq!(result.len(), 4);
466    }
467}