1use arrow_buffer::BooleanBufferBuilder;
5use num_traits::PrimInt;
6use vortex_dtype::{NativePType, Nullability, match_each_integer_ptype};
7use vortex_error::{VortexExpect, VortexResult, vortex_panic};
8use vortex_mask::Mask;
9
10use crate::arrays::{ListArray, ListVTable, PrimitiveArray};
11use crate::builders::{ArrayBuilder, PrimitiveBuilder};
12use crate::compute::{TakeKernel, TakeKernelAdapter, take};
13use crate::validity::Validity;
14use crate::vtable::ValidityHelper;
15use crate::{Array, ArrayRef, OffsetPType, ToCanonical, register_kernel};
16
17impl TakeKernel for ListVTable {
18 fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
19 let indices = indices.to_primitive();
20 let offsets = array.offsets().to_primitive();
21
22 match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
23 match_each_integer_ptype!(indices.ptype(), |I| {
24 _take::<I, O>(
25 array,
26 offsets.as_slice::<O>(),
27 &indices,
28 array.validity_mask(),
29 indices.validity_mask(),
30 )
31 })
32 })
33 }
34}
35
36register_kernel!(TakeKernelAdapter(ListVTable).lift());
37
38fn _take<I: NativePType, O: OffsetPType + NativePType + PrimInt>(
39 array: &ListArray,
40 offsets: &[O],
41 indices_array: &PrimitiveArray,
42 data_validity: Mask,
43 indices_validity_mask: Mask,
44) -> VortexResult<ArrayRef> {
45 let indices: &[I] = indices_array.as_slice::<I>();
46
47 if !indices_validity_mask.all_true() || !data_validity.all_true() {
48 return _take_nullable::<I, O>(
49 array,
50 offsets,
51 indices,
52 data_validity,
53 indices_validity_mask,
54 );
55 }
56
57 let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
58 let mut elements_to_take =
59 PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
60
61 let mut current_offset = O::zero();
62 new_offsets.append_zero();
63
64 for &data_idx in indices {
65 let data_idx = data_idx
66 .to_usize()
67 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
68
69 let start = offsets[data_idx];
70 let stop = offsets[data_idx + 1];
71
72 let additional = (stop - start).to_usize().unwrap_or_else(|| {
78 vortex_panic!("Failed to convert range length to usize: {}", stop - start)
79 });
80
81 elements_to_take.ensure_capacity(elements_to_take.len() + additional);
82 for i in 0..additional {
83 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
84 }
85 current_offset = current_offset + (stop - start);
86 new_offsets.append_value(current_offset);
87 }
88
89 let elements_to_take = elements_to_take.finish();
90 let new_offsets = new_offsets.finish();
91
92 let new_elements = take(array.elements(), elements_to_take.as_ref())?;
93
94 Ok(ListArray::try_new(
95 new_elements,
96 new_offsets,
97 indices_array
98 .validity()
99 .clone()
100 .and(array.validity().clone()),
101 )?
102 .to_array())
103}
104
105fn _take_nullable<I: NativePType, O: OffsetPType + NativePType + PrimInt>(
106 array: &ListArray,
107 offsets: &[O],
108 indices: &[I],
109 data_validity: Mask,
110 indices_validity: Mask,
111) -> VortexResult<ArrayRef> {
112 let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
113
114 let mut elements_to_take =
122 PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
123
124 let mut current_offset = O::zero();
125 new_offsets.append_zero();
126
127 let mut new_validity = BooleanBufferBuilder::new(indices.len());
128
129 for (idx, data_idx) in indices.iter().enumerate() {
130 if !indices_validity.value(idx) {
131 new_offsets.append_value(current_offset);
132 new_validity.append(false);
133 continue;
134 }
135
136 let data_idx = data_idx
137 .to_usize()
138 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
139
140 if data_validity.value(data_idx) {
141 let start = offsets[data_idx];
142 let stop = offsets[data_idx + 1];
143
144 let additional = (stop - start).to_usize().unwrap_or_else(|| {
146 vortex_panic!("Failed to convert range length to usize: {}", stop - start)
147 });
148
149 elements_to_take.ensure_capacity(elements_to_take.len() + additional);
150 for i in 0..additional {
151 elements_to_take
152 .append_value(start + O::from_usize(i).vortex_expect("i < additional"));
153 }
154 current_offset = current_offset + (stop - start);
155 new_offsets.append_value(current_offset);
156 new_validity.append(true);
157 } else {
158 new_offsets.append_value(current_offset);
159 new_validity.append(false);
160 }
161 }
162
163 let elements_to_take = elements_to_take.finish();
164 let new_offsets = new_offsets.finish();
165 let new_elements = take(array.elements(), elements_to_take.as_ref())?;
166
167 let new_validity: Validity = Validity::from(new_validity.finish());
168 Ok(ListArray::try_new(new_elements, new_offsets, new_validity)?.to_array())
171}
172
173#[cfg(test)]
174mod test {
175 use std::sync::Arc;
176
177 use rstest::rstest;
178 use vortex_buffer::buffer;
179 use vortex_dtype::PType::I32;
180 use vortex_dtype::{DType, Nullability};
181 use vortex_scalar::Scalar;
182
183 use crate::arrays::list::ListArray;
184 use crate::arrays::{BoolArray, PrimitiveArray};
185 use crate::compute::conformance::take::test_take_conformance;
186 use crate::compute::take;
187 use crate::validity::Validity;
188 use crate::{Array, IntoArray as _, ToCanonical};
189
190 #[test]
191 fn nullable_take() {
192 let list = ListArray::try_new(
193 buffer![0i32, 5, 3, 4].into_array(),
194 buffer![0, 2, 3, 4, 4].into_array(),
195 Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
196 )
197 .unwrap()
198 .to_array();
199
200 let idx =
201 PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
202
203 let result = take(&list, &idx).unwrap();
204
205 assert_eq!(
206 result.dtype(),
207 &DType::List(
208 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
209 Nullability::Nullable
210 )
211 );
212
213 let result = result.to_list();
214
215 assert_eq!(result.len(), 4);
216
217 let element_dtype: Arc<DType> = Arc::new(I32.into());
218
219 assert!(result.is_valid(0));
220 assert_eq!(
221 result.scalar_at(0),
222 Scalar::list(
223 element_dtype.clone(),
224 vec![0i32.into(), 5.into()],
225 Nullability::Nullable
226 )
227 );
228
229 assert!(result.is_invalid(1));
230
231 assert!(result.is_valid(2));
232 assert_eq!(
233 result.scalar_at(2),
234 Scalar::list(
235 element_dtype.clone(),
236 vec![3i32.into()],
237 Nullability::Nullable
238 )
239 );
240
241 assert!(result.is_valid(3));
242 assert_eq!(
243 result.scalar_at(3),
244 Scalar::list(element_dtype, vec![], Nullability::Nullable)
245 );
246 }
247
248 #[test]
249 fn change_validity() {
250 let list = ListArray::try_new(
251 buffer![0i32, 5, 3, 4].into_array(),
252 buffer![0, 2, 3].into_array(),
253 Validity::NonNullable,
254 )
255 .unwrap()
256 .to_array();
257
258 let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
259 let result = take(&list, &idx).unwrap();
262 assert_eq!(
263 result.dtype(),
264 &DType::List(
265 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
266 Nullability::Nullable
267 )
268 );
269 }
270
271 #[test]
272 fn non_nullable_take() {
273 let list = ListArray::try_new(
274 buffer![0i32, 5, 3, 4].into_array(),
275 buffer![0, 2, 3, 3, 4].into_array(),
276 Validity::NonNullable,
277 )
278 .unwrap()
279 .to_array();
280
281 let idx = buffer![1, 0, 2].into_array();
282
283 let result = take(&list, &idx).unwrap();
284
285 assert_eq!(
286 result.dtype(),
287 &DType::List(
288 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
289 Nullability::NonNullable
290 )
291 );
292
293 let result = result.to_list();
294
295 assert_eq!(result.len(), 3);
296
297 let element_dtype: Arc<DType> = Arc::new(I32.into());
298
299 assert!(result.is_valid(0));
300 assert_eq!(
301 result.scalar_at(0),
302 Scalar::list(
303 element_dtype.clone(),
304 vec![3i32.into()],
305 Nullability::NonNullable
306 )
307 );
308
309 assert!(result.is_valid(1));
310 assert_eq!(
311 result.scalar_at(1),
312 Scalar::list(
313 element_dtype.clone(),
314 vec![0i32.into(), 5.into()],
315 Nullability::NonNullable
316 )
317 );
318
319 assert!(result.is_valid(2));
320 assert_eq!(
321 result.scalar_at(2),
322 Scalar::list(element_dtype, vec![], Nullability::NonNullable)
323 );
324 }
325
326 #[test]
327 fn test_take_empty_array() {
328 let list = ListArray::try_new(
329 buffer![0i32, 5, 3, 4].into_array(),
330 buffer![0].into_array(),
331 Validity::NonNullable,
332 )
333 .unwrap()
334 .to_array();
335
336 let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
337
338 let result = take(&list, &idx).unwrap();
339 assert_eq!(
340 result.dtype(),
341 &DType::List(
342 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
343 Nullability::Nullable
344 )
345 );
346 assert_eq!(result.len(), 0,);
347 }
348
349 #[rstest]
350 #[case(ListArray::try_new(
351 buffer![0i32, 1, 2, 3, 4, 5].into_array(),
352 buffer![0, 2, 3, 5, 5, 6].into_array(),
353 Validity::NonNullable,
354 ).unwrap())]
355 #[case(ListArray::try_new(
356 buffer![10i32, 20, 30, 40, 50].into_array(),
357 buffer![0, 2, 3, 4, 5].into_array(),
358 Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).to_array()),
359 ).unwrap())]
360 #[case(ListArray::try_new(
361 buffer![1i32, 2, 3].into_array(),
362 buffer![0, 0, 2, 2, 3].into_array(), Validity::NonNullable,
364 ).unwrap())]
365 #[case(ListArray::try_new(
366 buffer![42i32, 43].into_array(),
367 buffer![0, 2].into_array(),
368 Validity::NonNullable,
369 ).unwrap())]
370 #[case({
371 let elements = buffer![0i32..200].into_array();
372 let mut offsets = vec![0u64];
373 for i in 1..=50 {
374 offsets.push(offsets[i - 1] + (i as u64 % 5)); }
376 ListArray::try_new(
377 elements,
378 PrimitiveArray::from_iter(offsets).to_array(),
379 Validity::NonNullable,
380 ).unwrap()
381 })]
382 #[case(ListArray::try_new(
383 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).to_array(),
384 buffer![0, 2, 3, 5].into_array(),
385 Validity::NonNullable,
386 ).unwrap())]
387 fn test_take_list_conformance(#[case] list: ListArray) {
388 test_take_conformance(list.as_ref());
389 }
390}