vortex_array/arrays/list/compute/
take.rs1use vortex_error::VortexResult;
5
6use crate::arrays::{ListArray, ListVTable, list_view_from_list};
7use crate::compute::{self, TakeKernel, TakeKernelAdapter};
8use crate::{Array, ArrayRef, IntoArray, register_kernel};
9
10impl TakeKernel for ListVTable {
21 fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
22 let list_view = list_view_from_list(array.clone());
23 compute::take(&list_view.into_array(), indices)
24 }
25}
26
27register_kernel!(TakeKernelAdapter(ListVTable).lift());
28
29#[cfg(test)]
30mod test {
31 use std::sync::Arc;
32
33 use rstest::rstest;
34 use vortex_buffer::buffer;
35 use vortex_dtype::PType::I32;
36 use vortex_dtype::{DType, Nullability};
37 use vortex_scalar::Scalar;
38
39 use crate::arrays::list::ListArray;
40 use crate::arrays::{BoolArray, PrimitiveArray};
41 use crate::compute::conformance::take::test_take_conformance;
42 use crate::compute::take;
43 use crate::validity::Validity;
44 use crate::{Array, IntoArray as _, ToCanonical};
45
46 #[test]
47 fn nullable_take() {
48 let list = ListArray::try_new(
49 buffer![0i32, 5, 3, 4].into_array(),
50 buffer![0, 2, 3, 4, 4].into_array(),
51 Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
52 )
53 .unwrap()
54 .to_array();
55
56 let idx =
57 PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
58
59 let result = take(&list, &idx).unwrap();
60
61 assert_eq!(
62 result.dtype(),
63 &DType::List(
64 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
65 Nullability::Nullable
66 )
67 );
68
69 let result = result.to_listview();
70
71 assert_eq!(result.len(), 4);
72
73 let element_dtype: Arc<DType> = Arc::new(I32.into());
74
75 assert!(result.is_valid(0));
76 assert_eq!(
77 result.scalar_at(0),
78 Scalar::list(
79 element_dtype.clone(),
80 vec![0i32.into(), 5.into()],
81 Nullability::Nullable
82 )
83 );
84
85 assert!(result.is_invalid(1));
86
87 assert!(result.is_valid(2));
88 assert_eq!(
89 result.scalar_at(2),
90 Scalar::list(
91 element_dtype.clone(),
92 vec![3i32.into()],
93 Nullability::Nullable
94 )
95 );
96
97 assert!(result.is_valid(3));
98 assert_eq!(
99 result.scalar_at(3),
100 Scalar::list(element_dtype, vec![], Nullability::Nullable)
101 );
102 }
103
104 #[test]
105 fn change_validity() {
106 let list = ListArray::try_new(
107 buffer![0i32, 5, 3, 4].into_array(),
108 buffer![0, 2, 3].into_array(),
109 Validity::NonNullable,
110 )
111 .unwrap()
112 .to_array();
113
114 let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
115 let result = take(&list, &idx).unwrap();
118 assert_eq!(
119 result.dtype(),
120 &DType::List(
121 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
122 Nullability::Nullable
123 )
124 );
125 }
126
127 #[test]
128 fn non_nullable_take() {
129 let list = ListArray::try_new(
130 buffer![0i32, 5, 3, 4].into_array(),
131 buffer![0, 2, 3, 3, 4].into_array(),
132 Validity::NonNullable,
133 )
134 .unwrap()
135 .to_array();
136
137 let idx = buffer![1, 0, 2].into_array();
138
139 let result = take(&list, &idx).unwrap();
140
141 assert_eq!(
142 result.dtype(),
143 &DType::List(
144 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
145 Nullability::NonNullable
146 )
147 );
148
149 let result = result.to_listview();
150
151 assert_eq!(result.len(), 3);
152
153 let element_dtype: Arc<DType> = Arc::new(I32.into());
154
155 assert!(result.is_valid(0));
156 assert_eq!(
157 result.scalar_at(0),
158 Scalar::list(
159 element_dtype.clone(),
160 vec![3i32.into()],
161 Nullability::NonNullable
162 )
163 );
164
165 assert!(result.is_valid(1));
166 assert_eq!(
167 result.scalar_at(1),
168 Scalar::list(
169 element_dtype.clone(),
170 vec![0i32.into(), 5.into()],
171 Nullability::NonNullable
172 )
173 );
174
175 assert!(result.is_valid(2));
176 assert_eq!(
177 result.scalar_at(2),
178 Scalar::list(element_dtype, vec![], Nullability::NonNullable)
179 );
180 }
181
182 #[test]
183 fn test_take_empty_array() {
184 let list = ListArray::try_new(
185 buffer![0i32, 5, 3, 4].into_array(),
186 buffer![0].into_array(),
187 Validity::NonNullable,
188 )
189 .unwrap()
190 .to_array();
191
192 let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
193
194 let result = take(&list, &idx).unwrap();
195 assert_eq!(
196 result.dtype(),
197 &DType::List(
198 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
199 Nullability::Nullable
200 )
201 );
202 assert_eq!(result.len(), 0,);
203 }
204
205 #[rstest]
206 #[case(ListArray::try_new(
207 buffer![0i32, 1, 2, 3, 4, 5].into_array(),
208 buffer![0, 2, 3, 5, 5, 6].into_array(),
209 Validity::NonNullable,
210 ).unwrap())]
211 #[case(ListArray::try_new(
212 buffer![10i32, 20, 30, 40, 50].into_array(),
213 buffer![0, 2, 3, 4, 5].into_array(),
214 Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).to_array()),
215 ).unwrap())]
216 #[case(ListArray::try_new(
217 buffer![1i32, 2, 3].into_array(),
218 buffer![0, 0, 2, 2, 3].into_array(), Validity::NonNullable,
220 ).unwrap())]
221 #[case(ListArray::try_new(
222 buffer![42i32, 43].into_array(),
223 buffer![0, 2].into_array(),
224 Validity::NonNullable,
225 ).unwrap())]
226 #[case({
227 let elements = buffer![0i32..200].into_array();
228 let mut offsets = vec![0u64];
229 for i in 1..=50 {
230 offsets.push(offsets[i - 1] + (i as u64 % 5)); }
232 ListArray::try_new(
233 elements,
234 PrimitiveArray::from_iter(offsets).to_array(),
235 Validity::NonNullable,
236 ).unwrap()
237 })]
238 #[case(ListArray::try_new(
239 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).to_array(),
240 buffer![0, 2, 3, 5].into_array(),
241 Validity::NonNullable,
242 ).unwrap())]
243 fn test_take_list_conformance(#[case] list: ListArray) {
244 test_take_conformance(list.as_ref());
245 }
246}