vortex_array/arrays/list/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::BooleanBufferBuilder;
5use num_traits::PrimInt;
6use vortex_dtype::{NativePType, Nullability, match_each_integer_ptype};
7use vortex_error::{VortexExpect, VortexResult, vortex_panic};
8use vortex_mask::Mask;
9
10use crate::arrays::{ListArray, ListVTable, OffsetPType, PrimitiveArray};
11use crate::builders::{ArrayBuilder, PrimitiveBuilder};
12use crate::compute::{TakeKernel, TakeKernelAdapter, take};
13use crate::validity::Validity;
14use crate::vtable::ValidityHelper;
15use crate::{Array, ArrayRef, ToCanonical, register_kernel};
16
17impl TakeKernel for ListVTable {
18    fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
19        let indices = indices.to_primitive()?;
20        let offsets = array.offsets().to_primitive()?;
21
22        match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
23            match_each_integer_ptype!(indices.ptype(), |I| {
24                _take::<I, O>(
25                    array,
26                    offsets.as_slice::<O>(),
27                    &indices,
28                    array.validity_mask()?,
29                    indices.validity_mask()?,
30                )
31            })
32        })
33    }
34}
35
36register_kernel!(TakeKernelAdapter(ListVTable).lift());
37
38fn _take<I: NativePType, O: OffsetPType + NativePType + PrimInt>(
39    array: &ListArray,
40    offsets: &[O],
41    indices_array: &PrimitiveArray,
42    data_validity: Mask,
43    indices_validity_mask: Mask,
44) -> VortexResult<ArrayRef> {
45    let indices: &[I] = indices_array.as_slice::<I>();
46
47    if !indices_validity_mask.all_true() || !data_validity.all_true() {
48        return _take_nullable::<I, O>(
49            array,
50            offsets,
51            indices,
52            data_validity,
53            indices_validity_mask,
54        );
55    }
56
57    let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
58    let mut elements_to_take =
59        PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
60
61    let mut current_offset = O::zero();
62    new_offsets.append_zero();
63
64    for &data_idx in indices {
65        let data_idx = data_idx
66            .to_usize()
67            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
68
69        let start = offsets[data_idx];
70        let stop = offsets[data_idx + 1];
71
72        // Annoyingly, we can't turn (start..end) into a range, so we're doing that manually.
73        //
74        // We could convert start and end to usize, but that would impose a potentially
75        // harder constraint - now we don't care if they fit into usize as long as their
76        // difference does.
77        let additional = (stop - start).to_usize().unwrap_or_else(|| {
78            vortex_panic!("Failed to convert range length to usize: {}", stop - start)
79        });
80
81        elements_to_take.ensure_capacity(elements_to_take.len() + additional);
82        for i in 0..additional {
83            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
84        }
85        current_offset = current_offset + (stop - start);
86        new_offsets.append_value(current_offset);
87    }
88
89    let elements_to_take = elements_to_take.finish();
90    let new_offsets = new_offsets.finish();
91
92    let new_elements = take(array.elements(), elements_to_take.as_ref())?;
93
94    Ok(ListArray::try_new(
95        new_elements,
96        new_offsets,
97        indices_array
98            .validity()
99            .clone()
100            .and(array.validity().clone())?,
101    )?
102    .to_array())
103}
104
105fn _take_nullable<I: NativePType, O: OffsetPType + NativePType + PrimInt>(
106    array: &ListArray,
107    offsets: &[O],
108    indices: &[I],
109    data_validity: Mask,
110    indices_validity: Mask,
111) -> VortexResult<ArrayRef> {
112    let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
113    let mut elements_to_take =
114        PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
115
116    let mut current_offset = O::zero();
117    new_offsets.append_zero();
118    let mut new_validity = BooleanBufferBuilder::new(2 * indices.len());
119
120    for (idx, data_idx) in indices.iter().enumerate() {
121        if !indices_validity.value(idx) {
122            new_offsets.append_value(current_offset);
123            new_validity.append(false);
124            continue;
125        }
126
127        let data_idx = data_idx
128            .to_usize()
129            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
130
131        if data_validity.value(data_idx) {
132            let start = offsets[data_idx];
133            let stop = offsets[data_idx + 1];
134
135            // See the note it the `take` on the reasoning
136            let additional = (stop - start).to_usize().unwrap_or_else(|| {
137                vortex_panic!("Failed to convert range length to usize: {}", stop - start)
138            });
139
140            elements_to_take.ensure_capacity(elements_to_take.len() + additional);
141            for i in 0..additional {
142                elements_to_take
143                    .append_value(start + O::from_usize(i).vortex_expect("i < additional"));
144            }
145            current_offset = current_offset + (stop - start);
146            new_offsets.append_value(current_offset);
147            new_validity.append(true);
148        } else {
149            new_offsets.append_value(current_offset);
150            new_validity.append(false);
151        }
152    }
153
154    let elements_to_take = elements_to_take.finish();
155    let new_offsets = new_offsets.finish();
156    let new_elements = take(array.elements(), elements_to_take.as_ref())?;
157
158    let new_validity: Validity = Validity::from(new_validity.finish());
159    // data are indexes are nullable, so the final result is also nullable.
160
161    Ok(ListArray::try_new(new_elements, new_offsets, new_validity)?.to_array())
162}
163
164#[cfg(test)]
165mod test {
166    use std::sync::Arc;
167
168    use rstest::rstest;
169    use vortex_dtype::PType::I32;
170    use vortex_dtype::{DType, Nullability};
171    use vortex_scalar::Scalar;
172
173    use crate::arrays::list::ListArray;
174    use crate::arrays::{BoolArray, PrimitiveArray};
175    use crate::compute::conformance::take::test_take_conformance;
176    use crate::compute::take;
177    use crate::validity::Validity;
178    use crate::{Array, ToCanonical};
179
180    #[test]
181    fn nullable_take() {
182        let list = ListArray::try_new(
183            PrimitiveArray::from_iter([0i32, 5, 3, 4]).to_array(),
184            PrimitiveArray::from_iter([0, 2, 3, 4, 4]).to_array(),
185            Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
186        )
187        .unwrap()
188        .to_array();
189
190        let idx =
191            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
192
193        let result = take(&list, &idx).unwrap();
194
195        assert_eq!(
196            result.dtype(),
197            &DType::List(
198                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
199                Nullability::Nullable
200            )
201        );
202
203        let result = result.to_list().unwrap();
204
205        assert_eq!(result.len(), 4);
206
207        let element_dtype: Arc<DType> = Arc::new(I32.into());
208
209        assert!(result.is_valid(0).unwrap());
210        assert_eq!(
211            result.scalar_at(0),
212            Scalar::list(
213                element_dtype.clone(),
214                vec![0i32.into(), 5.into()],
215                Nullability::Nullable
216            )
217        );
218
219        assert!(result.is_invalid(1).unwrap());
220
221        assert!(result.is_valid(2).unwrap());
222        assert_eq!(
223            result.scalar_at(2),
224            Scalar::list(
225                element_dtype.clone(),
226                vec![3i32.into()],
227                Nullability::Nullable
228            )
229        );
230
231        assert!(result.is_valid(3).unwrap());
232        assert_eq!(
233            result.scalar_at(3),
234            Scalar::list(element_dtype, vec![], Nullability::Nullable)
235        );
236    }
237
238    #[test]
239    fn change_validity() {
240        let list = ListArray::try_new(
241            PrimitiveArray::from_iter([0i32, 5, 3, 4]).to_array(),
242            PrimitiveArray::from_iter([0, 2, 3]).to_array(),
243            Validity::NonNullable,
244        )
245        .unwrap()
246        .to_array();
247
248        let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
249        // since idx is nullable, the final list will also be nullable
250
251        let result = take(&list, &idx).unwrap();
252        assert_eq!(
253            result.dtype(),
254            &DType::List(
255                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
256                Nullability::Nullable
257            )
258        );
259    }
260
261    #[test]
262    fn non_nullable_take() {
263        let list = ListArray::try_new(
264            PrimitiveArray::from_iter([0i32, 5, 3, 4]).to_array(),
265            PrimitiveArray::from_iter([0, 2, 3, 3, 4]).to_array(),
266            Validity::NonNullable,
267        )
268        .unwrap()
269        .to_array();
270
271        let idx = PrimitiveArray::from_iter([1, 0, 2]).to_array();
272
273        let result = take(&list, &idx).unwrap();
274
275        assert_eq!(
276            result.dtype(),
277            &DType::List(
278                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
279                Nullability::NonNullable
280            )
281        );
282
283        let result = result.to_list().unwrap();
284
285        assert_eq!(result.len(), 3);
286
287        let element_dtype: Arc<DType> = Arc::new(I32.into());
288
289        assert!(result.is_valid(0).unwrap());
290        assert_eq!(
291            result.scalar_at(0),
292            Scalar::list(
293                element_dtype.clone(),
294                vec![3i32.into()],
295                Nullability::NonNullable
296            )
297        );
298
299        assert!(result.is_valid(1).unwrap());
300        assert_eq!(
301            result.scalar_at(1),
302            Scalar::list(
303                element_dtype.clone(),
304                vec![0i32.into(), 5.into()],
305                Nullability::NonNullable
306            )
307        );
308
309        assert!(result.is_valid(2).unwrap());
310        assert_eq!(
311            result.scalar_at(2),
312            Scalar::list(element_dtype, vec![], Nullability::NonNullable)
313        );
314    }
315
316    #[test]
317    fn test_take_empty_array() {
318        let list = ListArray::try_new(
319            PrimitiveArray::from_iter([0i32, 5, 3, 4]).to_array(),
320            PrimitiveArray::from_iter([0]).to_array(),
321            Validity::NonNullable,
322        )
323        .unwrap()
324        .to_array();
325
326        let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
327
328        let result = take(&list, &idx).unwrap();
329        assert_eq!(
330            result.dtype(),
331            &DType::List(
332                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
333                Nullability::Nullable
334            )
335        );
336        assert_eq!(result.len(), 0,);
337    }
338
339    #[rstest]
340    #[case(ListArray::try_new(
341        PrimitiveArray::from_iter([0i32, 1, 2, 3, 4, 5]).to_array(),
342        PrimitiveArray::from_iter([0, 2, 3, 5, 5, 6]).to_array(),
343        Validity::NonNullable,
344    ).unwrap())]
345    #[case(ListArray::try_new(
346        PrimitiveArray::from_iter([10i32, 20, 30, 40, 50]).to_array(),
347        PrimitiveArray::from_iter([0, 2, 3, 4, 5]).to_array(),
348        Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).to_array()),
349    ).unwrap())]
350    #[case(ListArray::try_new(
351        PrimitiveArray::from_iter([1i32, 2, 3]).to_array(),
352        PrimitiveArray::from_iter([0, 0, 2, 2, 3]).to_array(), // First and third are empty
353        Validity::NonNullable,
354    ).unwrap())]
355    #[case(ListArray::try_new(
356        PrimitiveArray::from_iter([42i32, 43]).to_array(),
357        PrimitiveArray::from_iter([0, 2]).to_array(),
358        Validity::NonNullable,
359    ).unwrap())]
360    #[case({
361        let elements = PrimitiveArray::from_iter(0i32..200).to_array();
362        let mut offsets = vec![0u64];
363        for i in 1..=50 {
364            offsets.push(offsets[i - 1] + (i as u64 % 5)); // Variable length lists
365        }
366        ListArray::try_new(
367            elements,
368            PrimitiveArray::from_iter(offsets).to_array(),
369            Validity::NonNullable,
370        ).unwrap()
371    })]
372    #[case(ListArray::try_new(
373        PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).to_array(),
374        PrimitiveArray::from_iter([0, 2, 3, 5]).to_array(),
375        Validity::NonNullable,
376    ).unwrap())]
377    fn test_take_list_conformance(#[case] list: ListArray) {
378        test_take_conformance(list.as_ref());
379    }
380}