Skip to main content

vortex_array/arrays/list/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::array::ArrayView;
10use crate::arrays::List;
11use crate::arrays::ListArray;
12use crate::arrays::Primitive;
13use crate::arrays::PrimitiveArray;
14use crate::arrays::dict::TakeExecute;
15use crate::arrays::list::ListArrayExt;
16use crate::arrays::primitive::PrimitiveArrayExt;
17use crate::builders::ArrayBuilder;
18use crate::builders::PrimitiveBuilder;
19use crate::dtype::IntegerPType;
20use crate::dtype::Nullability;
21use crate::executor::ExecutionCtx;
22use crate::match_each_integer_ptype;
23use crate::match_smallest_offset_type;
24
25// TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call
26// the `ListView::take` compute function once `ListView` is more stable.
27
28impl TakeExecute for List {
29    /// Take implementation for [`ListArray`].
30    ///
31    /// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant
32    /// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
33    /// non-contiguous indices would violate this requirement.
34    #[expect(clippy::cognitive_complexity)]
35    fn take(
36        array: ArrayView<'_, List>,
37        indices: &ArrayRef,
38        ctx: &mut ExecutionCtx,
39    ) -> VortexResult<Option<ArrayRef>> {
40        let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
41        // This is an over-approximation of the total number of elements in the resulting array.
42        let total_approx = array.elements().len().saturating_mul(indices.len());
43
44        match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| {
45            match_each_integer_ptype!(indices.ptype(), |I| {
46                match_smallest_offset_type!(total_approx, |OutputOffsetType| {
47                    {
48                        let indices = indices.as_view();
49                        _take::<I, O, OutputOffsetType>(array, indices, ctx).map(Some)
50                    }
51                })
52            })
53        })
54    }
55}
56
57fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
58    array: ArrayView<'_, List>,
59    indices_array: ArrayView<'_, Primitive>,
60    ctx: &mut ExecutionCtx,
61) -> VortexResult<ArrayRef> {
62    let data_validity = array.list_validity_mask();
63    let indices_validity = indices_array.validity_mask();
64
65    if !indices_validity.all_true() || !data_validity.all_true() {
66        return _take_nullable::<I, O, OutputOffsetType>(array, indices_array, ctx);
67    }
68
69    let offsets_array = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
70    let offsets: &[O] = offsets_array.as_slice();
71    let indices: &[I] = indices_array.as_slice();
72
73    let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
74        Nullability::NonNullable,
75        indices.len(),
76    );
77    let mut elements_to_take =
78        PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
79
80    let mut current_offset = OutputOffsetType::zero();
81    new_offsets.append_zero();
82
83    for &data_idx in indices {
84        let data_idx: usize = data_idx.as_();
85
86        let start = offsets[data_idx];
87        let stop = offsets[data_idx + 1];
88
89        // Annoyingly, we can't turn (start..end) into a range, so we're doing that manually.
90        //
91        // We could convert start and end to usize, but that would impose a potentially
92        // harder constraint - now we don't care if they fit into usize as long as their
93        // difference does.
94        let additional: usize = (stop - start).as_();
95
96        // TODO(0ax1): optimize this
97        elements_to_take.reserve_exact(additional);
98        for i in 0..additional {
99            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
100        }
101        current_offset +=
102            OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
103        new_offsets.append_value(current_offset);
104    }
105
106    let elements_to_take = elements_to_take.finish();
107    let new_offsets = new_offsets.finish();
108
109    let new_elements = array.elements().take(elements_to_take)?;
110
111    Ok(ListArray::try_new(
112        new_elements,
113        new_offsets,
114        array.validity()?.take(indices_array.array())?,
115    )?
116    .into_array())
117}
118
119fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
120    array: ArrayView<'_, List>,
121    indices_array: ArrayView<'_, Primitive>,
122    ctx: &mut ExecutionCtx,
123) -> VortexResult<ArrayRef> {
124    let offsets_array = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
125    let offsets: &[O] = offsets_array.as_slice();
126    let indices: &[I] = indices_array.as_slice();
127    let data_validity = array.list_validity_mask();
128    let indices_validity = indices_array.validity_mask();
129
130    let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
131        Nullability::NonNullable,
132        indices.len(),
133    );
134
135    // This will be the indices we push down to the child array to call `take` with.
136    //
137    // There are 2 things to note here:
138    // - We do not know how many elements we need to take from our child since lists are variable
139    //   size: thus we arbitrarily choose a capacity of `2 * # of indices`.
140    // - The type of the primitive builder needs to fit the largest offset of the (parent)
141    //   `ListArray`, so we make this `PrimitiveBuilder` generic over `O` (instead of `I`).
142    let mut elements_to_take =
143        PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
144
145    let mut current_offset = OutputOffsetType::zero();
146    new_offsets.append_zero();
147
148    for (idx, data_idx) in indices.iter().enumerate() {
149        if !indices_validity.value(idx) {
150            new_offsets.append_value(current_offset);
151            continue;
152        }
153
154        let data_idx: usize = data_idx.as_();
155
156        if !data_validity.value(data_idx) {
157            new_offsets.append_value(current_offset);
158            continue;
159        }
160
161        let start = offsets[data_idx];
162        let stop = offsets[data_idx + 1];
163
164        // See the note in `_take` on the reasoning.
165        let additional: usize = (stop - start).as_();
166
167        elements_to_take.reserve_exact(additional);
168        for i in 0..additional {
169            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
170        }
171        current_offset +=
172            OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
173        new_offsets.append_value(current_offset);
174    }
175
176    let elements_to_take = elements_to_take.finish();
177    let new_offsets = new_offsets.finish();
178    let new_elements = array.elements().take(elements_to_take)?;
179
180    Ok(ListArray::try_new(
181        new_elements,
182        new_offsets,
183        array.validity()?.take(indices_array.array())?,
184    )?
185    .into_array())
186}
187
188#[cfg(test)]
189mod test {
190    use std::sync::Arc;
191
192    use rstest::rstest;
193    use vortex_buffer::buffer;
194
195    use crate::IntoArray as _;
196    use crate::ToCanonical;
197    use crate::arrays::BoolArray;
198    use crate::arrays::ListArray;
199    use crate::arrays::PrimitiveArray;
200    use crate::compute::conformance::take::test_take_conformance;
201    use crate::dtype::DType;
202    use crate::dtype::Nullability;
203    use crate::dtype::PType::I32;
204    use crate::scalar::Scalar;
205    use crate::validity::Validity;
206
207    #[test]
208    fn nullable_take() {
209        let list = ListArray::try_new(
210            buffer![0i32, 5, 3, 4].into_array(),
211            buffer![0, 2, 3, 4, 4].into_array(),
212            Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).into_array()),
213        )
214        .unwrap()
215        .into_array();
216
217        let idx =
218            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).into_array();
219
220        let result = list.take(idx).unwrap();
221
222        assert_eq!(
223            result.dtype(),
224            &DType::List(
225                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
226                Nullability::Nullable
227            )
228        );
229
230        let result = result.to_listview();
231
232        assert_eq!(result.len(), 4);
233
234        let element_dtype: Arc<DType> = Arc::new(I32.into());
235
236        assert!(result.is_valid(0).unwrap());
237        assert_eq!(
238            result.scalar_at(0).unwrap(),
239            Scalar::list(
240                Arc::clone(&element_dtype),
241                vec![0i32.into(), 5.into()],
242                Nullability::Nullable
243            )
244        );
245
246        assert!(result.is_invalid(1).unwrap());
247
248        assert!(result.is_valid(2).unwrap());
249        assert_eq!(
250            result.scalar_at(2).unwrap(),
251            Scalar::list(
252                Arc::clone(&element_dtype),
253                vec![3i32.into()],
254                Nullability::Nullable
255            )
256        );
257
258        assert!(result.is_valid(3).unwrap());
259        assert_eq!(
260            result.scalar_at(3).unwrap(),
261            Scalar::list(element_dtype, vec![], Nullability::Nullable)
262        );
263    }
264
265    #[test]
266    fn change_validity() {
267        let list = ListArray::try_new(
268            buffer![0i32, 5, 3, 4].into_array(),
269            buffer![0, 2, 3].into_array(),
270            Validity::NonNullable,
271        )
272        .unwrap()
273        .into_array();
274
275        let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).into_array();
276        // since idx is nullable, the final list will also be nullable
277
278        let result = list.take(idx).unwrap();
279        assert_eq!(
280            result.dtype(),
281            &DType::List(
282                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
283                Nullability::Nullable
284            )
285        );
286    }
287
288    #[test]
289    fn non_nullable_take() {
290        let list = ListArray::try_new(
291            buffer![0i32, 5, 3, 4].into_array(),
292            buffer![0, 2, 3, 3, 4].into_array(),
293            Validity::NonNullable,
294        )
295        .unwrap()
296        .into_array();
297
298        let idx = buffer![1, 0, 2].into_array();
299
300        let result = list.take(idx).unwrap();
301
302        assert_eq!(
303            result.dtype(),
304            &DType::List(
305                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
306                Nullability::NonNullable
307            )
308        );
309
310        let result = result.to_listview();
311
312        assert_eq!(result.len(), 3);
313
314        let element_dtype: Arc<DType> = Arc::new(I32.into());
315
316        assert!(result.is_valid(0).unwrap());
317        assert_eq!(
318            result.scalar_at(0).unwrap(),
319            Scalar::list(
320                Arc::clone(&element_dtype),
321                vec![3i32.into()],
322                Nullability::NonNullable
323            )
324        );
325
326        assert!(result.is_valid(1).unwrap());
327        assert_eq!(
328            result.scalar_at(1).unwrap(),
329            Scalar::list(
330                Arc::clone(&element_dtype),
331                vec![0i32.into(), 5.into()],
332                Nullability::NonNullable
333            )
334        );
335
336        assert!(result.is_valid(2).unwrap());
337        assert_eq!(
338            result.scalar_at(2).unwrap(),
339            Scalar::list(element_dtype, vec![], Nullability::NonNullable)
340        );
341    }
342
343    #[test]
344    fn test_take_empty_array() {
345        let list = ListArray::try_new(
346            buffer![0i32, 5, 3, 4].into_array(),
347            buffer![0].into_array(),
348            Validity::NonNullable,
349        )
350        .unwrap()
351        .into_array();
352
353        let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).into_array();
354
355        let result = list.take(idx).unwrap();
356        assert_eq!(
357            result.dtype(),
358            &DType::List(
359                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
360                Nullability::Nullable
361            )
362        );
363        assert_eq!(result.len(), 0,);
364    }
365
366    #[rstest]
367    #[case(ListArray::try_new(
368        buffer![0i32, 1, 2, 3, 4, 5].into_array(),
369        buffer![0, 2, 3, 5, 5, 6].into_array(),
370        Validity::NonNullable,
371    ).unwrap())]
372    #[case(ListArray::try_new(
373        buffer![10i32, 20, 30, 40, 50].into_array(),
374        buffer![0, 2, 3, 4, 5].into_array(),
375        Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).into_array()),
376    ).unwrap())]
377    #[case(ListArray::try_new(
378        buffer![1i32, 2, 3].into_array(),
379        buffer![0, 0, 2, 2, 3].into_array(), // First and third are empty
380        Validity::NonNullable,
381    ).unwrap())]
382    #[case(ListArray::try_new(
383        buffer![42i32, 43].into_array(),
384        buffer![0, 2].into_array(),
385        Validity::NonNullable,
386    ).unwrap())]
387    #[case({
388        let elements = buffer![0i32..200].into_array();
389        let mut offsets = vec![0u64];
390        for i in 1..=50 {
391            offsets.push(offsets[i - 1] + (i as u64 % 5)); // Variable length lists
392        }
393        ListArray::try_new(
394            elements,
395            PrimitiveArray::from_iter(offsets).into_array(),
396            Validity::NonNullable,
397        ).unwrap()
398    })]
399    #[case(ListArray::try_new(
400        PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).into_array(),
401        buffer![0, 2, 3, 5].into_array(),
402        Validity::NonNullable,
403    ).unwrap())]
404    fn test_take_list_conformance(#[case] list: ListArray) {
405        test_take_conformance(&list.into_array());
406    }
407
408    #[test]
409    fn test_u64_offset_accumulation_non_nullable() {
410        let elements = buffer![0i32; 200].into_array();
411        let offsets = buffer![0u8, 200].into_array();
412        let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
413            .unwrap()
414            .into_array();
415
416        // Take the same large list twice - would overflow u8 but works with u64.
417        let idx = buffer![0u8, 0].into_array();
418        let result = list.take(idx).unwrap();
419
420        assert_eq!(result.len(), 2);
421
422        let result_view = result.to_listview();
423        assert_eq!(result_view.len(), 2);
424        assert!(result_view.is_valid(0).unwrap());
425        assert!(result_view.is_valid(1).unwrap());
426    }
427
428    #[test]
429    fn test_u64_offset_accumulation_nullable() {
430        let elements = buffer![0i32; 150].into_array();
431        let offsets = buffer![0u8, 150, 150].into_array();
432        let validity = BoolArray::from_iter(vec![true, false]).into_array();
433        let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
434            .unwrap()
435            .into_array();
436
437        // Take the same large list twice - would overflow u8 but works with u64.
438        let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).into_array();
439        let result = list.take(idx).unwrap();
440
441        assert_eq!(result.len(), 3);
442
443        let result_view = result.to_listview();
444        assert_eq!(result_view.len(), 3);
445        assert!(result_view.is_valid(0).unwrap());
446        assert!(result_view.is_invalid(1).unwrap());
447        assert!(result_view.is_valid(2).unwrap());
448    }
449
450    /// Regression test for validity length mismatch bug.
451    ///
452    /// When source array has `Validity::Array(...)` and indices are non-nullable,
453    /// the result validity must have length equal to indices.len(), not source.len().
454    #[test]
455    fn test_take_validity_length_mismatch_regression() {
456        // Source array with explicit validity array (length 2).
457        let list = ListArray::try_new(
458            buffer![1i32, 2, 3, 4].into_array(),
459            buffer![0, 2, 4].into_array(),
460            Validity::Array(BoolArray::from_iter(vec![true, true]).into_array()),
461        )
462        .unwrap()
463        .into_array();
464
465        // Take more indices than source length (4 vs 2) with non-nullable indices.
466        let idx = buffer![0u32, 1, 0, 1].into_array();
467
468        // This should not panic - result should have length 4.
469        let result = list.take(idx).unwrap();
470        assert_eq!(result.len(), 4);
471    }
472}