vortex_array/arrays/list/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::BooleanBufferBuilder;
5use num_traits::PrimInt;
6use vortex_dtype::{NativePType, Nullability, match_each_integer_ptype};
7use vortex_error::{VortexExpect, VortexResult, vortex_panic};
8use vortex_mask::Mask;
9
10use crate::arrays::{ListArray, ListVTable, PrimitiveArray};
11use crate::builders::{ArrayBuilder, PrimitiveBuilder};
12use crate::compute::{TakeKernel, TakeKernelAdapter, take};
13use crate::validity::Validity;
14use crate::vtable::ValidityHelper;
15use crate::{Array, ArrayRef, OffsetPType, ToCanonical, register_kernel};
16
17impl TakeKernel for ListVTable {
18    fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
19        let indices = indices.to_primitive();
20        let offsets = array.offsets().to_primitive();
21
22        match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
23            match_each_integer_ptype!(indices.ptype(), |I| {
24                _take::<I, O>(
25                    array,
26                    offsets.as_slice::<O>(),
27                    &indices,
28                    array.validity_mask(),
29                    indices.validity_mask(),
30                )
31            })
32        })
33    }
34}
35
36register_kernel!(TakeKernelAdapter(ListVTable).lift());
37
38fn _take<I: NativePType, O: OffsetPType + NativePType + PrimInt>(
39    array: &ListArray,
40    offsets: &[O],
41    indices_array: &PrimitiveArray,
42    data_validity: Mask,
43    indices_validity_mask: Mask,
44) -> VortexResult<ArrayRef> {
45    let indices: &[I] = indices_array.as_slice::<I>();
46
47    if !indices_validity_mask.all_true() || !data_validity.all_true() {
48        return _take_nullable::<I, O>(
49            array,
50            offsets,
51            indices,
52            data_validity,
53            indices_validity_mask,
54        );
55    }
56
57    let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
58    let mut elements_to_take =
59        PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
60
61    let mut current_offset = O::zero();
62    new_offsets.append_zero();
63
64    for &data_idx in indices {
65        let data_idx = data_idx
66            .to_usize()
67            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
68
69        let start = offsets[data_idx];
70        let stop = offsets[data_idx + 1];
71
72        // Annoyingly, we can't turn (start..end) into a range, so we're doing that manually.
73        //
74        // We could convert start and end to usize, but that would impose a potentially
75        // harder constraint - now we don't care if they fit into usize as long as their
76        // difference does.
77        let additional = (stop - start).to_usize().unwrap_or_else(|| {
78            vortex_panic!("Failed to convert range length to usize: {}", stop - start)
79        });
80
81        elements_to_take.ensure_capacity(elements_to_take.len() + additional);
82        for i in 0..additional {
83            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
84        }
85        current_offset = current_offset + (stop - start);
86        new_offsets.append_value(current_offset);
87    }
88
89    let elements_to_take = elements_to_take.finish();
90    let new_offsets = new_offsets.finish();
91
92    let new_elements = take(array.elements(), elements_to_take.as_ref())?;
93
94    Ok(ListArray::try_new(
95        new_elements,
96        new_offsets,
97        indices_array
98            .validity()
99            .clone()
100            .and(array.validity().clone()),
101    )?
102    .to_array())
103}
104
105fn _take_nullable<I: NativePType, O: OffsetPType + NativePType + PrimInt>(
106    array: &ListArray,
107    offsets: &[O],
108    indices: &[I],
109    data_validity: Mask,
110    indices_validity: Mask,
111) -> VortexResult<ArrayRef> {
112    let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
113
114    // This will be the indices we push down to the child array to call `take` with.
115    //
116    // There are 2 things to note here:
117    // - We do not know how many elements we need to take from our child since lists are variable
118    //   size: thus we arbitrarily choose a capacity of `2 * # of indices`.
119    // - The type of the primitive builder needs to fit the largest offset of the (parent)
120    //   `ListArray`, so we make this `PrimitiveBuilder` generic over `O` (instead of `I`).
121    let mut elements_to_take =
122        PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
123
124    let mut current_offset = O::zero();
125    new_offsets.append_zero();
126
127    let mut new_validity = BooleanBufferBuilder::new(indices.len());
128
129    for (idx, data_idx) in indices.iter().enumerate() {
130        if !indices_validity.value(idx) {
131            new_offsets.append_value(current_offset);
132            new_validity.append(false);
133            continue;
134        }
135
136        let data_idx = data_idx
137            .to_usize()
138            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
139
140        if data_validity.value(data_idx) {
141            let start = offsets[data_idx];
142            let stop = offsets[data_idx + 1];
143
144            // See the note it the `take` on the reasoning
145            let additional = (stop - start).to_usize().unwrap_or_else(|| {
146                vortex_panic!("Failed to convert range length to usize: {}", stop - start)
147            });
148
149            elements_to_take.ensure_capacity(elements_to_take.len() + additional);
150            for i in 0..additional {
151                elements_to_take
152                    .append_value(start + O::from_usize(i).vortex_expect("i < additional"));
153            }
154            current_offset = current_offset + (stop - start);
155            new_offsets.append_value(current_offset);
156            new_validity.append(true);
157        } else {
158            new_offsets.append_value(current_offset);
159            new_validity.append(false);
160        }
161    }
162
163    let elements_to_take = elements_to_take.finish();
164    let new_offsets = new_offsets.finish();
165    let new_elements = take(array.elements(), elements_to_take.as_ref())?;
166
167    let new_validity: Validity = Validity::from(new_validity.finish());
168    // data are indexes are nullable, so the final result is also nullable.
169
170    Ok(ListArray::try_new(new_elements, new_offsets, new_validity)?.to_array())
171}
172
173#[cfg(test)]
174mod test {
175    use std::sync::Arc;
176
177    use rstest::rstest;
178    use vortex_buffer::buffer;
179    use vortex_dtype::PType::I32;
180    use vortex_dtype::{DType, Nullability};
181    use vortex_scalar::Scalar;
182
183    use crate::arrays::list::ListArray;
184    use crate::arrays::{BoolArray, PrimitiveArray};
185    use crate::compute::conformance::take::test_take_conformance;
186    use crate::compute::take;
187    use crate::validity::Validity;
188    use crate::{Array, IntoArray as _, ToCanonical};
189
190    #[test]
191    fn nullable_take() {
192        let list = ListArray::try_new(
193            buffer![0i32, 5, 3, 4].into_array(),
194            buffer![0, 2, 3, 4, 4].into_array(),
195            Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
196        )
197        .unwrap()
198        .to_array();
199
200        let idx =
201            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
202
203        let result = take(&list, &idx).unwrap();
204
205        assert_eq!(
206            result.dtype(),
207            &DType::List(
208                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
209                Nullability::Nullable
210            )
211        );
212
213        let result = result.to_list();
214
215        assert_eq!(result.len(), 4);
216
217        let element_dtype: Arc<DType> = Arc::new(I32.into());
218
219        assert!(result.is_valid(0));
220        assert_eq!(
221            result.scalar_at(0),
222            Scalar::list(
223                element_dtype.clone(),
224                vec![0i32.into(), 5.into()],
225                Nullability::Nullable
226            )
227        );
228
229        assert!(result.is_invalid(1));
230
231        assert!(result.is_valid(2));
232        assert_eq!(
233            result.scalar_at(2),
234            Scalar::list(
235                element_dtype.clone(),
236                vec![3i32.into()],
237                Nullability::Nullable
238            )
239        );
240
241        assert!(result.is_valid(3));
242        assert_eq!(
243            result.scalar_at(3),
244            Scalar::list(element_dtype, vec![], Nullability::Nullable)
245        );
246    }
247
248    #[test]
249    fn change_validity() {
250        let list = ListArray::try_new(
251            buffer![0i32, 5, 3, 4].into_array(),
252            buffer![0, 2, 3].into_array(),
253            Validity::NonNullable,
254        )
255        .unwrap()
256        .to_array();
257
258        let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
259        // since idx is nullable, the final list will also be nullable
260
261        let result = take(&list, &idx).unwrap();
262        assert_eq!(
263            result.dtype(),
264            &DType::List(
265                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
266                Nullability::Nullable
267            )
268        );
269    }
270
271    #[test]
272    fn non_nullable_take() {
273        let list = ListArray::try_new(
274            buffer![0i32, 5, 3, 4].into_array(),
275            buffer![0, 2, 3, 3, 4].into_array(),
276            Validity::NonNullable,
277        )
278        .unwrap()
279        .to_array();
280
281        let idx = buffer![1, 0, 2].into_array();
282
283        let result = take(&list, &idx).unwrap();
284
285        assert_eq!(
286            result.dtype(),
287            &DType::List(
288                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
289                Nullability::NonNullable
290            )
291        );
292
293        let result = result.to_list();
294
295        assert_eq!(result.len(), 3);
296
297        let element_dtype: Arc<DType> = Arc::new(I32.into());
298
299        assert!(result.is_valid(0));
300        assert_eq!(
301            result.scalar_at(0),
302            Scalar::list(
303                element_dtype.clone(),
304                vec![3i32.into()],
305                Nullability::NonNullable
306            )
307        );
308
309        assert!(result.is_valid(1));
310        assert_eq!(
311            result.scalar_at(1),
312            Scalar::list(
313                element_dtype.clone(),
314                vec![0i32.into(), 5.into()],
315                Nullability::NonNullable
316            )
317        );
318
319        assert!(result.is_valid(2));
320        assert_eq!(
321            result.scalar_at(2),
322            Scalar::list(element_dtype, vec![], Nullability::NonNullable)
323        );
324    }
325
326    #[test]
327    fn test_take_empty_array() {
328        let list = ListArray::try_new(
329            buffer![0i32, 5, 3, 4].into_array(),
330            buffer![0].into_array(),
331            Validity::NonNullable,
332        )
333        .unwrap()
334        .to_array();
335
336        let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
337
338        let result = take(&list, &idx).unwrap();
339        assert_eq!(
340            result.dtype(),
341            &DType::List(
342                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
343                Nullability::Nullable
344            )
345        );
346        assert_eq!(result.len(), 0,);
347    }
348
349    #[rstest]
350    #[case(ListArray::try_new(
351        buffer![0i32, 1, 2, 3, 4, 5].into_array(),
352        buffer![0, 2, 3, 5, 5, 6].into_array(),
353        Validity::NonNullable,
354    ).unwrap())]
355    #[case(ListArray::try_new(
356        buffer![10i32, 20, 30, 40, 50].into_array(),
357        buffer![0, 2, 3, 4, 5].into_array(),
358        Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).to_array()),
359    ).unwrap())]
360    #[case(ListArray::try_new(
361        buffer![1i32, 2, 3].into_array(),
362        buffer![0, 0, 2, 2, 3].into_array(), // First and third are empty
363        Validity::NonNullable,
364    ).unwrap())]
365    #[case(ListArray::try_new(
366        buffer![42i32, 43].into_array(),
367        buffer![0, 2].into_array(),
368        Validity::NonNullable,
369    ).unwrap())]
370    #[case({
371        let elements = buffer![0i32..200].into_array();
372        let mut offsets = vec![0u64];
373        for i in 1..=50 {
374            offsets.push(offsets[i - 1] + (i as u64 % 5)); // Variable length lists
375        }
376        ListArray::try_new(
377            elements,
378            PrimitiveArray::from_iter(offsets).to_array(),
379            Validity::NonNullable,
380        ).unwrap()
381    })]
382    #[case(ListArray::try_new(
383        PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).to_array(),
384        buffer![0, 2, 3, 5].into_array(),
385        Validity::NonNullable,
386    ).unwrap())]
387    fn test_take_list_conformance(#[case] list: ListArray) {
388        test_take_conformance(list.as_ref());
389    }
390}