Skip to main content

vortex_array/arrays/list/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6
7use crate::Array;
8use crate::ArrayRef;
9use crate::arrays::ListArray;
10use crate::arrays::ListVTable;
11use crate::arrays::PrimitiveArray;
12use crate::arrays::TakeExecute;
13use crate::builders::ArrayBuilder;
14use crate::builders::PrimitiveBuilder;
15use crate::dtype::IntegerPType;
16use crate::dtype::Nullability;
17use crate::executor::ExecutionCtx;
18use crate::match_each_integer_ptype;
19use crate::match_smallest_offset_type;
20use crate::vtable::ValidityHelper;
21
22// TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call
23// the `ListView::take` compute function once `ListView` is more stable.
24
25impl TakeExecute for ListVTable {
26    /// Take implementation for [`ListArray`].
27    ///
28    /// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant
29    /// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
30    /// non-contiguous indices would violate this requirement.
31    #[expect(clippy::cognitive_complexity)]
32    fn take(
33        array: &ListArray,
34        indices: &ArrayRef,
35        ctx: &mut ExecutionCtx,
36    ) -> VortexResult<Option<ArrayRef>> {
37        let indices = indices.to_array().execute::<PrimitiveArray>(ctx)?;
38        // This is an over-approximation of the total number of elements in the resulting array.
39        let total_approx = array.elements().len().saturating_mul(indices.len());
40
41        match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| {
42            match_each_integer_ptype!(indices.ptype(), |I| {
43                match_smallest_offset_type!(total_approx, |OutputOffsetType| {
44                    _take::<I, O, OutputOffsetType>(array, &indices, ctx).map(Some)
45                })
46            })
47        })
48    }
49}
50
51fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
52    array: &ListArray,
53    indices_array: &PrimitiveArray,
54    ctx: &mut ExecutionCtx,
55) -> VortexResult<ArrayRef> {
56    let data_validity = array.validity_mask()?;
57    let indices_validity = indices_array.validity_mask()?;
58
59    if !indices_validity.all_true() || !data_validity.all_true() {
60        return _take_nullable::<I, O, OutputOffsetType>(array, indices_array, ctx);
61    }
62
63    let offsets_array = array.offsets().to_array().execute::<PrimitiveArray>(ctx)?;
64    let offsets: &[O] = offsets_array.as_slice();
65    let indices: &[I] = indices_array.as_slice();
66
67    let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
68        Nullability::NonNullable,
69        indices.len(),
70    );
71    let mut elements_to_take =
72        PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
73
74    let mut current_offset = OutputOffsetType::zero();
75    new_offsets.append_zero();
76
77    for &data_idx in indices {
78        let data_idx: usize = data_idx.as_();
79
80        let start = offsets[data_idx];
81        let stop = offsets[data_idx + 1];
82
83        // Annoyingly, we can't turn (start..end) into a range, so we're doing that manually.
84        //
85        // We could convert start and end to usize, but that would impose a potentially
86        // harder constraint - now we don't care if they fit into usize as long as their
87        // difference does.
88        let additional: usize = (stop - start).as_();
89
90        // TODO(0ax1): optimize this
91        elements_to_take.reserve_exact(additional);
92        for i in 0..additional {
93            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
94        }
95        current_offset +=
96            OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
97        new_offsets.append_value(current_offset);
98    }
99
100    let elements_to_take = elements_to_take.finish();
101    let new_offsets = new_offsets.finish();
102
103    let new_elements = array.elements().take(elements_to_take.to_array())?;
104
105    Ok(ListArray::try_new(
106        new_elements,
107        new_offsets,
108        array.validity().clone().take(&indices_array.to_array())?,
109    )?
110    .to_array())
111}
112
113fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
114    array: &ListArray,
115    indices_array: &PrimitiveArray,
116    ctx: &mut ExecutionCtx,
117) -> VortexResult<ArrayRef> {
118    let offsets_array = array.offsets().to_array().execute::<PrimitiveArray>(ctx)?;
119    let offsets: &[O] = offsets_array.as_slice();
120    let indices: &[I] = indices_array.as_slice();
121    let data_validity = array.validity_mask()?;
122    let indices_validity = indices_array.validity_mask()?;
123
124    let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
125        Nullability::NonNullable,
126        indices.len(),
127    );
128
129    // This will be the indices we push down to the child array to call `take` with.
130    //
131    // There are 2 things to note here:
132    // - We do not know how many elements we need to take from our child since lists are variable
133    //   size: thus we arbitrarily choose a capacity of `2 * # of indices`.
134    // - The type of the primitive builder needs to fit the largest offset of the (parent)
135    //   `ListArray`, so we make this `PrimitiveBuilder` generic over `O` (instead of `I`).
136    let mut elements_to_take =
137        PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
138
139    let mut current_offset = OutputOffsetType::zero();
140    new_offsets.append_zero();
141
142    for (idx, data_idx) in indices.iter().enumerate() {
143        if !indices_validity.value(idx) {
144            new_offsets.append_value(current_offset);
145            continue;
146        }
147
148        let data_idx: usize = data_idx.as_();
149
150        if !data_validity.value(data_idx) {
151            new_offsets.append_value(current_offset);
152            continue;
153        }
154
155        let start = offsets[data_idx];
156        let stop = offsets[data_idx + 1];
157
158        // See the note in `_take` on the reasoning.
159        let additional: usize = (stop - start).as_();
160
161        elements_to_take.reserve_exact(additional);
162        for i in 0..additional {
163            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
164        }
165        current_offset +=
166            OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
167        new_offsets.append_value(current_offset);
168    }
169
170    let elements_to_take = elements_to_take.finish();
171    let new_offsets = new_offsets.finish();
172    let new_elements = array.elements().take(elements_to_take.to_array())?;
173
174    Ok(ListArray::try_new(
175        new_elements,
176        new_offsets,
177        array.validity().clone().take(&indices_array.to_array())?,
178    )?
179    .to_array())
180}
181
182#[cfg(test)]
183mod test {
184    use std::sync::Arc;
185
186    use rstest::rstest;
187    use vortex_buffer::buffer;
188
189    use crate::Array;
190    use crate::IntoArray as _;
191    use crate::ToCanonical;
192    use crate::arrays::BoolArray;
193    use crate::arrays::PrimitiveArray;
194    use crate::arrays::list::ListArray;
195    use crate::compute::conformance::take::test_take_conformance;
196    use crate::dtype::DType;
197    use crate::dtype::Nullability;
198    use crate::dtype::PType::I32;
199    use crate::scalar::Scalar;
200    use crate::validity::Validity;
201
202    #[test]
203    fn nullable_take() {
204        let list = ListArray::try_new(
205            buffer![0i32, 5, 3, 4].into_array(),
206            buffer![0, 2, 3, 4, 4].into_array(),
207            Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
208        )
209        .unwrap()
210        .to_array();
211
212        let idx =
213            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
214
215        let result = list.take(idx.to_array()).unwrap();
216
217        assert_eq!(
218            result.dtype(),
219            &DType::List(
220                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
221                Nullability::Nullable
222            )
223        );
224
225        let result = result.to_listview();
226
227        assert_eq!(result.len(), 4);
228
229        let element_dtype: Arc<DType> = Arc::new(I32.into());
230
231        assert!(result.is_valid(0).unwrap());
232        assert_eq!(
233            result.scalar_at(0).unwrap(),
234            Scalar::list(
235                element_dtype.clone(),
236                vec![0i32.into(), 5.into()],
237                Nullability::Nullable
238            )
239        );
240
241        assert!(result.is_invalid(1).unwrap());
242
243        assert!(result.is_valid(2).unwrap());
244        assert_eq!(
245            result.scalar_at(2).unwrap(),
246            Scalar::list(
247                element_dtype.clone(),
248                vec![3i32.into()],
249                Nullability::Nullable
250            )
251        );
252
253        assert!(result.is_valid(3).unwrap());
254        assert_eq!(
255            result.scalar_at(3).unwrap(),
256            Scalar::list(element_dtype, vec![], Nullability::Nullable)
257        );
258    }
259
260    #[test]
261    fn change_validity() {
262        let list = ListArray::try_new(
263            buffer![0i32, 5, 3, 4].into_array(),
264            buffer![0, 2, 3].into_array(),
265            Validity::NonNullable,
266        )
267        .unwrap()
268        .to_array();
269
270        let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
271        // since idx is nullable, the final list will also be nullable
272
273        let result = list.take(idx.to_array()).unwrap();
274        assert_eq!(
275            result.dtype(),
276            &DType::List(
277                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
278                Nullability::Nullable
279            )
280        );
281    }
282
283    #[test]
284    fn non_nullable_take() {
285        let list = ListArray::try_new(
286            buffer![0i32, 5, 3, 4].into_array(),
287            buffer![0, 2, 3, 3, 4].into_array(),
288            Validity::NonNullable,
289        )
290        .unwrap()
291        .to_array();
292
293        let idx = buffer![1, 0, 2].into_array();
294
295        let result = list.take(idx.to_array()).unwrap();
296
297        assert_eq!(
298            result.dtype(),
299            &DType::List(
300                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
301                Nullability::NonNullable
302            )
303        );
304
305        let result = result.to_listview();
306
307        assert_eq!(result.len(), 3);
308
309        let element_dtype: Arc<DType> = Arc::new(I32.into());
310
311        assert!(result.is_valid(0).unwrap());
312        assert_eq!(
313            result.scalar_at(0).unwrap(),
314            Scalar::list(
315                element_dtype.clone(),
316                vec![3i32.into()],
317                Nullability::NonNullable
318            )
319        );
320
321        assert!(result.is_valid(1).unwrap());
322        assert_eq!(
323            result.scalar_at(1).unwrap(),
324            Scalar::list(
325                element_dtype.clone(),
326                vec![0i32.into(), 5.into()],
327                Nullability::NonNullable
328            )
329        );
330
331        assert!(result.is_valid(2).unwrap());
332        assert_eq!(
333            result.scalar_at(2).unwrap(),
334            Scalar::list(element_dtype, vec![], Nullability::NonNullable)
335        );
336    }
337
338    #[test]
339    fn test_take_empty_array() {
340        let list = ListArray::try_new(
341            buffer![0i32, 5, 3, 4].into_array(),
342            buffer![0].into_array(),
343            Validity::NonNullable,
344        )
345        .unwrap()
346        .to_array();
347
348        let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
349
350        let result = list.take(idx.to_array()).unwrap();
351        assert_eq!(
352            result.dtype(),
353            &DType::List(
354                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
355                Nullability::Nullable
356            )
357        );
358        assert_eq!(result.len(), 0,);
359    }
360
361    #[rstest]
362    #[case(ListArray::try_new(
363        buffer![0i32, 1, 2, 3, 4, 5].into_array(),
364        buffer![0, 2, 3, 5, 5, 6].into_array(),
365        Validity::NonNullable,
366    ).unwrap())]
367    #[case(ListArray::try_new(
368        buffer![10i32, 20, 30, 40, 50].into_array(),
369        buffer![0, 2, 3, 4, 5].into_array(),
370        Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).to_array()),
371    ).unwrap())]
372    #[case(ListArray::try_new(
373        buffer![1i32, 2, 3].into_array(),
374        buffer![0, 0, 2, 2, 3].into_array(), // First and third are empty
375        Validity::NonNullable,
376    ).unwrap())]
377    #[case(ListArray::try_new(
378        buffer![42i32, 43].into_array(),
379        buffer![0, 2].into_array(),
380        Validity::NonNullable,
381    ).unwrap())]
382    #[case({
383        let elements = buffer![0i32..200].into_array();
384        let mut offsets = vec![0u64];
385        for i in 1..=50 {
386            offsets.push(offsets[i - 1] + (i as u64 % 5)); // Variable length lists
387        }
388        ListArray::try_new(
389            elements,
390            PrimitiveArray::from_iter(offsets).to_array(),
391            Validity::NonNullable,
392        ).unwrap()
393    })]
394    #[case(ListArray::try_new(
395        PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).to_array(),
396        buffer![0, 2, 3, 5].into_array(),
397        Validity::NonNullable,
398    ).unwrap())]
399    fn test_take_list_conformance(#[case] list: ListArray) {
400        test_take_conformance(&list.to_array());
401    }
402
403    #[test]
404    fn test_u64_offset_accumulation_non_nullable() {
405        let elements = buffer![0i32; 200].into_array();
406        let offsets = buffer![0u8, 200].into_array();
407        let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
408            .unwrap()
409            .to_array();
410
411        // Take the same large list twice - would overflow u8 but works with u64.
412        let idx = buffer![0u8, 0].into_array();
413        let result = list.take(idx.to_array()).unwrap();
414
415        assert_eq!(result.len(), 2);
416
417        let result_view = result.to_listview();
418        assert_eq!(result_view.len(), 2);
419        assert!(result_view.is_valid(0).unwrap());
420        assert!(result_view.is_valid(1).unwrap());
421    }
422
423    #[test]
424    fn test_u64_offset_accumulation_nullable() {
425        let elements = buffer![0i32; 150].into_array();
426        let offsets = buffer![0u8, 150, 150].into_array();
427        let validity = BoolArray::from_iter(vec![true, false]).to_array();
428        let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
429            .unwrap()
430            .to_array();
431
432        // Take the same large list twice - would overflow u8 but works with u64.
433        let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).to_array();
434        let result = list.take(idx.to_array()).unwrap();
435
436        assert_eq!(result.len(), 3);
437
438        let result_view = result.to_listview();
439        assert_eq!(result_view.len(), 3);
440        assert!(result_view.is_valid(0).unwrap());
441        assert!(result_view.is_invalid(1).unwrap());
442        assert!(result_view.is_valid(2).unwrap());
443    }
444
445    /// Regression test for validity length mismatch bug.
446    ///
447    /// When source array has `Validity::Array(...)` and indices are non-nullable,
448    /// the result validity must have length equal to indices.len(), not source.len().
449    #[test]
450    fn test_take_validity_length_mismatch_regression() {
451        // Source array with explicit validity array (length 2).
452        let list = ListArray::try_new(
453            buffer![1i32, 2, 3, 4].into_array(),
454            buffer![0, 2, 4].into_array(),
455            Validity::Array(BoolArray::from_iter(vec![true, true]).to_array()),
456        )
457        .unwrap()
458        .to_array();
459
460        // Take more indices than source length (4 vs 2) with non-nullable indices.
461        let idx = buffer![0u32, 1, 0, 1].into_array();
462
463        // This should not panic - result should have length 4.
464        let result = list.take(idx.to_array()).unwrap();
465        assert_eq!(result.len(), 4);
466    }
467}