Skip to main content

vortex_array/arrays/list/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::array::ArrayView;
10use crate::arrays::List;
11use crate::arrays::ListArray;
12use crate::arrays::Primitive;
13use crate::arrays::PrimitiveArray;
14use crate::arrays::dict::TakeExecute;
15use crate::arrays::list::ListArrayExt;
16use crate::arrays::primitive::PrimitiveArrayExt;
17use crate::builders::ArrayBuilder;
18use crate::builders::PrimitiveBuilder;
19use crate::dtype::IntegerPType;
20use crate::dtype::Nullability;
21use crate::executor::ExecutionCtx;
22use crate::match_each_integer_ptype;
23use crate::match_smallest_offset_type;
24
25// TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call
26// the `ListView::take` compute function once `ListView` is more stable.
27
28impl TakeExecute for List {
29    /// Take implementation for [`ListArray`].
30    ///
31    /// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant
32    /// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
33    /// non-contiguous indices would violate this requirement.
34    #[expect(clippy::cognitive_complexity)]
35    fn take(
36        array: ArrayView<'_, List>,
37        indices: &ArrayRef,
38        ctx: &mut ExecutionCtx,
39    ) -> VortexResult<Option<ArrayRef>> {
40        let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
41        // This is an over-approximation of the total number of elements in the resulting array.
42        let total_approx = array.elements().len().saturating_mul(indices.len());
43
44        match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| {
45            match_each_integer_ptype!(indices.ptype(), |I| {
46                match_smallest_offset_type!(total_approx, |OutputOffsetType| {
47                    {
48                        let indices = indices.as_view();
49                        _take::<I, O, OutputOffsetType>(array, indices, ctx).map(Some)
50                    }
51                })
52            })
53        })
54    }
55}
56
57fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
58    array: ArrayView<'_, List>,
59    indices_array: ArrayView<'_, Primitive>,
60    ctx: &mut ExecutionCtx,
61) -> VortexResult<ArrayRef> {
62    let data_validity = array
63        .list_validity()
64        .execute_mask(array.as_ref().len(), ctx)?;
65    let indices_validity = indices_array
66        .validity()
67        .vortex_expect("Failed to compute validity mask")
68        .execute_mask(indices_array.as_ref().len(), ctx)?;
69
70    if !indices_validity.all_true() || !data_validity.all_true() {
71        return _take_nullable::<I, O, OutputOffsetType>(array, indices_array, ctx);
72    }
73
74    let offsets_array = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
75    let offsets: &[O] = offsets_array.as_slice();
76    let indices: &[I] = indices_array.as_slice();
77
78    let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
79        Nullability::NonNullable,
80        indices.len(),
81    );
82    let mut elements_to_take =
83        PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
84
85    let mut current_offset = OutputOffsetType::zero();
86    new_offsets.append_zero();
87
88    for &data_idx in indices {
89        let data_idx: usize = data_idx.as_();
90
91        let start = offsets[data_idx];
92        let stop = offsets[data_idx + 1];
93
94        // Annoyingly, we can't turn (start..end) into a range, so we're doing that manually.
95        //
96        // We could convert start and end to usize, but that would impose a potentially
97        // harder constraint - now we don't care if they fit into usize as long as their
98        // difference does.
99        let additional: usize = (stop - start).as_();
100
101        // TODO(0ax1): optimize this
102        elements_to_take.reserve_exact(additional);
103        for i in 0..additional {
104            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
105        }
106        current_offset +=
107            OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
108        new_offsets.append_value(current_offset);
109    }
110
111    let elements_to_take = elements_to_take.finish();
112    let new_offsets = new_offsets.finish();
113
114    let new_elements = array.elements().take(elements_to_take)?;
115
116    Ok(ListArray::try_new(
117        new_elements,
118        new_offsets,
119        array.validity()?.take(indices_array.array())?,
120    )?
121    .into_array())
122}
123
124fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
125    array: ArrayView<'_, List>,
126    indices_array: ArrayView<'_, Primitive>,
127    ctx: &mut ExecutionCtx,
128) -> VortexResult<ArrayRef> {
129    let offsets_array = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
130    let offsets: &[O] = offsets_array.as_slice();
131    let indices: &[I] = indices_array.as_slice();
132    let data_validity = array
133        .list_validity()
134        .execute_mask(array.as_ref().len(), ctx)?;
135    let indices_validity = indices_array
136        .validity()
137        .vortex_expect("Failed to compute validity mask")
138        .execute_mask(indices_array.as_ref().len(), ctx)?;
139
140    let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
141        Nullability::NonNullable,
142        indices.len(),
143    );
144
145    // This will be the indices we push down to the child array to call `take` with.
146    //
147    // There are 2 things to note here:
148    // - We do not know how many elements we need to take from our child since lists are variable
149    //   size: thus we arbitrarily choose a capacity of `2 * # of indices`.
150    // - The type of the primitive builder needs to fit the largest offset of the (parent)
151    //   `ListArray`, so we make this `PrimitiveBuilder` generic over `O` (instead of `I`).
152    let mut elements_to_take =
153        PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
154
155    let mut current_offset = OutputOffsetType::zero();
156    new_offsets.append_zero();
157
158    for (idx, data_idx) in indices.iter().enumerate() {
159        if !indices_validity.value(idx) {
160            new_offsets.append_value(current_offset);
161            continue;
162        }
163
164        let data_idx: usize = data_idx.as_();
165
166        if !data_validity.value(data_idx) {
167            new_offsets.append_value(current_offset);
168            continue;
169        }
170
171        let start = offsets[data_idx];
172        let stop = offsets[data_idx + 1];
173
174        // See the note in `_take` on the reasoning.
175        let additional: usize = (stop - start).as_();
176
177        elements_to_take.reserve_exact(additional);
178        for i in 0..additional {
179            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
180        }
181        current_offset +=
182            OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
183        new_offsets.append_value(current_offset);
184    }
185
186    let elements_to_take = elements_to_take.finish();
187    let new_offsets = new_offsets.finish();
188    let new_elements = array.elements().take(elements_to_take)?;
189
190    Ok(ListArray::try_new(
191        new_elements,
192        new_offsets,
193        array.validity()?.take(indices_array.array())?,
194    )?
195    .into_array())
196}
197
198#[cfg(test)]
199mod test {
200    use std::sync::Arc;
201
202    use rstest::rstest;
203    use vortex_buffer::buffer;
204
205    use crate::IntoArray as _;
206    use crate::LEGACY_SESSION;
207    #[expect(deprecated)]
208    use crate::ToCanonical as _;
209    use crate::VortexSessionExecute;
210    use crate::arrays::BoolArray;
211    use crate::arrays::ListArray;
212    use crate::arrays::PrimitiveArray;
213    use crate::compute::conformance::take::test_take_conformance;
214    use crate::dtype::DType;
215    use crate::dtype::Nullability;
216    use crate::dtype::PType::I32;
217    use crate::scalar::Scalar;
218    use crate::validity::Validity;
219
220    #[test]
221    fn nullable_take() {
222        let list = ListArray::try_new(
223            buffer![0i32, 5, 3, 4].into_array(),
224            buffer![0, 2, 3, 4, 4].into_array(),
225            Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).into_array()),
226        )
227        .unwrap()
228        .into_array();
229
230        let idx =
231            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).into_array();
232
233        let result = list.take(idx).unwrap();
234
235        assert_eq!(
236            result.dtype(),
237            &DType::List(
238                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
239                Nullability::Nullable
240            )
241        );
242
243        #[expect(deprecated)]
244        let result = result.to_listview();
245
246        assert_eq!(result.len(), 4);
247
248        let element_dtype: Arc<DType> = Arc::new(I32.into());
249
250        assert!(
251            result
252                .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
253                .unwrap()
254        );
255        assert_eq!(
256            result
257                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
258                .unwrap(),
259            Scalar::list(
260                Arc::clone(&element_dtype),
261                vec![0i32.into(), 5.into()],
262                Nullability::Nullable
263            )
264        );
265
266        assert!(
267            result
268                .is_invalid(1, &mut LEGACY_SESSION.create_execution_ctx())
269                .unwrap()
270        );
271
272        assert!(
273            result
274                .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
275                .unwrap()
276        );
277        assert_eq!(
278            result
279                .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
280                .unwrap(),
281            Scalar::list(
282                Arc::clone(&element_dtype),
283                vec![3i32.into()],
284                Nullability::Nullable
285            )
286        );
287
288        assert!(
289            result
290                .is_valid(3, &mut LEGACY_SESSION.create_execution_ctx())
291                .unwrap()
292        );
293        assert_eq!(
294            result
295                .execute_scalar(3, &mut LEGACY_SESSION.create_execution_ctx())
296                .unwrap(),
297            Scalar::list(element_dtype, vec![], Nullability::Nullable)
298        );
299    }
300
301    #[test]
302    fn change_validity() {
303        let list = ListArray::try_new(
304            buffer![0i32, 5, 3, 4].into_array(),
305            buffer![0, 2, 3].into_array(),
306            Validity::NonNullable,
307        )
308        .unwrap()
309        .into_array();
310
311        let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).into_array();
312        // since idx is nullable, the final list will also be nullable
313
314        let result = list.take(idx).unwrap();
315        assert_eq!(
316            result.dtype(),
317            &DType::List(
318                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
319                Nullability::Nullable
320            )
321        );
322    }
323
324    #[test]
325    fn non_nullable_take() {
326        let list = ListArray::try_new(
327            buffer![0i32, 5, 3, 4].into_array(),
328            buffer![0, 2, 3, 3, 4].into_array(),
329            Validity::NonNullable,
330        )
331        .unwrap()
332        .into_array();
333
334        let idx = buffer![1, 0, 2].into_array();
335
336        let result = list.take(idx).unwrap();
337
338        assert_eq!(
339            result.dtype(),
340            &DType::List(
341                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
342                Nullability::NonNullable
343            )
344        );
345
346        #[expect(deprecated)]
347        let result = result.to_listview();
348
349        assert_eq!(result.len(), 3);
350
351        let element_dtype: Arc<DType> = Arc::new(I32.into());
352
353        assert!(
354            result
355                .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
356                .unwrap()
357        );
358        assert_eq!(
359            result
360                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
361                .unwrap(),
362            Scalar::list(
363                Arc::clone(&element_dtype),
364                vec![3i32.into()],
365                Nullability::NonNullable
366            )
367        );
368
369        assert!(
370            result
371                .is_valid(1, &mut LEGACY_SESSION.create_execution_ctx())
372                .unwrap()
373        );
374        assert_eq!(
375            result
376                .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())
377                .unwrap(),
378            Scalar::list(
379                Arc::clone(&element_dtype),
380                vec![0i32.into(), 5.into()],
381                Nullability::NonNullable
382            )
383        );
384
385        assert!(
386            result
387                .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
388                .unwrap()
389        );
390        assert_eq!(
391            result
392                .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
393                .unwrap(),
394            Scalar::list(element_dtype, vec![], Nullability::NonNullable)
395        );
396    }
397
398    #[test]
399    fn test_take_empty_array() {
400        let list = ListArray::try_new(
401            buffer![0i32, 5, 3, 4].into_array(),
402            buffer![0].into_array(),
403            Validity::NonNullable,
404        )
405        .unwrap()
406        .into_array();
407
408        let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).into_array();
409
410        let result = list.take(idx).unwrap();
411        assert_eq!(
412            result.dtype(),
413            &DType::List(
414                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
415                Nullability::Nullable
416            )
417        );
418        assert_eq!(result.len(), 0,);
419    }
420
421    #[rstest]
422    #[case(ListArray::try_new(
423        buffer![0i32, 1, 2, 3, 4, 5].into_array(),
424        buffer![0, 2, 3, 5, 5, 6].into_array(),
425        Validity::NonNullable,
426    ).unwrap())]
427    #[case(ListArray::try_new(
428        buffer![10i32, 20, 30, 40, 50].into_array(),
429        buffer![0, 2, 3, 4, 5].into_array(),
430        Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).into_array()),
431    ).unwrap())]
432    #[case(ListArray::try_new(
433        buffer![1i32, 2, 3].into_array(),
434        buffer![0, 0, 2, 2, 3].into_array(), // First and third are empty
435        Validity::NonNullable,
436    ).unwrap())]
437    #[case(ListArray::try_new(
438        buffer![42i32, 43].into_array(),
439        buffer![0, 2].into_array(),
440        Validity::NonNullable,
441    ).unwrap())]
442    #[case({
443        let elements = buffer![0i32..200].into_array();
444        let mut offsets = vec![0u64];
445        for i in 1..=50 {
446            offsets.push(offsets[i - 1] + (i as u64 % 5)); // Variable length lists
447        }
448        ListArray::try_new(
449            elements,
450            PrimitiveArray::from_iter(offsets).into_array(),
451            Validity::NonNullable,
452        ).unwrap()
453    })]
454    #[case(ListArray::try_new(
455        PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).into_array(),
456        buffer![0, 2, 3, 5].into_array(),
457        Validity::NonNullable,
458    ).unwrap())]
459    fn test_take_list_conformance(#[case] list: ListArray) {
460        test_take_conformance(&list.into_array());
461    }
462
463    #[test]
464    fn test_u64_offset_accumulation_non_nullable() {
465        let elements = buffer![0i32; 200].into_array();
466        let offsets = buffer![0u8, 200].into_array();
467        let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
468            .unwrap()
469            .into_array();
470
471        // Take the same large list twice - would overflow u8 but works with u64.
472        let idx = buffer![0u8, 0].into_array();
473        let result = list.take(idx).unwrap();
474
475        assert_eq!(result.len(), 2);
476
477        #[expect(deprecated)]
478        let result_view = result.to_listview();
479        assert_eq!(result_view.len(), 2);
480        assert!(
481            result_view
482                .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
483                .unwrap()
484        );
485        assert!(
486            result_view
487                .is_valid(1, &mut LEGACY_SESSION.create_execution_ctx())
488                .unwrap()
489        );
490    }
491
492    #[test]
493    fn test_u64_offset_accumulation_nullable() {
494        let elements = buffer![0i32; 150].into_array();
495        let offsets = buffer![0u8, 150, 150].into_array();
496        let validity = BoolArray::from_iter(vec![true, false]).into_array();
497        let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
498            .unwrap()
499            .into_array();
500
501        // Take the same large list twice - would overflow u8 but works with u64.
502        let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).into_array();
503        let result = list.take(idx).unwrap();
504
505        assert_eq!(result.len(), 3);
506
507        #[expect(deprecated)]
508        let result_view = result.to_listview();
509        assert_eq!(result_view.len(), 3);
510        assert!(
511            result_view
512                .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
513                .unwrap()
514        );
515        assert!(
516            result_view
517                .is_invalid(1, &mut LEGACY_SESSION.create_execution_ctx())
518                .unwrap()
519        );
520        assert!(
521            result_view
522                .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
523                .unwrap()
524        );
525    }
526
527    /// Regression test for validity length mismatch bug.
528    ///
529    /// When source array has `Validity::Array(...)` and indices are non-nullable,
530    /// the result validity must have length equal to indices.len(), not source.len().
531    #[test]
532    fn test_take_validity_length_mismatch_regression() {
533        // Source array with explicit validity array (length 2).
534        let list = ListArray::try_new(
535            buffer![1i32, 2, 3, 4].into_array(),
536            buffer![0, 2, 4].into_array(),
537            Validity::Array(BoolArray::from_iter(vec![true, true]).into_array()),
538        )
539        .unwrap()
540        .into_array();
541
542        // Take more indices than source length (4 vs 2) with non-nullable indices.
543        let idx = buffer![0u32, 1, 0, 1].into_array();
544
545        // This should not panic - result should have length 4.
546        let result = list.take(idx).unwrap();
547        assert_eq!(result.len(), 4);
548    }
549}