vortex_array/arrays/list/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBufferMut;
5use vortex_dtype::IntegerPType;
6use vortex_dtype::Nullability;
7use vortex_dtype::match_each_integer_ptype;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_panic;
11use vortex_mask::Mask;
12
13use crate::Array;
14use crate::ArrayRef;
15use crate::ToCanonical;
16use crate::arrays::ListArray;
17use crate::arrays::ListVTable;
18use crate::arrays::PrimitiveArray;
19use crate::builders::ArrayBuilder;
20use crate::builders::PrimitiveBuilder;
21use crate::compute::TakeKernel;
22use crate::compute::TakeKernelAdapter;
23use crate::compute::take;
24use crate::register_kernel;
25use crate::validity::Validity;
26use crate::vtable::ValidityHelper;
27
28// TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call
29// the `ListView::take` compute function once `ListView` is more stable.
30
31/// Take implementation for [`ListArray`].
32///
33/// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant
34/// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
35/// non-contiguous indices would violate this requirement.
36impl TakeKernel for ListVTable {
37    fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
38        let indices = indices.to_primitive();
39        let offsets = array.offsets().to_primitive();
40
41        match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
42            match_each_integer_ptype!(indices.ptype(), |I| {
43                _take::<I, O>(
44                    array,
45                    offsets.as_slice::<O>(),
46                    &indices,
47                    array.validity_mask(),
48                    indices.validity_mask(),
49                )
50            })
51        })
52    }
53}
54
55register_kernel!(TakeKernelAdapter(ListVTable).lift());
56
57fn _take<I: IntegerPType, O: IntegerPType>(
58    array: &ListArray,
59    offsets: &[O],
60    indices_array: &PrimitiveArray,
61    data_validity: Mask,
62    indices_validity_mask: Mask,
63) -> VortexResult<ArrayRef> {
64    let indices: &[I] = indices_array.as_slice::<I>();
65
66    if !indices_validity_mask.all_true() || !data_validity.all_true() {
67        return _take_nullable::<I, O>(
68            array,
69            offsets,
70            indices,
71            data_validity,
72            indices_validity_mask,
73        );
74    }
75
76    let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
77    let mut elements_to_take =
78        PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
79
80    let mut current_offset = O::zero();
81    new_offsets.append_zero();
82
83    for &data_idx in indices {
84        let data_idx = data_idx
85            .to_usize()
86            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
87
88        let start = offsets[data_idx];
89        let stop = offsets[data_idx + 1];
90
91        // Annoyingly, we can't turn (start..end) into a range, so we're doing that manually.
92        //
93        // We could convert start and end to usize, but that would impose a potentially
94        // harder constraint - now we don't care if they fit into usize as long as their
95        // difference does.
96        let additional = (stop - start).to_usize().unwrap_or_else(|| {
97            vortex_panic!("Failed to convert range length to usize: {}", stop - start)
98        });
99
100        elements_to_take.reserve_exact(additional);
101        for i in 0..additional {
102            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
103        }
104        current_offset += stop - start;
105        new_offsets.append_value(current_offset);
106    }
107
108    let elements_to_take = elements_to_take.finish();
109    let new_offsets = new_offsets.finish();
110
111    let new_elements = take(array.elements(), elements_to_take.as_ref())?;
112
113    Ok(ListArray::try_new(
114        new_elements,
115        new_offsets,
116        indices_array
117            .validity()
118            .clone()
119            .and(array.validity().clone()),
120    )?
121    .to_array())
122}
123
124fn _take_nullable<I: IntegerPType, O: IntegerPType>(
125    array: &ListArray,
126    offsets: &[O],
127    indices: &[I],
128    data_validity: Mask,
129    indices_validity: Mask,
130) -> VortexResult<ArrayRef> {
131    let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
132
133    // This will be the indices we push down to the child array to call `take` with.
134    //
135    // There are 2 things to note here:
136    // - We do not know how many elements we need to take from our child since lists are variable
137    //   size: thus we arbitrarily choose a capacity of `2 * # of indices`.
138    // - The type of the primitive builder needs to fit the largest offset of the (parent)
139    //   `ListArray`, so we make this `PrimitiveBuilder` generic over `O` (instead of `I`).
140    let mut elements_to_take =
141        PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
142
143    let mut current_offset = O::zero();
144    new_offsets.append_zero();
145
146    // Set all bits to invalid and selectively set which values are valid.
147    let mut new_validity = BitBufferMut::new_unset(indices.len());
148
149    for (idx, data_idx) in indices.iter().enumerate() {
150        if !indices_validity.value(idx) {
151            new_offsets.append_value(current_offset);
152            // Bit buffer already has this set to invalid.
153            continue;
154        }
155
156        let data_idx = data_idx
157            .to_usize()
158            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
159
160        if !data_validity.value(data_idx) {
161            new_offsets.append_value(current_offset);
162            // Bit buffer already has this set to invalid.
163            continue;
164        }
165
166        let start = offsets[data_idx];
167        let stop = offsets[data_idx + 1];
168
169        // See the note it the `take` on the reasoning
170        let additional = (stop - start).to_usize().unwrap_or_else(|| {
171            vortex_panic!("Failed to convert range length to usize: {}", stop - start)
172        });
173
174        elements_to_take.reserve_exact(additional);
175        for i in 0..additional {
176            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
177        }
178        current_offset += stop - start;
179        new_offsets.append_value(current_offset);
180        new_validity.set(idx);
181    }
182
183    let elements_to_take = elements_to_take.finish();
184    let new_offsets = new_offsets.finish();
185    let new_elements = take(array.elements(), elements_to_take.as_ref())?;
186
187    let new_validity = Validity::from(new_validity.freeze());
188    // data are indexes are nullable, so the final result is also nullable.
189
190    Ok(ListArray::try_new(new_elements, new_offsets, new_validity)?.to_array())
191}
192
193#[cfg(test)]
194mod test {
195    use std::sync::Arc;
196
197    use rstest::rstest;
198    use vortex_buffer::buffer;
199    use vortex_dtype::DType;
200    use vortex_dtype::Nullability;
201    use vortex_dtype::PType::I32;
202    use vortex_scalar::Scalar;
203
204    use crate::Array;
205    use crate::IntoArray as _;
206    use crate::ToCanonical;
207    use crate::arrays::BoolArray;
208    use crate::arrays::PrimitiveArray;
209    use crate::arrays::list::ListArray;
210    use crate::compute::conformance::take::test_take_conformance;
211    use crate::compute::take;
212    use crate::validity::Validity;
213
214    #[test]
215    fn nullable_take() {
216        let list = ListArray::try_new(
217            buffer![0i32, 5, 3, 4].into_array(),
218            buffer![0, 2, 3, 4, 4].into_array(),
219            Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
220        )
221        .unwrap()
222        .to_array();
223
224        let idx =
225            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
226
227        let result = take(&list, &idx).unwrap();
228
229        assert_eq!(
230            result.dtype(),
231            &DType::List(
232                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
233                Nullability::Nullable
234            )
235        );
236
237        let result = result.to_listview();
238
239        assert_eq!(result.len(), 4);
240
241        let element_dtype: Arc<DType> = Arc::new(I32.into());
242
243        assert!(result.is_valid(0));
244        assert_eq!(
245            result.scalar_at(0),
246            Scalar::list(
247                element_dtype.clone(),
248                vec![0i32.into(), 5.into()],
249                Nullability::Nullable
250            )
251        );
252
253        assert!(result.is_invalid(1));
254
255        assert!(result.is_valid(2));
256        assert_eq!(
257            result.scalar_at(2),
258            Scalar::list(
259                element_dtype.clone(),
260                vec![3i32.into()],
261                Nullability::Nullable
262            )
263        );
264
265        assert!(result.is_valid(3));
266        assert_eq!(
267            result.scalar_at(3),
268            Scalar::list(element_dtype, vec![], Nullability::Nullable)
269        );
270    }
271
272    #[test]
273    fn change_validity() {
274        let list = ListArray::try_new(
275            buffer![0i32, 5, 3, 4].into_array(),
276            buffer![0, 2, 3].into_array(),
277            Validity::NonNullable,
278        )
279        .unwrap()
280        .to_array();
281
282        let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
283        // since idx is nullable, the final list will also be nullable
284
285        let result = take(&list, &idx).unwrap();
286        assert_eq!(
287            result.dtype(),
288            &DType::List(
289                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
290                Nullability::Nullable
291            )
292        );
293    }
294
295    #[test]
296    fn non_nullable_take() {
297        let list = ListArray::try_new(
298            buffer![0i32, 5, 3, 4].into_array(),
299            buffer![0, 2, 3, 3, 4].into_array(),
300            Validity::NonNullable,
301        )
302        .unwrap()
303        .to_array();
304
305        let idx = buffer![1, 0, 2].into_array();
306
307        let result = take(&list, &idx).unwrap();
308
309        assert_eq!(
310            result.dtype(),
311            &DType::List(
312                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
313                Nullability::NonNullable
314            )
315        );
316
317        let result = result.to_listview();
318
319        assert_eq!(result.len(), 3);
320
321        let element_dtype: Arc<DType> = Arc::new(I32.into());
322
323        assert!(result.is_valid(0));
324        assert_eq!(
325            result.scalar_at(0),
326            Scalar::list(
327                element_dtype.clone(),
328                vec![3i32.into()],
329                Nullability::NonNullable
330            )
331        );
332
333        assert!(result.is_valid(1));
334        assert_eq!(
335            result.scalar_at(1),
336            Scalar::list(
337                element_dtype.clone(),
338                vec![0i32.into(), 5.into()],
339                Nullability::NonNullable
340            )
341        );
342
343        assert!(result.is_valid(2));
344        assert_eq!(
345            result.scalar_at(2),
346            Scalar::list(element_dtype, vec![], Nullability::NonNullable)
347        );
348    }
349
350    #[test]
351    fn test_take_empty_array() {
352        let list = ListArray::try_new(
353            buffer![0i32, 5, 3, 4].into_array(),
354            buffer![0].into_array(),
355            Validity::NonNullable,
356        )
357        .unwrap()
358        .to_array();
359
360        let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
361
362        let result = take(&list, &idx).unwrap();
363        assert_eq!(
364            result.dtype(),
365            &DType::List(
366                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
367                Nullability::Nullable
368            )
369        );
370        assert_eq!(result.len(), 0,);
371    }
372
373    #[rstest]
374    #[case(ListArray::try_new(
375        buffer![0i32, 1, 2, 3, 4, 5].into_array(),
376        buffer![0, 2, 3, 5, 5, 6].into_array(),
377        Validity::NonNullable,
378    ).unwrap())]
379    #[case(ListArray::try_new(
380        buffer![10i32, 20, 30, 40, 50].into_array(),
381        buffer![0, 2, 3, 4, 5].into_array(),
382        Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).to_array()),
383    ).unwrap())]
384    #[case(ListArray::try_new(
385        buffer![1i32, 2, 3].into_array(),
386        buffer![0, 0, 2, 2, 3].into_array(), // First and third are empty
387        Validity::NonNullable,
388    ).unwrap())]
389    #[case(ListArray::try_new(
390        buffer![42i32, 43].into_array(),
391        buffer![0, 2].into_array(),
392        Validity::NonNullable,
393    ).unwrap())]
394    #[case({
395        let elements = buffer![0i32..200].into_array();
396        let mut offsets = vec![0u64];
397        for i in 1..=50 {
398            offsets.push(offsets[i - 1] + (i as u64 % 5)); // Variable length lists
399        }
400        ListArray::try_new(
401            elements,
402            PrimitiveArray::from_iter(offsets).to_array(),
403            Validity::NonNullable,
404        ).unwrap()
405    })]
406    #[case(ListArray::try_new(
407        PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).to_array(),
408        buffer![0, 2, 3, 5].into_array(),
409        Validity::NonNullable,
410    ).unwrap())]
411    fn test_take_list_conformance(#[case] list: ListArray) {
412        test_take_conformance(list.as_ref());
413    }
414}