Skip to main content

vortex_array/arrays/list/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::IntegerPType;
5use vortex_dtype::Nullability;
6use vortex_dtype::match_each_integer_ptype;
7use vortex_dtype::match_smallest_offset_type;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use crate::Array;
12use crate::ArrayRef;
13use crate::ToCanonical;
14use crate::arrays::ListArray;
15use crate::arrays::ListVTable;
16use crate::arrays::PrimitiveArray;
17use crate::arrays::TakeExecute;
18use crate::builders::ArrayBuilder;
19use crate::builders::PrimitiveBuilder;
20use crate::executor::ExecutionCtx;
21use crate::vtable::ValidityHelper;
22
23// TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call
24// the `ListView::take` compute function once `ListView` is more stable.
25
26impl TakeExecute for ListVTable {
27    /// Take implementation for [`ListArray`].
28    ///
29    /// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant
30    /// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
31    /// non-contiguous indices would violate this requirement.
32    #[expect(clippy::cognitive_complexity)]
33    fn take(
34        array: &ListArray,
35        indices: &dyn Array,
36        _ctx: &mut ExecutionCtx,
37    ) -> VortexResult<Option<ArrayRef>> {
38        let indices = indices.to_primitive();
39        // This is an over-approximation of the total number of elements in the resulting array.
40        let total_approx = array.elements().len().saturating_mul(indices.len());
41
42        match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| {
43            match_each_integer_ptype!(indices.ptype(), |I| {
44                match_smallest_offset_type!(total_approx, |OutputOffsetType| {
45                    _take::<I, O, OutputOffsetType>(array, &indices).map(Some)
46                })
47            })
48        })
49    }
50}
51
52fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
53    array: &ListArray,
54    indices_array: &PrimitiveArray,
55) -> VortexResult<ArrayRef> {
56    let data_validity = array.validity_mask()?;
57    let indices_validity = indices_array.validity_mask()?;
58
59    if !indices_validity.all_true() || !data_validity.all_true() {
60        return _take_nullable::<I, O, OutputOffsetType>(array, indices_array);
61    }
62
63    let offsets_array = array.offsets().to_primitive();
64    let offsets: &[O] = offsets_array.as_slice();
65    let indices: &[I] = indices_array.as_slice();
66
67    let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
68        Nullability::NonNullable,
69        indices.len(),
70    );
71    let mut elements_to_take =
72        PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
73
74    let mut current_offset = OutputOffsetType::zero();
75    new_offsets.append_zero();
76
77    for &data_idx in indices {
78        let data_idx: usize = data_idx.as_();
79
80        let start = offsets[data_idx];
81        let stop = offsets[data_idx + 1];
82
83        // Annoyingly, we can't turn (start..end) into a range, so we're doing that manually.
84        //
85        // We could convert start and end to usize, but that would impose a potentially
86        // harder constraint - now we don't care if they fit into usize as long as their
87        // difference does.
88        let additional: usize = (stop - start).as_();
89
90        // TODO(0ax1): optimize this
91        elements_to_take.reserve_exact(additional);
92        for i in 0..additional {
93            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
94        }
95        current_offset +=
96            OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
97        new_offsets.append_value(current_offset);
98    }
99
100    let elements_to_take = elements_to_take.finish();
101    let new_offsets = new_offsets.finish();
102
103    let new_elements = array.elements().take(elements_to_take.to_array())?;
104
105    Ok(ListArray::try_new(
106        new_elements,
107        new_offsets,
108        array.validity().clone().take(indices_array.as_ref())?,
109    )?
110    .to_array())
111}
112
113fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
114    array: &ListArray,
115    indices_array: &PrimitiveArray,
116) -> VortexResult<ArrayRef> {
117    let offsets_array = array.offsets().to_primitive();
118    let offsets: &[O] = offsets_array.as_slice();
119    let indices: &[I] = indices_array.as_slice();
120    let data_validity = array.validity_mask()?;
121    let indices_validity = indices_array.validity_mask()?;
122
123    let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
124        Nullability::NonNullable,
125        indices.len(),
126    );
127
128    // This will be the indices we push down to the child array to call `take` with.
129    //
130    // There are 2 things to note here:
131    // - We do not know how many elements we need to take from our child since lists are variable
132    //   size: thus we arbitrarily choose a capacity of `2 * # of indices`.
133    // - The type of the primitive builder needs to fit the largest offset of the (parent)
134    //   `ListArray`, so we make this `PrimitiveBuilder` generic over `O` (instead of `I`).
135    let mut elements_to_take =
136        PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
137
138    let mut current_offset = OutputOffsetType::zero();
139    new_offsets.append_zero();
140
141    for (idx, data_idx) in indices.iter().enumerate() {
142        if !indices_validity.value(idx) {
143            new_offsets.append_value(current_offset);
144            continue;
145        }
146
147        let data_idx: usize = data_idx.as_();
148
149        if !data_validity.value(data_idx) {
150            new_offsets.append_value(current_offset);
151            continue;
152        }
153
154        let start = offsets[data_idx];
155        let stop = offsets[data_idx + 1];
156
157        // See the note in `_take` on the reasoning.
158        let additional: usize = (stop - start).as_();
159
160        elements_to_take.reserve_exact(additional);
161        for i in 0..additional {
162            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
163        }
164        current_offset +=
165            OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
166        new_offsets.append_value(current_offset);
167    }
168
169    let elements_to_take = elements_to_take.finish();
170    let new_offsets = new_offsets.finish();
171    let new_elements = array.elements().take(elements_to_take.to_array())?;
172
173    Ok(ListArray::try_new(
174        new_elements,
175        new_offsets,
176        array.validity().clone().take(indices_array.as_ref())?,
177    )?
178    .to_array())
179}
180
181#[cfg(test)]
182mod test {
183    use std::sync::Arc;
184
185    use rstest::rstest;
186    use vortex_buffer::buffer;
187    use vortex_dtype::DType;
188    use vortex_dtype::Nullability;
189    use vortex_dtype::PType::I32;
190
191    use crate::Array;
192    use crate::IntoArray as _;
193    use crate::ToCanonical;
194    use crate::arrays::BoolArray;
195    use crate::arrays::PrimitiveArray;
196    use crate::arrays::list::ListArray;
197    use crate::compute::conformance::take::test_take_conformance;
198    use crate::scalar::Scalar;
199    use crate::validity::Validity;
200
201    #[test]
202    fn nullable_take() {
203        let list = ListArray::try_new(
204            buffer![0i32, 5, 3, 4].into_array(),
205            buffer![0, 2, 3, 4, 4].into_array(),
206            Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
207        )
208        .unwrap()
209        .to_array();
210
211        let idx =
212            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
213
214        let result = list.take(idx.to_array()).unwrap();
215
216        assert_eq!(
217            result.dtype(),
218            &DType::List(
219                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
220                Nullability::Nullable
221            )
222        );
223
224        let result = result.to_listview();
225
226        assert_eq!(result.len(), 4);
227
228        let element_dtype: Arc<DType> = Arc::new(I32.into());
229
230        assert!(result.is_valid(0).unwrap());
231        assert_eq!(
232            result.scalar_at(0).unwrap(),
233            Scalar::list(
234                element_dtype.clone(),
235                vec![0i32.into(), 5.into()],
236                Nullability::Nullable
237            )
238        );
239
240        assert!(result.is_invalid(1).unwrap());
241
242        assert!(result.is_valid(2).unwrap());
243        assert_eq!(
244            result.scalar_at(2).unwrap(),
245            Scalar::list(
246                element_dtype.clone(),
247                vec![3i32.into()],
248                Nullability::Nullable
249            )
250        );
251
252        assert!(result.is_valid(3).unwrap());
253        assert_eq!(
254            result.scalar_at(3).unwrap(),
255            Scalar::list(element_dtype, vec![], Nullability::Nullable)
256        );
257    }
258
259    #[test]
260    fn change_validity() {
261        let list = ListArray::try_new(
262            buffer![0i32, 5, 3, 4].into_array(),
263            buffer![0, 2, 3].into_array(),
264            Validity::NonNullable,
265        )
266        .unwrap()
267        .to_array();
268
269        let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
270        // since idx is nullable, the final list will also be nullable
271
272        let result = list.take(idx.to_array()).unwrap();
273        assert_eq!(
274            result.dtype(),
275            &DType::List(
276                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
277                Nullability::Nullable
278            )
279        );
280    }
281
282    #[test]
283    fn non_nullable_take() {
284        let list = ListArray::try_new(
285            buffer![0i32, 5, 3, 4].into_array(),
286            buffer![0, 2, 3, 3, 4].into_array(),
287            Validity::NonNullable,
288        )
289        .unwrap()
290        .to_array();
291
292        let idx = buffer![1, 0, 2].into_array();
293
294        let result = list.take(idx.to_array()).unwrap();
295
296        assert_eq!(
297            result.dtype(),
298            &DType::List(
299                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
300                Nullability::NonNullable
301            )
302        );
303
304        let result = result.to_listview();
305
306        assert_eq!(result.len(), 3);
307
308        let element_dtype: Arc<DType> = Arc::new(I32.into());
309
310        assert!(result.is_valid(0).unwrap());
311        assert_eq!(
312            result.scalar_at(0).unwrap(),
313            Scalar::list(
314                element_dtype.clone(),
315                vec![3i32.into()],
316                Nullability::NonNullable
317            )
318        );
319
320        assert!(result.is_valid(1).unwrap());
321        assert_eq!(
322            result.scalar_at(1).unwrap(),
323            Scalar::list(
324                element_dtype.clone(),
325                vec![0i32.into(), 5.into()],
326                Nullability::NonNullable
327            )
328        );
329
330        assert!(result.is_valid(2).unwrap());
331        assert_eq!(
332            result.scalar_at(2).unwrap(),
333            Scalar::list(element_dtype, vec![], Nullability::NonNullable)
334        );
335    }
336
337    #[test]
338    fn test_take_empty_array() {
339        let list = ListArray::try_new(
340            buffer![0i32, 5, 3, 4].into_array(),
341            buffer![0].into_array(),
342            Validity::NonNullable,
343        )
344        .unwrap()
345        .to_array();
346
347        let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
348
349        let result = list.take(idx.to_array()).unwrap();
350        assert_eq!(
351            result.dtype(),
352            &DType::List(
353                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
354                Nullability::Nullable
355            )
356        );
357        assert_eq!(result.len(), 0,);
358    }
359
360    #[rstest]
361    #[case(ListArray::try_new(
362        buffer![0i32, 1, 2, 3, 4, 5].into_array(),
363        buffer![0, 2, 3, 5, 5, 6].into_array(),
364        Validity::NonNullable,
365    ).unwrap())]
366    #[case(ListArray::try_new(
367        buffer![10i32, 20, 30, 40, 50].into_array(),
368        buffer![0, 2, 3, 4, 5].into_array(),
369        Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).to_array()),
370    ).unwrap())]
371    #[case(ListArray::try_new(
372        buffer![1i32, 2, 3].into_array(),
373        buffer![0, 0, 2, 2, 3].into_array(), // First and third are empty
374        Validity::NonNullable,
375    ).unwrap())]
376    #[case(ListArray::try_new(
377        buffer![42i32, 43].into_array(),
378        buffer![0, 2].into_array(),
379        Validity::NonNullable,
380    ).unwrap())]
381    #[case({
382        let elements = buffer![0i32..200].into_array();
383        let mut offsets = vec![0u64];
384        for i in 1..=50 {
385            offsets.push(offsets[i - 1] + (i as u64 % 5)); // Variable length lists
386        }
387        ListArray::try_new(
388            elements,
389            PrimitiveArray::from_iter(offsets).to_array(),
390            Validity::NonNullable,
391        ).unwrap()
392    })]
393    #[case(ListArray::try_new(
394        PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).to_array(),
395        buffer![0, 2, 3, 5].into_array(),
396        Validity::NonNullable,
397    ).unwrap())]
398    fn test_take_list_conformance(#[case] list: ListArray) {
399        test_take_conformance(list.as_ref());
400    }
401
402    #[test]
403    fn test_u64_offset_accumulation_non_nullable() {
404        let elements = buffer![0i32; 200].into_array();
405        let offsets = buffer![0u8, 200].into_array();
406        let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
407            .unwrap()
408            .to_array();
409
410        // Take the same large list twice - would overflow u8 but works with u64.
411        let idx = buffer![0u8, 0].into_array();
412        let result = list.take(idx.to_array()).unwrap();
413
414        assert_eq!(result.len(), 2);
415
416        let result_view = result.to_listview();
417        assert_eq!(result_view.len(), 2);
418        assert!(result_view.is_valid(0).unwrap());
419        assert!(result_view.is_valid(1).unwrap());
420    }
421
422    #[test]
423    fn test_u64_offset_accumulation_nullable() {
424        let elements = buffer![0i32; 150].into_array();
425        let offsets = buffer![0u8, 150, 150].into_array();
426        let validity = BoolArray::from_iter(vec![true, false]).to_array();
427        let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
428            .unwrap()
429            .to_array();
430
431        // Take the same large list twice - would overflow u8 but works with u64.
432        let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).to_array();
433        let result = list.take(idx.to_array()).unwrap();
434
435        assert_eq!(result.len(), 3);
436
437        let result_view = result.to_listview();
438        assert_eq!(result_view.len(), 3);
439        assert!(result_view.is_valid(0).unwrap());
440        assert!(result_view.is_invalid(1).unwrap());
441        assert!(result_view.is_valid(2).unwrap());
442    }
443
444    /// Regression test for validity length mismatch bug.
445    ///
446    /// When source array has `Validity::Array(...)` and indices are non-nullable,
447    /// the result validity must have length equal to indices.len(), not source.len().
448    #[test]
449    fn test_take_validity_length_mismatch_regression() {
450        // Source array with explicit validity array (length 2).
451        let list = ListArray::try_new(
452            buffer![1i32, 2, 3, 4].into_array(),
453            buffer![0, 2, 4].into_array(),
454            Validity::Array(BoolArray::from_iter(vec![true, true]).to_array()),
455        )
456        .unwrap()
457        .to_array();
458
459        // Take more indices than source length (4 vs 2) with non-nullable indices.
460        let idx = buffer![0u32, 1, 0, 1].into_array();
461
462        // This should not panic - result should have length 4.
463        let result = list.take(idx.to_array()).unwrap();
464        assert_eq!(result.len(), 4);
465    }
466}