1use vortex_buffer::BitBufferMut;
5use vortex_dtype::IntegerPType;
6use vortex_dtype::Nullability;
7use vortex_dtype::match_each_integer_ptype;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_panic;
11use vortex_mask::Mask;
12
13use crate::Array;
14use crate::ArrayRef;
15use crate::ToCanonical;
16use crate::arrays::ListArray;
17use crate::arrays::ListVTable;
18use crate::arrays::PrimitiveArray;
19use crate::builders::ArrayBuilder;
20use crate::builders::PrimitiveBuilder;
21use crate::compute::TakeKernel;
22use crate::compute::TakeKernelAdapter;
23use crate::compute::take;
24use crate::register_kernel;
25use crate::validity::Validity;
26use crate::vtable::ValidityHelper;
27
28impl TakeKernel for ListVTable {
37 fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
38 let indices = indices.to_primitive();
39 let offsets = array.offsets().to_primitive();
40
41 match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
42 match_each_integer_ptype!(indices.ptype(), |I| {
43 _take::<I, O>(
44 array,
45 offsets.as_slice::<O>(),
46 &indices,
47 array.validity_mask(),
48 indices.validity_mask(),
49 )
50 })
51 })
52 }
53}
54
55register_kernel!(TakeKernelAdapter(ListVTable).lift());
56
57fn _take<I: IntegerPType, O: IntegerPType>(
58 array: &ListArray,
59 offsets: &[O],
60 indices_array: &PrimitiveArray,
61 data_validity: Mask,
62 indices_validity_mask: Mask,
63) -> VortexResult<ArrayRef> {
64 let indices: &[I] = indices_array.as_slice::<I>();
65
66 if !indices_validity_mask.all_true() || !data_validity.all_true() {
67 return _take_nullable::<I, O>(
68 array,
69 offsets,
70 indices,
71 data_validity,
72 indices_validity_mask,
73 );
74 }
75
76 let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
77 let mut elements_to_take =
78 PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
79
80 let mut current_offset = O::zero();
81 new_offsets.append_zero();
82
83 for &data_idx in indices {
84 let data_idx = data_idx
85 .to_usize()
86 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
87
88 let start = offsets[data_idx];
89 let stop = offsets[data_idx + 1];
90
91 let additional = (stop - start).to_usize().unwrap_or_else(|| {
97 vortex_panic!("Failed to convert range length to usize: {}", stop - start)
98 });
99
100 elements_to_take.reserve_exact(additional);
101 for i in 0..additional {
102 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
103 }
104 current_offset += stop - start;
105 new_offsets.append_value(current_offset);
106 }
107
108 let elements_to_take = elements_to_take.finish();
109 let new_offsets = new_offsets.finish();
110
111 let new_elements = take(array.elements(), elements_to_take.as_ref())?;
112
113 Ok(ListArray::try_new(
114 new_elements,
115 new_offsets,
116 indices_array
117 .validity()
118 .clone()
119 .and(array.validity().clone()),
120 )?
121 .to_array())
122}
123
124fn _take_nullable<I: IntegerPType, O: IntegerPType>(
125 array: &ListArray,
126 offsets: &[O],
127 indices: &[I],
128 data_validity: Mask,
129 indices_validity: Mask,
130) -> VortexResult<ArrayRef> {
131 let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
132
133 let mut elements_to_take =
141 PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
142
143 let mut current_offset = O::zero();
144 new_offsets.append_zero();
145
146 let mut new_validity = BitBufferMut::new_unset(indices.len());
148
149 for (idx, data_idx) in indices.iter().enumerate() {
150 if !indices_validity.value(idx) {
151 new_offsets.append_value(current_offset);
152 continue;
154 }
155
156 let data_idx = data_idx
157 .to_usize()
158 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
159
160 if !data_validity.value(data_idx) {
161 new_offsets.append_value(current_offset);
162 continue;
164 }
165
166 let start = offsets[data_idx];
167 let stop = offsets[data_idx + 1];
168
169 let additional = (stop - start).to_usize().unwrap_or_else(|| {
171 vortex_panic!("Failed to convert range length to usize: {}", stop - start)
172 });
173
174 elements_to_take.reserve_exact(additional);
175 for i in 0..additional {
176 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
177 }
178 current_offset += stop - start;
179 new_offsets.append_value(current_offset);
180 new_validity.set(idx);
181 }
182
183 let elements_to_take = elements_to_take.finish();
184 let new_offsets = new_offsets.finish();
185 let new_elements = take(array.elements(), elements_to_take.as_ref())?;
186
187 let new_validity = Validity::from(new_validity.freeze());
188 Ok(ListArray::try_new(new_elements, new_offsets, new_validity)?.to_array())
191}
192
193#[cfg(test)]
194mod test {
195 use std::sync::Arc;
196
197 use rstest::rstest;
198 use vortex_buffer::buffer;
199 use vortex_dtype::DType;
200 use vortex_dtype::Nullability;
201 use vortex_dtype::PType::I32;
202 use vortex_scalar::Scalar;
203
204 use crate::Array;
205 use crate::IntoArray as _;
206 use crate::ToCanonical;
207 use crate::arrays::BoolArray;
208 use crate::arrays::PrimitiveArray;
209 use crate::arrays::list::ListArray;
210 use crate::compute::conformance::take::test_take_conformance;
211 use crate::compute::take;
212 use crate::validity::Validity;
213
214 #[test]
215 fn nullable_take() {
216 let list = ListArray::try_new(
217 buffer![0i32, 5, 3, 4].into_array(),
218 buffer![0, 2, 3, 4, 4].into_array(),
219 Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
220 )
221 .unwrap()
222 .to_array();
223
224 let idx =
225 PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
226
227 let result = take(&list, &idx).unwrap();
228
229 assert_eq!(
230 result.dtype(),
231 &DType::List(
232 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
233 Nullability::Nullable
234 )
235 );
236
237 let result = result.to_listview();
238
239 assert_eq!(result.len(), 4);
240
241 let element_dtype: Arc<DType> = Arc::new(I32.into());
242
243 assert!(result.is_valid(0));
244 assert_eq!(
245 result.scalar_at(0),
246 Scalar::list(
247 element_dtype.clone(),
248 vec![0i32.into(), 5.into()],
249 Nullability::Nullable
250 )
251 );
252
253 assert!(result.is_invalid(1));
254
255 assert!(result.is_valid(2));
256 assert_eq!(
257 result.scalar_at(2),
258 Scalar::list(
259 element_dtype.clone(),
260 vec![3i32.into()],
261 Nullability::Nullable
262 )
263 );
264
265 assert!(result.is_valid(3));
266 assert_eq!(
267 result.scalar_at(3),
268 Scalar::list(element_dtype, vec![], Nullability::Nullable)
269 );
270 }
271
272 #[test]
273 fn change_validity() {
274 let list = ListArray::try_new(
275 buffer![0i32, 5, 3, 4].into_array(),
276 buffer![0, 2, 3].into_array(),
277 Validity::NonNullable,
278 )
279 .unwrap()
280 .to_array();
281
282 let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
283 let result = take(&list, &idx).unwrap();
286 assert_eq!(
287 result.dtype(),
288 &DType::List(
289 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
290 Nullability::Nullable
291 )
292 );
293 }
294
295 #[test]
296 fn non_nullable_take() {
297 let list = ListArray::try_new(
298 buffer![0i32, 5, 3, 4].into_array(),
299 buffer![0, 2, 3, 3, 4].into_array(),
300 Validity::NonNullable,
301 )
302 .unwrap()
303 .to_array();
304
305 let idx = buffer![1, 0, 2].into_array();
306
307 let result = take(&list, &idx).unwrap();
308
309 assert_eq!(
310 result.dtype(),
311 &DType::List(
312 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
313 Nullability::NonNullable
314 )
315 );
316
317 let result = result.to_listview();
318
319 assert_eq!(result.len(), 3);
320
321 let element_dtype: Arc<DType> = Arc::new(I32.into());
322
323 assert!(result.is_valid(0));
324 assert_eq!(
325 result.scalar_at(0),
326 Scalar::list(
327 element_dtype.clone(),
328 vec![3i32.into()],
329 Nullability::NonNullable
330 )
331 );
332
333 assert!(result.is_valid(1));
334 assert_eq!(
335 result.scalar_at(1),
336 Scalar::list(
337 element_dtype.clone(),
338 vec![0i32.into(), 5.into()],
339 Nullability::NonNullable
340 )
341 );
342
343 assert!(result.is_valid(2));
344 assert_eq!(
345 result.scalar_at(2),
346 Scalar::list(element_dtype, vec![], Nullability::NonNullable)
347 );
348 }
349
350 #[test]
351 fn test_take_empty_array() {
352 let list = ListArray::try_new(
353 buffer![0i32, 5, 3, 4].into_array(),
354 buffer![0].into_array(),
355 Validity::NonNullable,
356 )
357 .unwrap()
358 .to_array();
359
360 let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
361
362 let result = take(&list, &idx).unwrap();
363 assert_eq!(
364 result.dtype(),
365 &DType::List(
366 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
367 Nullability::Nullable
368 )
369 );
370 assert_eq!(result.len(), 0,);
371 }
372
373 #[rstest]
374 #[case(ListArray::try_new(
375 buffer![0i32, 1, 2, 3, 4, 5].into_array(),
376 buffer![0, 2, 3, 5, 5, 6].into_array(),
377 Validity::NonNullable,
378 ).unwrap())]
379 #[case(ListArray::try_new(
380 buffer![10i32, 20, 30, 40, 50].into_array(),
381 buffer![0, 2, 3, 4, 5].into_array(),
382 Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).to_array()),
383 ).unwrap())]
384 #[case(ListArray::try_new(
385 buffer![1i32, 2, 3].into_array(),
386 buffer![0, 0, 2, 2, 3].into_array(), Validity::NonNullable,
388 ).unwrap())]
389 #[case(ListArray::try_new(
390 buffer![42i32, 43].into_array(),
391 buffer![0, 2].into_array(),
392 Validity::NonNullable,
393 ).unwrap())]
394 #[case({
395 let elements = buffer![0i32..200].into_array();
396 let mut offsets = vec![0u64];
397 for i in 1..=50 {
398 offsets.push(offsets[i - 1] + (i as u64 % 5)); }
400 ListArray::try_new(
401 elements,
402 PrimitiveArray::from_iter(offsets).to_array(),
403 Validity::NonNullable,
404 ).unwrap()
405 })]
406 #[case(ListArray::try_new(
407 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).to_array(),
408 buffer![0, 2, 3, 5].into_array(),
409 Validity::NonNullable,
410 ).unwrap())]
411 fn test_take_list_conformance(#[case] list: ListArray) {
412 test_take_conformance(list.as_ref());
413 }
414}