vortex_array/arrays/list/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::arrays::{ListArray, ListVTable, list_view_from_list};
7use crate::compute::{self, TakeKernel, TakeKernelAdapter};
8use crate::{Array, ArrayRef, IntoArray, register_kernel};
9
10// TODO(connor): For very short arrays it is probably more efficient to build the list from scratch.
11/// Take implementation for [`ListArray`].
12///
13/// This implementation converts the [`ListArray`] to a [`ListViewArray`] and then delegates to its
14/// `take` implementation. This approach avoids the need to rebuild the `elements` array.
15///
16/// The resulting [`ListViewArray`] can represent non-contiguous and out-of-order lists, which would
17/// violate [`ListArray`]'s invariants (but not [`ListViewArray`]'s).
18///
19/// [`ListViewArray`]: crate::arrays::ListViewArray
20impl TakeKernel for ListVTable {
21    fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
22        let list_view = list_view_from_list(array.clone());
23        compute::take(&list_view.into_array(), indices)
24    }
25}
26
27register_kernel!(TakeKernelAdapter(ListVTable).lift());
28
29#[cfg(test)]
30mod test {
31    use std::sync::Arc;
32
33    use rstest::rstest;
34    use vortex_buffer::buffer;
35    use vortex_dtype::PType::I32;
36    use vortex_dtype::{DType, Nullability};
37    use vortex_scalar::Scalar;
38
39    use crate::arrays::list::ListArray;
40    use crate::arrays::{BoolArray, PrimitiveArray};
41    use crate::compute::conformance::take::test_take_conformance;
42    use crate::compute::take;
43    use crate::validity::Validity;
44    use crate::{Array, IntoArray as _, ToCanonical};
45
46    #[test]
47    fn nullable_take() {
48        let list = ListArray::try_new(
49            buffer![0i32, 5, 3, 4].into_array(),
50            buffer![0, 2, 3, 4, 4].into_array(),
51            Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
52        )
53        .unwrap()
54        .to_array();
55
56        let idx =
57            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
58
59        let result = take(&list, &idx).unwrap();
60
61        assert_eq!(
62            result.dtype(),
63            &DType::List(
64                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
65                Nullability::Nullable
66            )
67        );
68
69        let result = result.to_listview();
70
71        assert_eq!(result.len(), 4);
72
73        let element_dtype: Arc<DType> = Arc::new(I32.into());
74
75        assert!(result.is_valid(0));
76        assert_eq!(
77            result.scalar_at(0),
78            Scalar::list(
79                element_dtype.clone(),
80                vec![0i32.into(), 5.into()],
81                Nullability::Nullable
82            )
83        );
84
85        assert!(result.is_invalid(1));
86
87        assert!(result.is_valid(2));
88        assert_eq!(
89            result.scalar_at(2),
90            Scalar::list(
91                element_dtype.clone(),
92                vec![3i32.into()],
93                Nullability::Nullable
94            )
95        );
96
97        assert!(result.is_valid(3));
98        assert_eq!(
99            result.scalar_at(3),
100            Scalar::list(element_dtype, vec![], Nullability::Nullable)
101        );
102    }
103
104    #[test]
105    fn change_validity() {
106        let list = ListArray::try_new(
107            buffer![0i32, 5, 3, 4].into_array(),
108            buffer![0, 2, 3].into_array(),
109            Validity::NonNullable,
110        )
111        .unwrap()
112        .to_array();
113
114        let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
115        // since idx is nullable, the final list will also be nullable
116
117        let result = take(&list, &idx).unwrap();
118        assert_eq!(
119            result.dtype(),
120            &DType::List(
121                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
122                Nullability::Nullable
123            )
124        );
125    }
126
127    #[test]
128    fn non_nullable_take() {
129        let list = ListArray::try_new(
130            buffer![0i32, 5, 3, 4].into_array(),
131            buffer![0, 2, 3, 3, 4].into_array(),
132            Validity::NonNullable,
133        )
134        .unwrap()
135        .to_array();
136
137        let idx = buffer![1, 0, 2].into_array();
138
139        let result = take(&list, &idx).unwrap();
140
141        assert_eq!(
142            result.dtype(),
143            &DType::List(
144                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
145                Nullability::NonNullable
146            )
147        );
148
149        let result = result.to_listview();
150
151        assert_eq!(result.len(), 3);
152
153        let element_dtype: Arc<DType> = Arc::new(I32.into());
154
155        assert!(result.is_valid(0));
156        assert_eq!(
157            result.scalar_at(0),
158            Scalar::list(
159                element_dtype.clone(),
160                vec![3i32.into()],
161                Nullability::NonNullable
162            )
163        );
164
165        assert!(result.is_valid(1));
166        assert_eq!(
167            result.scalar_at(1),
168            Scalar::list(
169                element_dtype.clone(),
170                vec![0i32.into(), 5.into()],
171                Nullability::NonNullable
172            )
173        );
174
175        assert!(result.is_valid(2));
176        assert_eq!(
177            result.scalar_at(2),
178            Scalar::list(element_dtype, vec![], Nullability::NonNullable)
179        );
180    }
181
182    #[test]
183    fn test_take_empty_array() {
184        let list = ListArray::try_new(
185            buffer![0i32, 5, 3, 4].into_array(),
186            buffer![0].into_array(),
187            Validity::NonNullable,
188        )
189        .unwrap()
190        .to_array();
191
192        let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
193
194        let result = take(&list, &idx).unwrap();
195        assert_eq!(
196            result.dtype(),
197            &DType::List(
198                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
199                Nullability::Nullable
200            )
201        );
202        assert_eq!(result.len(), 0,);
203    }
204
205    #[rstest]
206    #[case(ListArray::try_new(
207        buffer![0i32, 1, 2, 3, 4, 5].into_array(),
208        buffer![0, 2, 3, 5, 5, 6].into_array(),
209        Validity::NonNullable,
210    ).unwrap())]
211    #[case(ListArray::try_new(
212        buffer![10i32, 20, 30, 40, 50].into_array(),
213        buffer![0, 2, 3, 4, 5].into_array(),
214        Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).to_array()),
215    ).unwrap())]
216    #[case(ListArray::try_new(
217        buffer![1i32, 2, 3].into_array(),
218        buffer![0, 0, 2, 2, 3].into_array(), // First and third are empty
219        Validity::NonNullable,
220    ).unwrap())]
221    #[case(ListArray::try_new(
222        buffer![42i32, 43].into_array(),
223        buffer![0, 2].into_array(),
224        Validity::NonNullable,
225    ).unwrap())]
226    #[case({
227        let elements = buffer![0i32..200].into_array();
228        let mut offsets = vec![0u64];
229        for i in 1..=50 {
230            offsets.push(offsets[i - 1] + (i as u64 % 5)); // Variable length lists
231        }
232        ListArray::try_new(
233            elements,
234            PrimitiveArray::from_iter(offsets).to_array(),
235            Validity::NonNullable,
236        ).unwrap()
237    })]
238    #[case(ListArray::try_new(
239        PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).to_array(),
240        buffer![0, 2, 3, 5].into_array(),
241        Validity::NonNullable,
242    ).unwrap())]
243    fn test_take_list_conformance(#[case] list: ListArray) {
244        test_take_conformance(list.as_ref());
245    }
246}