Skip to main content

vortex_array/arrays/list/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::DynArray;
9use crate::IntoArray;
10use crate::arrays::ListArray;
11use crate::arrays::ListVTable;
12use crate::arrays::PrimitiveArray;
13use crate::arrays::dict::TakeExecute;
14use crate::builders::ArrayBuilder;
15use crate::builders::PrimitiveBuilder;
16use crate::dtype::IntegerPType;
17use crate::dtype::Nullability;
18use crate::executor::ExecutionCtx;
19use crate::match_each_integer_ptype;
20use crate::match_smallest_offset_type;
21use crate::vtable::ValidityHelper;
22
23// TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call
24// the `ListView::take` compute function once `ListView` is more stable.
25
26impl TakeExecute for ListVTable {
27    /// Take implementation for [`ListArray`].
28    ///
29    /// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant
30    /// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
31    /// non-contiguous indices would violate this requirement.
32    #[expect(clippy::cognitive_complexity)]
33    fn take(
34        array: &ListArray,
35        indices: &ArrayRef,
36        ctx: &mut ExecutionCtx,
37    ) -> VortexResult<Option<ArrayRef>> {
38        let indices = indices.to_array().execute::<PrimitiveArray>(ctx)?;
39        // This is an over-approximation of the total number of elements in the resulting array.
40        let total_approx = array.elements().len().saturating_mul(indices.len());
41
42        match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| {
43            match_each_integer_ptype!(indices.ptype(), |I| {
44                match_smallest_offset_type!(total_approx, |OutputOffsetType| {
45                    _take::<I, O, OutputOffsetType>(array, &indices, ctx).map(Some)
46                })
47            })
48        })
49    }
50}
51
52fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
53    array: &ListArray,
54    indices_array: &PrimitiveArray,
55    ctx: &mut ExecutionCtx,
56) -> VortexResult<ArrayRef> {
57    let data_validity = array.validity_mask()?;
58    let indices_validity = indices_array.validity_mask()?;
59
60    if !indices_validity.all_true() || !data_validity.all_true() {
61        return _take_nullable::<I, O, OutputOffsetType>(array, indices_array, ctx);
62    }
63
64    let offsets_array = array.offsets().to_array().execute::<PrimitiveArray>(ctx)?;
65    let offsets: &[O] = offsets_array.as_slice();
66    let indices: &[I] = indices_array.as_slice();
67
68    let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
69        Nullability::NonNullable,
70        indices.len(),
71    );
72    let mut elements_to_take =
73        PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
74
75    let mut current_offset = OutputOffsetType::zero();
76    new_offsets.append_zero();
77
78    for &data_idx in indices {
79        let data_idx: usize = data_idx.as_();
80
81        let start = offsets[data_idx];
82        let stop = offsets[data_idx + 1];
83
84        // Annoyingly, we can't turn (start..end) into a range, so we're doing that manually.
85        //
86        // We could convert start and end to usize, but that would impose a potentially
87        // harder constraint - now we don't care if they fit into usize as long as their
88        // difference does.
89        let additional: usize = (stop - start).as_();
90
91        // TODO(0ax1): optimize this
92        elements_to_take.reserve_exact(additional);
93        for i in 0..additional {
94            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
95        }
96        current_offset +=
97            OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
98        new_offsets.append_value(current_offset);
99    }
100
101    let elements_to_take = elements_to_take.finish();
102    let new_offsets = new_offsets.finish();
103
104    let new_elements = array.elements().take(elements_to_take.to_array())?;
105
106    Ok(ListArray::try_new(
107        new_elements,
108        new_offsets,
109        array
110            .validity()
111            .clone()
112            .take(&indices_array.clone().into_array())?,
113    )?
114    .into_array())
115}
116
117fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
118    array: &ListArray,
119    indices_array: &PrimitiveArray,
120    ctx: &mut ExecutionCtx,
121) -> VortexResult<ArrayRef> {
122    let offsets_array = array.offsets().to_array().execute::<PrimitiveArray>(ctx)?;
123    let offsets: &[O] = offsets_array.as_slice();
124    let indices: &[I] = indices_array.as_slice();
125    let data_validity = array.validity_mask()?;
126    let indices_validity = indices_array.validity_mask()?;
127
128    let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
129        Nullability::NonNullable,
130        indices.len(),
131    );
132
133    // This will be the indices we push down to the child array to call `take` with.
134    //
135    // There are 2 things to note here:
136    // - We do not know how many elements we need to take from our child since lists are variable
137    //   size: thus we arbitrarily choose a capacity of `2 * # of indices`.
138    // - The type of the primitive builder needs to fit the largest offset of the (parent)
139    //   `ListArray`, so we make this `PrimitiveBuilder` generic over `O` (instead of `I`).
140    let mut elements_to_take =
141        PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
142
143    let mut current_offset = OutputOffsetType::zero();
144    new_offsets.append_zero();
145
146    for (idx, data_idx) in indices.iter().enumerate() {
147        if !indices_validity.value(idx) {
148            new_offsets.append_value(current_offset);
149            continue;
150        }
151
152        let data_idx: usize = data_idx.as_();
153
154        if !data_validity.value(data_idx) {
155            new_offsets.append_value(current_offset);
156            continue;
157        }
158
159        let start = offsets[data_idx];
160        let stop = offsets[data_idx + 1];
161
162        // See the note in `_take` on the reasoning.
163        let additional: usize = (stop - start).as_();
164
165        elements_to_take.reserve_exact(additional);
166        for i in 0..additional {
167            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
168        }
169        current_offset +=
170            OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
171        new_offsets.append_value(current_offset);
172    }
173
174    let elements_to_take = elements_to_take.finish();
175    let new_offsets = new_offsets.finish();
176    let new_elements = array.elements().take(elements_to_take.to_array())?;
177
178    Ok(ListArray::try_new(
179        new_elements,
180        new_offsets,
181        array
182            .validity()
183            .clone()
184            .take(&indices_array.clone().into_array())?,
185    )?
186    .into_array())
187}
188
189#[cfg(test)]
190mod test {
191    use std::sync::Arc;
192
193    use rstest::rstest;
194    use vortex_buffer::buffer;
195
196    use crate::DynArray;
197    use crate::IntoArray as _;
198    use crate::ToCanonical;
199    use crate::arrays::BoolArray;
200    use crate::arrays::ListArray;
201    use crate::arrays::PrimitiveArray;
202    use crate::compute::conformance::take::test_take_conformance;
203    use crate::dtype::DType;
204    use crate::dtype::Nullability;
205    use crate::dtype::PType::I32;
206    use crate::scalar::Scalar;
207    use crate::validity::Validity;
208
209    #[test]
210    fn nullable_take() {
211        let list = ListArray::try_new(
212            buffer![0i32, 5, 3, 4].into_array(),
213            buffer![0, 2, 3, 4, 4].into_array(),
214            Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).into_array()),
215        )
216        .unwrap()
217        .into_array();
218
219        let idx =
220            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).into_array();
221
222        let result = list.take(idx.to_array()).unwrap();
223
224        assert_eq!(
225            result.dtype(),
226            &DType::List(
227                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
228                Nullability::Nullable
229            )
230        );
231
232        let result = result.to_listview();
233
234        assert_eq!(result.len(), 4);
235
236        let element_dtype: Arc<DType> = Arc::new(I32.into());
237
238        assert!(result.is_valid(0).unwrap());
239        assert_eq!(
240            result.scalar_at(0).unwrap(),
241            Scalar::list(
242                element_dtype.clone(),
243                vec![0i32.into(), 5.into()],
244                Nullability::Nullable
245            )
246        );
247
248        assert!(result.is_invalid(1).unwrap());
249
250        assert!(result.is_valid(2).unwrap());
251        assert_eq!(
252            result.scalar_at(2).unwrap(),
253            Scalar::list(
254                element_dtype.clone(),
255                vec![3i32.into()],
256                Nullability::Nullable
257            )
258        );
259
260        assert!(result.is_valid(3).unwrap());
261        assert_eq!(
262            result.scalar_at(3).unwrap(),
263            Scalar::list(element_dtype, vec![], Nullability::Nullable)
264        );
265    }
266
267    #[test]
268    fn change_validity() {
269        let list = ListArray::try_new(
270            buffer![0i32, 5, 3, 4].into_array(),
271            buffer![0, 2, 3].into_array(),
272            Validity::NonNullable,
273        )
274        .unwrap()
275        .into_array();
276
277        let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).into_array();
278        // since idx is nullable, the final list will also be nullable
279
280        let result = list.take(idx.to_array()).unwrap();
281        assert_eq!(
282            result.dtype(),
283            &DType::List(
284                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
285                Nullability::Nullable
286            )
287        );
288    }
289
290    #[test]
291    fn non_nullable_take() {
292        let list = ListArray::try_new(
293            buffer![0i32, 5, 3, 4].into_array(),
294            buffer![0, 2, 3, 3, 4].into_array(),
295            Validity::NonNullable,
296        )
297        .unwrap()
298        .into_array();
299
300        let idx = buffer![1, 0, 2].into_array();
301
302        let result = list.take(idx.to_array()).unwrap();
303
304        assert_eq!(
305            result.dtype(),
306            &DType::List(
307                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
308                Nullability::NonNullable
309            )
310        );
311
312        let result = result.to_listview();
313
314        assert_eq!(result.len(), 3);
315
316        let element_dtype: Arc<DType> = Arc::new(I32.into());
317
318        assert!(result.is_valid(0).unwrap());
319        assert_eq!(
320            result.scalar_at(0).unwrap(),
321            Scalar::list(
322                element_dtype.clone(),
323                vec![3i32.into()],
324                Nullability::NonNullable
325            )
326        );
327
328        assert!(result.is_valid(1).unwrap());
329        assert_eq!(
330            result.scalar_at(1).unwrap(),
331            Scalar::list(
332                element_dtype.clone(),
333                vec![0i32.into(), 5.into()],
334                Nullability::NonNullable
335            )
336        );
337
338        assert!(result.is_valid(2).unwrap());
339        assert_eq!(
340            result.scalar_at(2).unwrap(),
341            Scalar::list(element_dtype, vec![], Nullability::NonNullable)
342        );
343    }
344
345    #[test]
346    fn test_take_empty_array() {
347        let list = ListArray::try_new(
348            buffer![0i32, 5, 3, 4].into_array(),
349            buffer![0].into_array(),
350            Validity::NonNullable,
351        )
352        .unwrap()
353        .into_array();
354
355        let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).into_array();
356
357        let result = list.take(idx.to_array()).unwrap();
358        assert_eq!(
359            result.dtype(),
360            &DType::List(
361                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
362                Nullability::Nullable
363            )
364        );
365        assert_eq!(result.len(), 0,);
366    }
367
368    #[rstest]
369    #[case(ListArray::try_new(
370        buffer![0i32, 1, 2, 3, 4, 5].into_array(),
371        buffer![0, 2, 3, 5, 5, 6].into_array(),
372        Validity::NonNullable,
373    ).unwrap())]
374    #[case(ListArray::try_new(
375        buffer![10i32, 20, 30, 40, 50].into_array(),
376        buffer![0, 2, 3, 4, 5].into_array(),
377        Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).into_array()),
378    ).unwrap())]
379    #[case(ListArray::try_new(
380        buffer![1i32, 2, 3].into_array(),
381        buffer![0, 0, 2, 2, 3].into_array(), // First and third are empty
382        Validity::NonNullable,
383    ).unwrap())]
384    #[case(ListArray::try_new(
385        buffer![42i32, 43].into_array(),
386        buffer![0, 2].into_array(),
387        Validity::NonNullable,
388    ).unwrap())]
389    #[case({
390        let elements = buffer![0i32..200].into_array();
391        let mut offsets = vec![0u64];
392        for i in 1..=50 {
393            offsets.push(offsets[i - 1] + (i as u64 % 5)); // Variable length lists
394        }
395        ListArray::try_new(
396            elements,
397            PrimitiveArray::from_iter(offsets).into_array(),
398            Validity::NonNullable,
399        ).unwrap()
400    })]
401    #[case(ListArray::try_new(
402        PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).into_array(),
403        buffer![0, 2, 3, 5].into_array(),
404        Validity::NonNullable,
405    ).unwrap())]
406    fn test_take_list_conformance(#[case] list: ListArray) {
407        test_take_conformance(&list.into_array());
408    }
409
410    #[test]
411    fn test_u64_offset_accumulation_non_nullable() {
412        let elements = buffer![0i32; 200].into_array();
413        let offsets = buffer![0u8, 200].into_array();
414        let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
415            .unwrap()
416            .into_array();
417
418        // Take the same large list twice - would overflow u8 but works with u64.
419        let idx = buffer![0u8, 0].into_array();
420        let result = list.take(idx.to_array()).unwrap();
421
422        assert_eq!(result.len(), 2);
423
424        let result_view = result.to_listview();
425        assert_eq!(result_view.len(), 2);
426        assert!(result_view.is_valid(0).unwrap());
427        assert!(result_view.is_valid(1).unwrap());
428    }
429
430    #[test]
431    fn test_u64_offset_accumulation_nullable() {
432        let elements = buffer![0i32; 150].into_array();
433        let offsets = buffer![0u8, 150, 150].into_array();
434        let validity = BoolArray::from_iter(vec![true, false]).into_array();
435        let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
436            .unwrap()
437            .into_array();
438
439        // Take the same large list twice - would overflow u8 but works with u64.
440        let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).into_array();
441        let result = list.take(idx.to_array()).unwrap();
442
443        assert_eq!(result.len(), 3);
444
445        let result_view = result.to_listview();
446        assert_eq!(result_view.len(), 3);
447        assert!(result_view.is_valid(0).unwrap());
448        assert!(result_view.is_invalid(1).unwrap());
449        assert!(result_view.is_valid(2).unwrap());
450    }
451
452    /// Regression test for validity length mismatch bug.
453    ///
454    /// When source array has `Validity::Array(...)` and indices are non-nullable,
455    /// the result validity must have length equal to indices.len(), not source.len().
456    #[test]
457    fn test_take_validity_length_mismatch_regression() {
458        // Source array with explicit validity array (length 2).
459        let list = ListArray::try_new(
460            buffer![1i32, 2, 3, 4].into_array(),
461            buffer![0, 2, 4].into_array(),
462            Validity::Array(BoolArray::from_iter(vec![true, true]).into_array()),
463        )
464        .unwrap()
465        .into_array();
466
467        // Take more indices than source length (4 vs 2) with non-nullable indices.
468        let idx = buffer![0u32, 1, 0, 1].into_array();
469
470        // This should not panic - result should have length 4.
471        let result = list.take(idx.to_array()).unwrap();
472        assert_eq!(result.len(), 4);
473    }
474}