vortex_array/arrays/list/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBufferMut;
5use vortex_dtype::{IntegerPType, Nullability, match_each_integer_ptype};
6use vortex_error::{VortexExpect, VortexResult, vortex_panic};
7use vortex_mask::Mask;
8
9use crate::arrays::{ListArray, ListVTable, PrimitiveArray};
10use crate::builders::{ArrayBuilder, PrimitiveBuilder};
11use crate::compute::{TakeKernel, TakeKernelAdapter, take};
12use crate::validity::Validity;
13use crate::vtable::ValidityHelper;
14use crate::{Array, ArrayRef, ToCanonical, register_kernel};
15
16// TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call
17// the `ListView::take` compute function once `ListView` is more stable.
18
19/// Take implementation for [`ListArray`].
20///
21/// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant
22/// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
23/// non-contiguous indices would violate this requirement.
24impl TakeKernel for ListVTable {
25    fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
26        let indices = indices.to_primitive();
27        let offsets = array.offsets().to_primitive();
28
29        match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
30            match_each_integer_ptype!(indices.ptype(), |I| {
31                _take::<I, O>(
32                    array,
33                    offsets.as_slice::<O>(),
34                    &indices,
35                    array.validity_mask(),
36                    indices.validity_mask(),
37                )
38            })
39        })
40    }
41}
42
43register_kernel!(TakeKernelAdapter(ListVTable).lift());
44
45fn _take<I: IntegerPType, O: IntegerPType>(
46    array: &ListArray,
47    offsets: &[O],
48    indices_array: &PrimitiveArray,
49    data_validity: Mask,
50    indices_validity_mask: Mask,
51) -> VortexResult<ArrayRef> {
52    let indices: &[I] = indices_array.as_slice::<I>();
53
54    if !indices_validity_mask.all_true() || !data_validity.all_true() {
55        return _take_nullable::<I, O>(
56            array,
57            offsets,
58            indices,
59            data_validity,
60            indices_validity_mask,
61        );
62    }
63
64    let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
65    let mut elements_to_take =
66        PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
67
68    let mut current_offset = O::zero();
69    new_offsets.append_zero();
70
71    for &data_idx in indices {
72        let data_idx = data_idx
73            .to_usize()
74            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
75
76        let start = offsets[data_idx];
77        let stop = offsets[data_idx + 1];
78
79        // Annoyingly, we can't turn (start..end) into a range, so we're doing that manually.
80        //
81        // We could convert start and end to usize, but that would impose a potentially
82        // harder constraint - now we don't care if they fit into usize as long as their
83        // difference does.
84        let additional = (stop - start).to_usize().unwrap_or_else(|| {
85            vortex_panic!("Failed to convert range length to usize: {}", stop - start)
86        });
87
88        elements_to_take.reserve_exact(additional);
89        for i in 0..additional {
90            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
91        }
92        current_offset += stop - start;
93        new_offsets.append_value(current_offset);
94    }
95
96    let elements_to_take = elements_to_take.finish();
97    let new_offsets = new_offsets.finish();
98
99    let new_elements = take(array.elements(), elements_to_take.as_ref())?;
100
101    Ok(ListArray::try_new(
102        new_elements,
103        new_offsets,
104        indices_array
105            .validity()
106            .clone()
107            .and(array.validity().clone()),
108    )?
109    .to_array())
110}
111
112fn _take_nullable<I: IntegerPType, O: IntegerPType>(
113    array: &ListArray,
114    offsets: &[O],
115    indices: &[I],
116    data_validity: Mask,
117    indices_validity: Mask,
118) -> VortexResult<ArrayRef> {
119    let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
120
121    // This will be the indices we push down to the child array to call `take` with.
122    //
123    // There are 2 things to note here:
124    // - We do not know how many elements we need to take from our child since lists are variable
125    //   size: thus we arbitrarily choose a capacity of `2 * # of indices`.
126    // - The type of the primitive builder needs to fit the largest offset of the (parent)
127    //   `ListArray`, so we make this `PrimitiveBuilder` generic over `O` (instead of `I`).
128    let mut elements_to_take =
129        PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
130
131    let mut current_offset = O::zero();
132    new_offsets.append_zero();
133
134    // Set all bits to invalid and selectively set which values are valid.
135    let mut new_validity = BitBufferMut::new_unset(indices.len());
136
137    for (idx, data_idx) in indices.iter().enumerate() {
138        if !indices_validity.value(idx) {
139            new_offsets.append_value(current_offset);
140            // Bit buffer already has this set to invalid.
141            continue;
142        }
143
144        let data_idx = data_idx
145            .to_usize()
146            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
147
148        if !data_validity.value(data_idx) {
149            new_offsets.append_value(current_offset);
150            // Bit buffer already has this set to invalid.
151            continue;
152        }
153
154        let start = offsets[data_idx];
155        let stop = offsets[data_idx + 1];
156
157        // See the note it the `take` on the reasoning
158        let additional = (stop - start).to_usize().unwrap_or_else(|| {
159            vortex_panic!("Failed to convert range length to usize: {}", stop - start)
160        });
161
162        elements_to_take.reserve_exact(additional);
163        for i in 0..additional {
164            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
165        }
166        current_offset += stop - start;
167        new_offsets.append_value(current_offset);
168        new_validity.set(idx);
169    }
170
171    let elements_to_take = elements_to_take.finish();
172    let new_offsets = new_offsets.finish();
173    let new_elements = take(array.elements(), elements_to_take.as_ref())?;
174
175    let new_validity = Validity::from(new_validity.freeze());
176    // data are indexes are nullable, so the final result is also nullable.
177
178    Ok(ListArray::try_new(new_elements, new_offsets, new_validity)?.to_array())
179}
180
181#[cfg(test)]
182mod test {
183    use std::sync::Arc;
184
185    use rstest::rstest;
186    use vortex_buffer::buffer;
187    use vortex_dtype::PType::I32;
188    use vortex_dtype::{DType, Nullability};
189    use vortex_scalar::Scalar;
190
191    use crate::arrays::list::ListArray;
192    use crate::arrays::{BoolArray, PrimitiveArray};
193    use crate::compute::conformance::take::test_take_conformance;
194    use crate::compute::take;
195    use crate::validity::Validity;
196    use crate::{Array, IntoArray as _, ToCanonical};
197
198    #[test]
199    fn nullable_take() {
200        let list = ListArray::try_new(
201            buffer![0i32, 5, 3, 4].into_array(),
202            buffer![0, 2, 3, 4, 4].into_array(),
203            Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
204        )
205        .unwrap()
206        .to_array();
207
208        let idx =
209            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
210
211        let result = take(&list, &idx).unwrap();
212
213        assert_eq!(
214            result.dtype(),
215            &DType::List(
216                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
217                Nullability::Nullable
218            )
219        );
220
221        let result = result.to_listview();
222
223        assert_eq!(result.len(), 4);
224
225        let element_dtype: Arc<DType> = Arc::new(I32.into());
226
227        assert!(result.is_valid(0));
228        assert_eq!(
229            result.scalar_at(0),
230            Scalar::list(
231                element_dtype.clone(),
232                vec![0i32.into(), 5.into()],
233                Nullability::Nullable
234            )
235        );
236
237        assert!(result.is_invalid(1));
238
239        assert!(result.is_valid(2));
240        assert_eq!(
241            result.scalar_at(2),
242            Scalar::list(
243                element_dtype.clone(),
244                vec![3i32.into()],
245                Nullability::Nullable
246            )
247        );
248
249        assert!(result.is_valid(3));
250        assert_eq!(
251            result.scalar_at(3),
252            Scalar::list(element_dtype, vec![], Nullability::Nullable)
253        );
254    }
255
256    #[test]
257    fn change_validity() {
258        let list = ListArray::try_new(
259            buffer![0i32, 5, 3, 4].into_array(),
260            buffer![0, 2, 3].into_array(),
261            Validity::NonNullable,
262        )
263        .unwrap()
264        .to_array();
265
266        let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
267        // since idx is nullable, the final list will also be nullable
268
269        let result = take(&list, &idx).unwrap();
270        assert_eq!(
271            result.dtype(),
272            &DType::List(
273                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
274                Nullability::Nullable
275            )
276        );
277    }
278
279    #[test]
280    fn non_nullable_take() {
281        let list = ListArray::try_new(
282            buffer![0i32, 5, 3, 4].into_array(),
283            buffer![0, 2, 3, 3, 4].into_array(),
284            Validity::NonNullable,
285        )
286        .unwrap()
287        .to_array();
288
289        let idx = buffer![1, 0, 2].into_array();
290
291        let result = take(&list, &idx).unwrap();
292
293        assert_eq!(
294            result.dtype(),
295            &DType::List(
296                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
297                Nullability::NonNullable
298            )
299        );
300
301        let result = result.to_listview();
302
303        assert_eq!(result.len(), 3);
304
305        let element_dtype: Arc<DType> = Arc::new(I32.into());
306
307        assert!(result.is_valid(0));
308        assert_eq!(
309            result.scalar_at(0),
310            Scalar::list(
311                element_dtype.clone(),
312                vec![3i32.into()],
313                Nullability::NonNullable
314            )
315        );
316
317        assert!(result.is_valid(1));
318        assert_eq!(
319            result.scalar_at(1),
320            Scalar::list(
321                element_dtype.clone(),
322                vec![0i32.into(), 5.into()],
323                Nullability::NonNullable
324            )
325        );
326
327        assert!(result.is_valid(2));
328        assert_eq!(
329            result.scalar_at(2),
330            Scalar::list(element_dtype, vec![], Nullability::NonNullable)
331        );
332    }
333
334    #[test]
335    fn test_take_empty_array() {
336        let list = ListArray::try_new(
337            buffer![0i32, 5, 3, 4].into_array(),
338            buffer![0].into_array(),
339            Validity::NonNullable,
340        )
341        .unwrap()
342        .to_array();
343
344        let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
345
346        let result = take(&list, &idx).unwrap();
347        assert_eq!(
348            result.dtype(),
349            &DType::List(
350                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
351                Nullability::Nullable
352            )
353        );
354        assert_eq!(result.len(), 0,);
355    }
356
357    #[rstest]
358    #[case(ListArray::try_new(
359        buffer![0i32, 1, 2, 3, 4, 5].into_array(),
360        buffer![0, 2, 3, 5, 5, 6].into_array(),
361        Validity::NonNullable,
362    ).unwrap())]
363    #[case(ListArray::try_new(
364        buffer![10i32, 20, 30, 40, 50].into_array(),
365        buffer![0, 2, 3, 4, 5].into_array(),
366        Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).to_array()),
367    ).unwrap())]
368    #[case(ListArray::try_new(
369        buffer![1i32, 2, 3].into_array(),
370        buffer![0, 0, 2, 2, 3].into_array(), // First and third are empty
371        Validity::NonNullable,
372    ).unwrap())]
373    #[case(ListArray::try_new(
374        buffer![42i32, 43].into_array(),
375        buffer![0, 2].into_array(),
376        Validity::NonNullable,
377    ).unwrap())]
378    #[case({
379        let elements = buffer![0i32..200].into_array();
380        let mut offsets = vec![0u64];
381        for i in 1..=50 {
382            offsets.push(offsets[i - 1] + (i as u64 % 5)); // Variable length lists
383        }
384        ListArray::try_new(
385            elements,
386            PrimitiveArray::from_iter(offsets).to_array(),
387            Validity::NonNullable,
388        ).unwrap()
389    })]
390    #[case(ListArray::try_new(
391        PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).to_array(),
392        buffer![0, 2, 3, 5].into_array(),
393        Validity::NonNullable,
394    ).unwrap())]
395    fn test_take_list_conformance(#[case] list: ListArray) {
396        test_take_conformance(list.as_ref());
397    }
398}