Skip to main content

vortex_array/arrays/list/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::array::ArrayView;
10use crate::arrays::List;
11use crate::arrays::ListArray;
12use crate::arrays::Primitive;
13use crate::arrays::PrimitiveArray;
14use crate::arrays::dict::TakeExecute;
15use crate::arrays::list::ListArrayExt;
16use crate::arrays::primitive::PrimitiveArrayExt;
17use crate::builders::ArrayBuilder;
18use crate::builders::PrimitiveBuilder;
19use crate::dtype::IntegerPType;
20use crate::dtype::Nullability;
21use crate::executor::ExecutionCtx;
22use crate::match_each_integer_ptype;
23use crate::match_smallest_offset_type;
24
25// TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call
26// the `ListView::take` compute function once `ListView` is more stable.
27
28impl TakeExecute for List {
29    /// Take implementation for [`ListArray`].
30    ///
31    /// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant
32    /// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
33    /// non-contiguous indices would violate this requirement.
34    #[expect(clippy::cognitive_complexity)]
35    fn take(
36        array: ArrayView<'_, List>,
37        indices: &ArrayRef,
38        ctx: &mut ExecutionCtx,
39    ) -> VortexResult<Option<ArrayRef>> {
40        let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
41        // This is an over-approximation of the total number of elements in the resulting array.
42        let total_approx = array.elements().len().saturating_mul(indices.len());
43
44        match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| {
45            match_each_integer_ptype!(indices.ptype(), |I| {
46                match_smallest_offset_type!(total_approx, |OutputOffsetType| {
47                    {
48                        let indices = indices.as_view();
49                        _take::<I, O, OutputOffsetType>(array, indices, ctx).map(Some)
50                    }
51                })
52            })
53        })
54    }
55}
56
57fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
58    array: ArrayView<'_, List>,
59    indices_array: ArrayView<'_, Primitive>,
60    ctx: &mut ExecutionCtx,
61) -> VortexResult<ArrayRef> {
62    let data_validity = array.list_validity().to_mask(array.as_ref().len(), ctx)?;
63    let indices_validity = indices_array
64        .validity()
65        .vortex_expect("Failed to compute validity mask")
66        .to_mask(indices_array.as_ref().len(), ctx)?;
67
68    if !indices_validity.all_true() || !data_validity.all_true() {
69        return _take_nullable::<I, O, OutputOffsetType>(array, indices_array, ctx);
70    }
71
72    let offsets_array = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
73    let offsets: &[O] = offsets_array.as_slice();
74    let indices: &[I] = indices_array.as_slice();
75
76    let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
77        Nullability::NonNullable,
78        indices.len(),
79    );
80    let mut elements_to_take =
81        PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
82
83    let mut current_offset = OutputOffsetType::zero();
84    new_offsets.append_zero();
85
86    for &data_idx in indices {
87        let data_idx: usize = data_idx.as_();
88
89        let start = offsets[data_idx];
90        let stop = offsets[data_idx + 1];
91
92        // Annoyingly, we can't turn (start..end) into a range, so we're doing that manually.
93        //
94        // We could convert start and end to usize, but that would impose a potentially
95        // harder constraint - now we don't care if they fit into usize as long as their
96        // difference does.
97        let additional: usize = (stop - start).as_();
98
99        // TODO(0ax1): optimize this
100        elements_to_take.reserve_exact(additional);
101        for i in 0..additional {
102            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
103        }
104        current_offset +=
105            OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
106        new_offsets.append_value(current_offset);
107    }
108
109    let elements_to_take = elements_to_take.finish();
110    let new_offsets = new_offsets.finish();
111
112    let new_elements = array.elements().take(elements_to_take)?;
113
114    Ok(ListArray::try_new(
115        new_elements,
116        new_offsets,
117        array.validity()?.take(indices_array.array())?,
118    )?
119    .into_array())
120}
121
122fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
123    array: ArrayView<'_, List>,
124    indices_array: ArrayView<'_, Primitive>,
125    ctx: &mut ExecutionCtx,
126) -> VortexResult<ArrayRef> {
127    let offsets_array = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
128    let offsets: &[O] = offsets_array.as_slice();
129    let indices: &[I] = indices_array.as_slice();
130    let data_validity = array.list_validity().to_mask(array.as_ref().len(), ctx)?;
131    let indices_validity = indices_array
132        .validity()
133        .vortex_expect("Failed to compute validity mask")
134        .to_mask(indices_array.as_ref().len(), ctx)?;
135
136    let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
137        Nullability::NonNullable,
138        indices.len(),
139    );
140
141    // This will be the indices we push down to the child array to call `take` with.
142    //
143    // There are 2 things to note here:
144    // - We do not know how many elements we need to take from our child since lists are variable
145    //   size: thus we arbitrarily choose a capacity of `2 * # of indices`.
146    // - The type of the primitive builder needs to fit the largest offset of the (parent)
147    //   `ListArray`, so we make this `PrimitiveBuilder` generic over `O` (instead of `I`).
148    let mut elements_to_take =
149        PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
150
151    let mut current_offset = OutputOffsetType::zero();
152    new_offsets.append_zero();
153
154    for (idx, data_idx) in indices.iter().enumerate() {
155        if !indices_validity.value(idx) {
156            new_offsets.append_value(current_offset);
157            continue;
158        }
159
160        let data_idx: usize = data_idx.as_();
161
162        if !data_validity.value(data_idx) {
163            new_offsets.append_value(current_offset);
164            continue;
165        }
166
167        let start = offsets[data_idx];
168        let stop = offsets[data_idx + 1];
169
170        // See the note in `_take` on the reasoning.
171        let additional: usize = (stop - start).as_();
172
173        elements_to_take.reserve_exact(additional);
174        for i in 0..additional {
175            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
176        }
177        current_offset +=
178            OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
179        new_offsets.append_value(current_offset);
180    }
181
182    let elements_to_take = elements_to_take.finish();
183    let new_offsets = new_offsets.finish();
184    let new_elements = array.elements().take(elements_to_take)?;
185
186    Ok(ListArray::try_new(
187        new_elements,
188        new_offsets,
189        array.validity()?.take(indices_array.array())?,
190    )?
191    .into_array())
192}
193
194#[cfg(test)]
195mod test {
196    use std::sync::Arc;
197
198    use rstest::rstest;
199    use vortex_buffer::buffer;
200
201    use crate::IntoArray as _;
202    use crate::LEGACY_SESSION;
203    use crate::ToCanonical;
204    use crate::VortexSessionExecute;
205    use crate::arrays::BoolArray;
206    use crate::arrays::ListArray;
207    use crate::arrays::PrimitiveArray;
208    use crate::compute::conformance::take::test_take_conformance;
209    use crate::dtype::DType;
210    use crate::dtype::Nullability;
211    use crate::dtype::PType::I32;
212    use crate::scalar::Scalar;
213    use crate::validity::Validity;
214
215    #[test]
216    fn nullable_take() {
217        let list = ListArray::try_new(
218            buffer![0i32, 5, 3, 4].into_array(),
219            buffer![0, 2, 3, 4, 4].into_array(),
220            Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).into_array()),
221        )
222        .unwrap()
223        .into_array();
224
225        let idx =
226            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).into_array();
227
228        let result = list.take(idx).unwrap();
229
230        assert_eq!(
231            result.dtype(),
232            &DType::List(
233                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
234                Nullability::Nullable
235            )
236        );
237
238        let result = result.to_listview();
239
240        assert_eq!(result.len(), 4);
241
242        let element_dtype: Arc<DType> = Arc::new(I32.into());
243
244        assert!(
245            result
246                .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
247                .unwrap()
248        );
249        assert_eq!(
250            result
251                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
252                .unwrap(),
253            Scalar::list(
254                Arc::clone(&element_dtype),
255                vec![0i32.into(), 5.into()],
256                Nullability::Nullable
257            )
258        );
259
260        assert!(
261            result
262                .is_invalid(1, &mut LEGACY_SESSION.create_execution_ctx())
263                .unwrap()
264        );
265
266        assert!(
267            result
268                .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
269                .unwrap()
270        );
271        assert_eq!(
272            result
273                .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
274                .unwrap(),
275            Scalar::list(
276                Arc::clone(&element_dtype),
277                vec![3i32.into()],
278                Nullability::Nullable
279            )
280        );
281
282        assert!(
283            result
284                .is_valid(3, &mut LEGACY_SESSION.create_execution_ctx())
285                .unwrap()
286        );
287        assert_eq!(
288            result
289                .execute_scalar(3, &mut LEGACY_SESSION.create_execution_ctx())
290                .unwrap(),
291            Scalar::list(element_dtype, vec![], Nullability::Nullable)
292        );
293    }
294
295    #[test]
296    fn change_validity() {
297        let list = ListArray::try_new(
298            buffer![0i32, 5, 3, 4].into_array(),
299            buffer![0, 2, 3].into_array(),
300            Validity::NonNullable,
301        )
302        .unwrap()
303        .into_array();
304
305        let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).into_array();
306        // since idx is nullable, the final list will also be nullable
307
308        let result = list.take(idx).unwrap();
309        assert_eq!(
310            result.dtype(),
311            &DType::List(
312                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
313                Nullability::Nullable
314            )
315        );
316    }
317
318    #[test]
319    fn non_nullable_take() {
320        let list = ListArray::try_new(
321            buffer![0i32, 5, 3, 4].into_array(),
322            buffer![0, 2, 3, 3, 4].into_array(),
323            Validity::NonNullable,
324        )
325        .unwrap()
326        .into_array();
327
328        let idx = buffer![1, 0, 2].into_array();
329
330        let result = list.take(idx).unwrap();
331
332        assert_eq!(
333            result.dtype(),
334            &DType::List(
335                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
336                Nullability::NonNullable
337            )
338        );
339
340        let result = result.to_listview();
341
342        assert_eq!(result.len(), 3);
343
344        let element_dtype: Arc<DType> = Arc::new(I32.into());
345
346        assert!(
347            result
348                .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
349                .unwrap()
350        );
351        assert_eq!(
352            result
353                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
354                .unwrap(),
355            Scalar::list(
356                Arc::clone(&element_dtype),
357                vec![3i32.into()],
358                Nullability::NonNullable
359            )
360        );
361
362        assert!(
363            result
364                .is_valid(1, &mut LEGACY_SESSION.create_execution_ctx())
365                .unwrap()
366        );
367        assert_eq!(
368            result
369                .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())
370                .unwrap(),
371            Scalar::list(
372                Arc::clone(&element_dtype),
373                vec![0i32.into(), 5.into()],
374                Nullability::NonNullable
375            )
376        );
377
378        assert!(
379            result
380                .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
381                .unwrap()
382        );
383        assert_eq!(
384            result
385                .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
386                .unwrap(),
387            Scalar::list(element_dtype, vec![], Nullability::NonNullable)
388        );
389    }
390
391    #[test]
392    fn test_take_empty_array() {
393        let list = ListArray::try_new(
394            buffer![0i32, 5, 3, 4].into_array(),
395            buffer![0].into_array(),
396            Validity::NonNullable,
397        )
398        .unwrap()
399        .into_array();
400
401        let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).into_array();
402
403        let result = list.take(idx).unwrap();
404        assert_eq!(
405            result.dtype(),
406            &DType::List(
407                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
408                Nullability::Nullable
409            )
410        );
411        assert_eq!(result.len(), 0,);
412    }
413
414    #[rstest]
415    #[case(ListArray::try_new(
416        buffer![0i32, 1, 2, 3, 4, 5].into_array(),
417        buffer![0, 2, 3, 5, 5, 6].into_array(),
418        Validity::NonNullable,
419    ).unwrap())]
420    #[case(ListArray::try_new(
421        buffer![10i32, 20, 30, 40, 50].into_array(),
422        buffer![0, 2, 3, 4, 5].into_array(),
423        Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).into_array()),
424    ).unwrap())]
425    #[case(ListArray::try_new(
426        buffer![1i32, 2, 3].into_array(),
427        buffer![0, 0, 2, 2, 3].into_array(), // First and third are empty
428        Validity::NonNullable,
429    ).unwrap())]
430    #[case(ListArray::try_new(
431        buffer![42i32, 43].into_array(),
432        buffer![0, 2].into_array(),
433        Validity::NonNullable,
434    ).unwrap())]
435    #[case({
436        let elements = buffer![0i32..200].into_array();
437        let mut offsets = vec![0u64];
438        for i in 1..=50 {
439            offsets.push(offsets[i - 1] + (i as u64 % 5)); // Variable length lists
440        }
441        ListArray::try_new(
442            elements,
443            PrimitiveArray::from_iter(offsets).into_array(),
444            Validity::NonNullable,
445        ).unwrap()
446    })]
447    #[case(ListArray::try_new(
448        PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).into_array(),
449        buffer![0, 2, 3, 5].into_array(),
450        Validity::NonNullable,
451    ).unwrap())]
452    fn test_take_list_conformance(#[case] list: ListArray) {
453        test_take_conformance(&list.into_array());
454    }
455
456    #[test]
457    fn test_u64_offset_accumulation_non_nullable() {
458        let elements = buffer![0i32; 200].into_array();
459        let offsets = buffer![0u8, 200].into_array();
460        let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
461            .unwrap()
462            .into_array();
463
464        // Take the same large list twice - would overflow u8 but works with u64.
465        let idx = buffer![0u8, 0].into_array();
466        let result = list.take(idx).unwrap();
467
468        assert_eq!(result.len(), 2);
469
470        let result_view = result.to_listview();
471        assert_eq!(result_view.len(), 2);
472        assert!(
473            result_view
474                .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
475                .unwrap()
476        );
477        assert!(
478            result_view
479                .is_valid(1, &mut LEGACY_SESSION.create_execution_ctx())
480                .unwrap()
481        );
482    }
483
484    #[test]
485    fn test_u64_offset_accumulation_nullable() {
486        let elements = buffer![0i32; 150].into_array();
487        let offsets = buffer![0u8, 150, 150].into_array();
488        let validity = BoolArray::from_iter(vec![true, false]).into_array();
489        let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
490            .unwrap()
491            .into_array();
492
493        // Take the same large list twice - would overflow u8 but works with u64.
494        let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).into_array();
495        let result = list.take(idx).unwrap();
496
497        assert_eq!(result.len(), 3);
498
499        let result_view = result.to_listview();
500        assert_eq!(result_view.len(), 3);
501        assert!(
502            result_view
503                .is_valid(0, &mut LEGACY_SESSION.create_execution_ctx())
504                .unwrap()
505        );
506        assert!(
507            result_view
508                .is_invalid(1, &mut LEGACY_SESSION.create_execution_ctx())
509                .unwrap()
510        );
511        assert!(
512            result_view
513                .is_valid(2, &mut LEGACY_SESSION.create_execution_ctx())
514                .unwrap()
515        );
516    }
517
518    /// Regression test for validity length mismatch bug.
519    ///
520    /// When source array has `Validity::Array(...)` and indices are non-nullable,
521    /// the result validity must have length equal to indices.len(), not source.len().
522    #[test]
523    fn test_take_validity_length_mismatch_regression() {
524        // Source array with explicit validity array (length 2).
525        let list = ListArray::try_new(
526            buffer![1i32, 2, 3, 4].into_array(),
527            buffer![0, 2, 4].into_array(),
528            Validity::Array(BoolArray::from_iter(vec![true, true]).into_array()),
529        )
530        .unwrap()
531        .into_array();
532
533        // Take more indices than source length (4 vs 2) with non-nullable indices.
534        let idx = buffer![0u32, 1, 0, 1].into_array();
535
536        // This should not panic - result should have length 4.
537        let result = list.take(idx).unwrap();
538        assert_eq!(result.len(), 4);
539    }
540}