vortex_array/arrays/list/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::BooleanBufferBuilder;
5use num_traits::PrimInt;
6use vortex_dtype::{NativePType, Nullability, match_each_integer_ptype};
7use vortex_error::{VortexExpect, VortexResult, vortex_panic};
8use vortex_mask::Mask;
9
10use crate::arrays::{ListArray, ListVTable, OffsetPType, PrimitiveArray};
11use crate::builders::{ArrayBuilder, PrimitiveBuilder};
12use crate::compute::{TakeKernel, TakeKernelAdapter, take};
13use crate::validity::Validity;
14use crate::vtable::ValidityHelper;
15use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
16
17impl TakeKernel for ListVTable {
18    fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
19        let indices = indices.to_primitive()?;
20        let offsets = array.offsets().to_primitive()?;
21
22        match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
23            match_each_integer_ptype!(indices.ptype(), |I| {
24                Ok(_take::<I, O>(
25                    array,
26                    offsets.as_slice::<O>(),
27                    &indices,
28                    array.validity_mask()?,
29                    indices.validity_mask()?,
30                )?
31                .into_array())
32            })
33        })
34    }
35}
36
37register_kernel!(TakeKernelAdapter(ListVTable).lift());
38
39fn _take<I: NativePType, O: OffsetPType + NativePType + PrimInt>(
40    array: &ListArray,
41    offsets: &[O],
42    indices_array: &PrimitiveArray,
43    data_validity: Mask,
44    indices_validity_mask: Mask,
45) -> VortexResult<ArrayRef> {
46    let indices: &[I] = indices_array.as_slice::<I>();
47
48    if !indices_validity_mask.all_true() || !data_validity.all_true() {
49        return _take_nullable::<I, O>(
50            array,
51            offsets,
52            indices,
53            data_validity,
54            indices_validity_mask,
55        );
56    }
57
58    let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
59    let mut elements_to_take =
60        PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
61
62    let mut current_offset = O::zero();
63    new_offsets.append_zero();
64
65    for &data_idx in indices {
66        let data_idx = data_idx
67            .to_usize()
68            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
69
70        let start = offsets[data_idx];
71        let stop = offsets[data_idx + 1];
72
73        // Annoyingly, we can't turn (start..end) into a range, so we're doing that manually.
74        //
75        // We could convert start and end to usize, but that would impose a potentially
76        // harder constraint - now we don't care if they fit into usize as long as their
77        // difference does.
78        let additional = (stop - start).to_usize().unwrap_or_else(|| {
79            vortex_panic!("Failed to convert range length to usize: {}", stop - start)
80        });
81
82        elements_to_take.ensure_capacity(elements_to_take.len() + additional);
83        for i in 0..additional {
84            elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
85        }
86        current_offset = current_offset + (stop - start);
87        new_offsets.append_value(current_offset);
88    }
89
90    let elements_to_take = elements_to_take.finish();
91    let new_offsets = new_offsets.finish();
92
93    let new_elements = take(array.elements(), elements_to_take.as_ref())?;
94
95    Ok(ListArray::try_new(
96        new_elements,
97        new_offsets,
98        indices_array
99            .validity()
100            .clone()
101            .and(array.validity().clone())?,
102    )?
103    .to_array())
104}
105
106fn _take_nullable<I: NativePType, O: OffsetPType + NativePType + PrimInt>(
107    array: &ListArray,
108    offsets: &[O],
109    indices: &[I],
110    data_validity: Mask,
111    indices_validity: Mask,
112) -> VortexResult<ArrayRef> {
113    let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
114    let mut elements_to_take =
115        PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
116
117    let mut current_offset = O::zero();
118    new_offsets.append_zero();
119    let mut new_validity = BooleanBufferBuilder::new(2 * indices.len());
120
121    for (idx, data_idx) in indices.iter().enumerate() {
122        if !indices_validity.value(idx) {
123            new_offsets.append_value(current_offset);
124            new_validity.append(false);
125            continue;
126        }
127
128        let data_idx = data_idx
129            .to_usize()
130            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
131
132        if data_validity.value(data_idx) {
133            let start = offsets[data_idx];
134            let stop = offsets[data_idx + 1];
135
136            // See the note it the `take` on the reasoning
137            let additional = (stop - start).to_usize().unwrap_or_else(|| {
138                vortex_panic!("Failed to convert range length to usize: {}", stop - start)
139            });
140
141            elements_to_take.ensure_capacity(elements_to_take.len() + additional);
142            for i in 0..additional {
143                elements_to_take
144                    .append_value(start + O::from_usize(i).vortex_expect("i < additional"));
145            }
146            current_offset = current_offset + (stop - start);
147            new_offsets.append_value(current_offset);
148            new_validity.append(true);
149        } else {
150            new_offsets.append_value(current_offset);
151            new_validity.append(false);
152        }
153    }
154
155    let elements_to_take = elements_to_take.finish();
156    let new_offsets = new_offsets.finish();
157    let new_elements = take(array.elements(), elements_to_take.as_ref())?;
158
159    let new_validity: Validity = Validity::from(new_validity.finish());
160    // data are indexes are nullable, so the final result is also nullable.
161
162    Ok(ListArray::try_new(new_elements, new_offsets, new_validity)?.to_array())
163}
164
165#[cfg(test)]
166mod test {
167    use std::sync::Arc;
168
169    use vortex_dtype::PType::I32;
170    use vortex_dtype::{DType, Nullability};
171    use vortex_scalar::Scalar;
172
173    use crate::arrays::list::ListArray;
174    use crate::arrays::{BoolArray, PrimitiveArray};
175    use crate::compute::take;
176    use crate::validity::Validity;
177    use crate::{Array, ToCanonical};
178
179    #[test]
180    fn nullable_take() {
181        let list = ListArray::try_new(
182            PrimitiveArray::from_iter([0i32, 5, 3, 4]).to_array(),
183            PrimitiveArray::from_iter([0, 2, 3, 4, 4]).to_array(),
184            Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
185        )
186        .unwrap()
187        .to_array();
188
189        let idx =
190            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
191
192        let result = take(&list, &idx).unwrap();
193
194        assert_eq!(
195            result.dtype(),
196            &DType::List(
197                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
198                Nullability::Nullable
199            )
200        );
201
202        let result = result.to_list().unwrap();
203
204        assert_eq!(result.len(), 4);
205
206        let element_dtype: Arc<DType> = Arc::new(I32.into());
207
208        assert!(result.is_valid(0).unwrap());
209        assert_eq!(
210            result.scalar_at(0).unwrap(),
211            Scalar::list(
212                element_dtype.clone(),
213                vec![0i32.into(), 5.into()],
214                Nullability::Nullable
215            )
216        );
217
218        assert!(result.is_invalid(1).unwrap());
219
220        assert!(result.is_valid(2).unwrap());
221        assert_eq!(
222            result.scalar_at(2).unwrap(),
223            Scalar::list(
224                element_dtype.clone(),
225                vec![3i32.into()],
226                Nullability::Nullable
227            )
228        );
229
230        assert!(result.is_valid(3).unwrap());
231        assert_eq!(
232            result.scalar_at(3).unwrap(),
233            Scalar::list(element_dtype, vec![], Nullability::Nullable)
234        );
235    }
236
237    #[test]
238    fn change_validity() {
239        let list = ListArray::try_new(
240            PrimitiveArray::from_iter([0i32, 5, 3, 4]).to_array(),
241            PrimitiveArray::from_iter([0, 2, 3]).to_array(),
242            Validity::NonNullable,
243        )
244        .unwrap()
245        .to_array();
246
247        let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
248        // since idx is nullable, the final list will also be nullable
249
250        let result = take(&list, &idx).unwrap();
251        assert_eq!(
252            result.dtype(),
253            &DType::List(
254                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
255                Nullability::Nullable
256            )
257        );
258    }
259
260    #[test]
261    fn non_nullable_take() {
262        let list = ListArray::try_new(
263            PrimitiveArray::from_iter([0i32, 5, 3, 4]).to_array(),
264            PrimitiveArray::from_iter([0, 2, 3, 3, 4]).to_array(),
265            Validity::NonNullable,
266        )
267        .unwrap()
268        .to_array();
269
270        let idx = PrimitiveArray::from_iter([1, 0, 2]).to_array();
271
272        let result = take(&list, &idx).unwrap();
273
274        assert_eq!(
275            result.dtype(),
276            &DType::List(
277                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
278                Nullability::NonNullable
279            )
280        );
281
282        let result = result.to_list().unwrap();
283
284        assert_eq!(result.len(), 3);
285
286        let element_dtype: Arc<DType> = Arc::new(I32.into());
287
288        assert!(result.is_valid(0).unwrap());
289        assert_eq!(
290            result.scalar_at(0).unwrap(),
291            Scalar::list(
292                element_dtype.clone(),
293                vec![3i32.into()],
294                Nullability::NonNullable
295            )
296        );
297
298        assert!(result.is_valid(1).unwrap());
299        assert_eq!(
300            result.scalar_at(1).unwrap(),
301            Scalar::list(
302                element_dtype.clone(),
303                vec![0i32.into(), 5.into()],
304                Nullability::NonNullable
305            )
306        );
307
308        assert!(result.is_valid(2).unwrap());
309        assert_eq!(
310            result.scalar_at(2).unwrap(),
311            Scalar::list(element_dtype, vec![], Nullability::NonNullable)
312        );
313    }
314
315    #[test]
316    fn test_take_empty_array() {
317        let list = ListArray::try_new(
318            PrimitiveArray::from_iter([0i32, 5, 3, 4]).to_array(),
319            PrimitiveArray::from_iter([0]).to_array(),
320            Validity::NonNullable,
321        )
322        .unwrap()
323        .to_array();
324
325        let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
326
327        let result = take(&list, &idx).unwrap();
328        assert_eq!(
329            result.dtype(),
330            &DType::List(
331                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
332                Nullability::Nullable
333            )
334        );
335        assert_eq!(result.len(), 0,);
336    }
337}