Skip to main content

vortex_array/arrays/list/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::array::ArrayView;
10use crate::arrays::List;
11use crate::arrays::ListArray;
12use crate::arrays::Primitive;
13use crate::arrays::PrimitiveArray;
14use crate::arrays::dict::TakeExecute;
15use crate::arrays::list::ListArrayExt;
16use crate::arrays::primitive::PrimitiveArrayExt;
17use crate::builders::ArrayBuilder;
18use crate::builders::PrimitiveBuilder;
19use crate::dtype::IntegerPType;
20use crate::dtype::Nullability;
21use crate::executor::ExecutionCtx;
22use crate::match_each_unsigned_integer_ptype;
23use crate::match_smallest_offset_type;
24
25// TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call
26// the `ListView::take` compute function once `ListView` is more stable.
27
28impl TakeExecute for List {
29    /// Take implementation for [`ListArray`].
30    ///
31    /// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant
32    /// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
33    /// non-contiguous indices would violate this requirement.
34    #[expect(clippy::cognitive_complexity)]
35    fn take(
36        array: ArrayView<'_, List>,
37        indices: &ArrayRef,
38        ctx: &mut ExecutionCtx,
39    ) -> VortexResult<Option<ArrayRef>> {
40        let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
41        let indices = indices.reinterpret_cast(indices.ptype().to_unsigned());
42        let offsets = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
43        let offsets = offsets.reinterpret_cast(offsets.ptype().to_unsigned());
44        // This is an over-approximation of the total number of elements in the resulting array.
45        let total_approx = array.elements().len().saturating_mul(indices.len());
46
47        match_each_unsigned_integer_ptype!(offsets.ptype(), |O| {
48            match_each_unsigned_integer_ptype!(indices.ptype(), |I| {
49                match_smallest_offset_type!(total_approx, |OutputOffsetType| {
50                    _take::<I, O, OutputOffsetType>(
51                        array,
52                        offsets.as_view(),
53                        indices.as_view(),
54                        ctx,
55                    )
56                    .map(Some)
57                })
58            })
59        })
60    }
61}
62
63fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
64    array: ArrayView<'_, List>,
65    offsets_array: ArrayView<'_, Primitive>,
66    indices_array: ArrayView<'_, Primitive>,
67    ctx: &mut ExecutionCtx,
68) -> VortexResult<ArrayRef> {
69    let data_validity = array
70        .list_validity()
71        .execute_mask(array.as_ref().len(), ctx)?;
72    let indices_validity = indices_array
73        .validity()
74        .vortex_expect("Failed to compute validity mask")
75        .execute_mask(indices_array.as_ref().len(), ctx)?;
76
77    if !indices_validity.all_true() || !data_validity.all_true() {
78        return _take_nullable::<I, O, OutputOffsetType>(array, offsets_array, indices_array, ctx);
79    }
80
81    let offsets: &[O] = offsets_array.as_slice();
82    let indices: &[I] = indices_array.as_slice();
83
84    let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
85        Nullability::NonNullable,
86        indices.len(),
87    );
88    let mut elements_to_take =
89        PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
90
91    let mut current_offset = OutputOffsetType::zero();
92    new_offsets.append_zero();
93
94    for &data_idx in indices {
95        let data_idx: usize = data_idx.as_();
96
97        let start = offsets[data_idx];
98        let stop = offsets[data_idx + 1];
99
100        // Annoyingly, we can't turn (start..end) into a range, so we're doing that manually.
101        //
102        // We could convert start and end to usize, but that would impose a potentially
103        // harder constraint - now we don't care if they fit into usize as long as their
104        // difference does.
105        let additional: usize = (stop - start).as_();
106
107        // TODO(0ax1): optimize this
108        elements_to_take.reserve_exact(additional);
109        for i in 0..additional {
110            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
111        }
112        current_offset +=
113            OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
114        new_offsets.append_value(current_offset);
115    }
116
117    let elements_to_take = elements_to_take.finish();
118    let new_offsets = new_offsets.finish();
119
120    let new_elements = array.elements().take(elements_to_take)?;
121
122    Ok(ListArray::try_new(
123        new_elements,
124        new_offsets,
125        array.validity()?.take(indices_array.array())?,
126    )?
127    .into_array())
128}
129
130// Kept out-of-line: as a single-callsite generic helper it would otherwise be inlined into every
131// monomorphization of `_take`, duplicating the entire nullable path across all specializations.
132#[inline(never)]
133fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
134    array: ArrayView<'_, List>,
135    offsets_array: ArrayView<'_, Primitive>,
136    indices_array: ArrayView<'_, Primitive>,
137    ctx: &mut ExecutionCtx,
138) -> VortexResult<ArrayRef> {
139    let offsets: &[O] = offsets_array.as_slice();
140    let indices: &[I] = indices_array.as_slice();
141    let data_validity = array
142        .list_validity()
143        .execute_mask(array.as_ref().len(), ctx)?;
144    let indices_validity = indices_array
145        .validity()
146        .vortex_expect("Failed to compute validity mask")
147        .execute_mask(indices_array.as_ref().len(), ctx)?;
148
149    let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
150        Nullability::NonNullable,
151        indices.len(),
152    );
153
154    // This will be the indices we push down to the child array to call `take` with.
155    //
156    // There are 2 things to note here:
157    // - We do not know how many elements we need to take from our child since lists are variable
158    //   size: thus we arbitrarily choose a capacity of `2 * # of indices`.
159    // - The type of the primitive builder needs to fit the largest offset of the (parent)
160    //   `ListArray`, so we make this `PrimitiveBuilder` generic over `O` (instead of `I`).
161    let mut elements_to_take =
162        PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
163
164    let mut current_offset = OutputOffsetType::zero();
165    new_offsets.append_zero();
166
167    for (idx, data_idx) in indices.iter().enumerate() {
168        if !indices_validity.value(idx) {
169            new_offsets.append_value(current_offset);
170            continue;
171        }
172
173        let data_idx: usize = data_idx.as_();
174
175        if !data_validity.value(data_idx) {
176            new_offsets.append_value(current_offset);
177            continue;
178        }
179
180        let start = offsets[data_idx];
181        let stop = offsets[data_idx + 1];
182
183        // See the note in `_take` on the reasoning.
184        let additional: usize = (stop - start).as_();
185
186        elements_to_take.reserve_exact(additional);
187        for i in 0..additional {
188            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
189        }
190        current_offset +=
191            OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
192        new_offsets.append_value(current_offset);
193    }
194
195    let elements_to_take = elements_to_take.finish();
196    let new_offsets = new_offsets.finish();
197    let new_elements = array.elements().take(elements_to_take)?;
198
199    Ok(ListArray::try_new(
200        new_elements,
201        new_offsets,
202        array.validity()?.take(indices_array.array())?,
203    )?
204    .into_array())
205}
206
207#[cfg(test)]
208mod test {
209    use std::sync::Arc;
210
211    use rstest::rstest;
212    use vortex_buffer::buffer;
213
214    use crate::IntoArray as _;
215    use crate::LEGACY_SESSION;
216    #[expect(deprecated)]
217    use crate::ToCanonical as _;
218    use crate::VortexSessionExecute;
219    use crate::arrays::BoolArray;
220    use crate::arrays::ListArray;
221    use crate::arrays::PrimitiveArray;
222    use crate::compute::conformance::take::test_take_conformance;
223    use crate::dtype::DType;
224    use crate::dtype::Nullability;
225    use crate::dtype::PType::I32;
226    use crate::scalar::Scalar;
227    use crate::validity::Validity;
228
229    #[test]
230    fn nullable_take() {
231        let list = ListArray::try_new(
232            buffer![0i32, 5, 3, 4].into_array(),
233            buffer![0, 2, 3, 4, 4].into_array(),
234            Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).into_array()),
235        )
236        .unwrap()
237        .into_array();
238
239        let idx =
240            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).into_array();
241
242        let result = list.take(idx).unwrap();
243
244        assert_eq!(
245            result.dtype(),
246            &DType::List(
247                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
248                Nullability::Nullable
249            )
250        );
251
252        #[expect(deprecated)]
253        let result = result.to_listview();
254
255        assert_eq!(result.len(), 4);
256
257        let element_dtype: Arc<DType> = Arc::new(I32.into());
258
259        assert!(
260            result
261                .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
262                .unwrap()
263        );
264        assert_eq!(
265            result
266                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
267                .unwrap(),
268            Scalar::list(
269                Arc::clone(&element_dtype),
270                vec![0i32.into(), 5.into()],
271                Nullability::Nullable
272            )
273        );
274
275        assert!(
276            result
277                .is_invalid(1, &mut LEGACY_SESSION.create_execution_ctx())
278                .unwrap()
279        );
280
281        assert!(
282            result
283                .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
284                .unwrap()
285        );
286        assert_eq!(
287            result
288                .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
289                .unwrap(),
290            Scalar::list(
291                Arc::clone(&element_dtype),
292                vec![3i32.into()],
293                Nullability::Nullable
294            )
295        );
296
297        assert!(
298            result
299                .is_valid(3, &mut LEGACY_SESSION.create_execution_ctx())
300                .unwrap()
301        );
302        assert_eq!(
303            result
304                .execute_scalar(3, &mut LEGACY_SESSION.create_execution_ctx())
305                .unwrap(),
306            Scalar::list(element_dtype, vec![], Nullability::Nullable)
307        );
308    }
309
310    #[test]
311    fn change_validity() {
312        let list = ListArray::try_new(
313            buffer![0i32, 5, 3, 4].into_array(),
314            buffer![0, 2, 3].into_array(),
315            Validity::NonNullable,
316        )
317        .unwrap()
318        .into_array();
319
320        let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).into_array();
321        // since idx is nullable, the final list will also be nullable
322
323        let result = list.take(idx).unwrap();
324        assert_eq!(
325            result.dtype(),
326            &DType::List(
327                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
328                Nullability::Nullable
329            )
330        );
331    }
332
333    #[test]
334    fn non_nullable_take() {
335        let list = ListArray::try_new(
336            buffer![0i32, 5, 3, 4].into_array(),
337            buffer![0, 2, 3, 3, 4].into_array(),
338            Validity::NonNullable,
339        )
340        .unwrap()
341        .into_array();
342
343        let idx = buffer![1, 0, 2].into_array();
344
345        let result = list.take(idx).unwrap();
346
347        assert_eq!(
348            result.dtype(),
349            &DType::List(
350                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
351                Nullability::NonNullable
352            )
353        );
354
355        #[expect(deprecated)]
356        let result = result.to_listview();
357
358        assert_eq!(result.len(), 3);
359
360        let element_dtype: Arc<DType> = Arc::new(I32.into());
361
362        assert!(
363            result
364                .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
365                .unwrap()
366        );
367        assert_eq!(
368            result
369                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
370                .unwrap(),
371            Scalar::list(
372                Arc::clone(&element_dtype),
373                vec![3i32.into()],
374                Nullability::NonNullable
375            )
376        );
377
378        assert!(
379            result
380                .is_valid(1, &mut LEGACY_SESSION.create_execution_ctx())
381                .unwrap()
382        );
383        assert_eq!(
384            result
385                .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())
386                .unwrap(),
387            Scalar::list(
388                Arc::clone(&element_dtype),
389                vec![0i32.into(), 5.into()],
390                Nullability::NonNullable
391            )
392        );
393
394        assert!(
395            result
396                .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
397                .unwrap()
398        );
399        assert_eq!(
400            result
401                .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
402                .unwrap(),
403            Scalar::list(element_dtype, vec![], Nullability::NonNullable)
404        );
405    }
406
407    #[test]
408    fn test_take_empty_array() {
409        let list = ListArray::try_new(
410            buffer![0i32, 5, 3, 4].into_array(),
411            buffer![0].into_array(),
412            Validity::NonNullable,
413        )
414        .unwrap()
415        .into_array();
416
417        let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).into_array();
418
419        let result = list.take(idx).unwrap();
420        assert_eq!(
421            result.dtype(),
422            &DType::List(
423                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
424                Nullability::Nullable
425            )
426        );
427        assert_eq!(result.len(), 0,);
428    }
429
430    #[rstest]
431    #[case(ListArray::try_new(
432        buffer![0i32, 1, 2, 3, 4, 5].into_array(),
433        buffer![0, 2, 3, 5, 5, 6].into_array(),
434        Validity::NonNullable,
435    ).unwrap())]
436    #[case(ListArray::try_new(
437        buffer![10i32, 20, 30, 40, 50].into_array(),
438        buffer![0, 2, 3, 4, 5].into_array(),
439        Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).into_array()),
440    ).unwrap())]
441    #[case(ListArray::try_new(
442        buffer![1i32, 2, 3].into_array(),
443        buffer![0, 0, 2, 2, 3].into_array(), // First and third are empty
444        Validity::NonNullable,
445    ).unwrap())]
446    #[case(ListArray::try_new(
447        buffer![42i32, 43].into_array(),
448        buffer![0, 2].into_array(),
449        Validity::NonNullable,
450    ).unwrap())]
451    #[case({
452        let elements = buffer![0i32..200].into_array();
453        let mut offsets = vec![0u64];
454        for i in 1..=50 {
455            offsets.push(offsets[i - 1] + (i as u64 % 5)); // Variable length lists
456        }
457        ListArray::try_new(
458            elements,
459            PrimitiveArray::from_iter(offsets).into_array(),
460            Validity::NonNullable,
461        ).unwrap()
462    })]
463    #[case(ListArray::try_new(
464        PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).into_array(),
465        buffer![0, 2, 3, 5].into_array(),
466        Validity::NonNullable,
467    ).unwrap())]
468    fn test_take_list_conformance(#[case] list: ListArray) {
469        test_take_conformance(&list.into_array());
470    }
471
472    #[test]
473    fn test_u64_offset_accumulation_non_nullable() {
474        let elements = buffer![0i32; 200].into_array();
475        let offsets = buffer![0u8, 200].into_array();
476        let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
477            .unwrap()
478            .into_array();
479
480        // Take the same large list twice - would overflow u8 but works with u64.
481        let idx = buffer![0u8, 0].into_array();
482        let result = list.take(idx).unwrap();
483
484        assert_eq!(result.len(), 2);
485
486        #[expect(deprecated)]
487        let result_view = result.to_listview();
488        assert_eq!(result_view.len(), 2);
489        assert!(
490            result_view
491                .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
492                .unwrap()
493        );
494        assert!(
495            result_view
496                .is_valid(1, &mut LEGACY_SESSION.create_execution_ctx())
497                .unwrap()
498        );
499    }
500
501    #[test]
502    fn test_u64_offset_accumulation_nullable() {
503        let elements = buffer![0i32; 150].into_array();
504        let offsets = buffer![0u8, 150, 150].into_array();
505        let validity = BoolArray::from_iter(vec![true, false]).into_array();
506        let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
507            .unwrap()
508            .into_array();
509
510        // Take the same large list twice - would overflow u8 but works with u64.
511        let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).into_array();
512        let result = list.take(idx).unwrap();
513
514        assert_eq!(result.len(), 3);
515
516        #[expect(deprecated)]
517        let result_view = result.to_listview();
518        assert_eq!(result_view.len(), 3);
519        assert!(
520            result_view
521                .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
522                .unwrap()
523        );
524        assert!(
525            result_view
526                .is_invalid(1, &mut LEGACY_SESSION.create_execution_ctx())
527                .unwrap()
528        );
529        assert!(
530            result_view
531                .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
532                .unwrap()
533        );
534    }
535
536    /// Regression test for validity length mismatch bug.
537    ///
538    /// When source array has `Validity::Array(...)` and indices are non-nullable,
539    /// the result validity must have length equal to indices.len(), not source.len().
540    #[test]
541    fn test_take_validity_length_mismatch_regression() {
542        // Source array with explicit validity array (length 2).
543        let list = ListArray::try_new(
544            buffer![1i32, 2, 3, 4].into_array(),
545            buffer![0, 2, 4].into_array(),
546            Validity::Array(BoolArray::from_iter(vec![true, true]).into_array()),
547        )
548        .unwrap()
549        .into_array();
550
551        // Take more indices than source length (4 vs 2) with non-nullable indices.
552        let idx = buffer![0u32, 1, 0, 1].into_array();
553
554        // This should not panic - result should have length 4.
555        let result = list.take(idx).unwrap();
556        assert_eq!(result.len(), 4);
557    }
558}