1use vortex_buffer::BitBufferMut;
5use vortex_dtype::{IntegerPType, Nullability, match_each_integer_ptype};
6use vortex_error::{VortexExpect, VortexResult, vortex_panic};
7use vortex_mask::Mask;
8
9use crate::arrays::{ListArray, ListVTable, PrimitiveArray};
10use crate::builders::{ArrayBuilder, PrimitiveBuilder};
11use crate::compute::{TakeKernel, TakeKernelAdapter, take};
12use crate::validity::Validity;
13use crate::vtable::ValidityHelper;
14use crate::{Array, ArrayRef, ToCanonical, register_kernel};
15
16impl TakeKernel for ListVTable {
25 fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
26 let indices = indices.to_primitive();
27 let offsets = array.offsets().to_primitive();
28
29 match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
30 match_each_integer_ptype!(indices.ptype(), |I| {
31 _take::<I, O>(
32 array,
33 offsets.as_slice::<O>(),
34 &indices,
35 array.validity_mask(),
36 indices.validity_mask(),
37 )
38 })
39 })
40 }
41}
42
43register_kernel!(TakeKernelAdapter(ListVTable).lift());
44
45fn _take<I: IntegerPType, O: IntegerPType>(
46 array: &ListArray,
47 offsets: &[O],
48 indices_array: &PrimitiveArray,
49 data_validity: Mask,
50 indices_validity_mask: Mask,
51) -> VortexResult<ArrayRef> {
52 let indices: &[I] = indices_array.as_slice::<I>();
53
54 if !indices_validity_mask.all_true() || !data_validity.all_true() {
55 return _take_nullable::<I, O>(
56 array,
57 offsets,
58 indices,
59 data_validity,
60 indices_validity_mask,
61 );
62 }
63
64 let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
65 let mut elements_to_take =
66 PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
67
68 let mut current_offset = O::zero();
69 new_offsets.append_zero();
70
71 for &data_idx in indices {
72 let data_idx = data_idx
73 .to_usize()
74 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
75
76 let start = offsets[data_idx];
77 let stop = offsets[data_idx + 1];
78
79 let additional = (stop - start).to_usize().unwrap_or_else(|| {
85 vortex_panic!("Failed to convert range length to usize: {}", stop - start)
86 });
87
88 elements_to_take.reserve_exact(additional);
89 for i in 0..additional {
90 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
91 }
92 current_offset += stop - start;
93 new_offsets.append_value(current_offset);
94 }
95
96 let elements_to_take = elements_to_take.finish();
97 let new_offsets = new_offsets.finish();
98
99 let new_elements = take(array.elements(), elements_to_take.as_ref())?;
100
101 Ok(ListArray::try_new(
102 new_elements,
103 new_offsets,
104 indices_array
105 .validity()
106 .clone()
107 .and(array.validity().clone()),
108 )?
109 .to_array())
110}
111
112fn _take_nullable<I: IntegerPType, O: IntegerPType>(
113 array: &ListArray,
114 offsets: &[O],
115 indices: &[I],
116 data_validity: Mask,
117 indices_validity: Mask,
118) -> VortexResult<ArrayRef> {
119 let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
120
121 let mut elements_to_take =
129 PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
130
131 let mut current_offset = O::zero();
132 new_offsets.append_zero();
133
134 let mut new_validity = BitBufferMut::new_unset(indices.len());
136
137 for (idx, data_idx) in indices.iter().enumerate() {
138 if !indices_validity.value(idx) {
139 new_offsets.append_value(current_offset);
140 continue;
142 }
143
144 let data_idx = data_idx
145 .to_usize()
146 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
147
148 if !data_validity.value(data_idx) {
149 new_offsets.append_value(current_offset);
150 continue;
152 }
153
154 let start = offsets[data_idx];
155 let stop = offsets[data_idx + 1];
156
157 let additional = (stop - start).to_usize().unwrap_or_else(|| {
159 vortex_panic!("Failed to convert range length to usize: {}", stop - start)
160 });
161
162 elements_to_take.reserve_exact(additional);
163 for i in 0..additional {
164 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
165 }
166 current_offset += stop - start;
167 new_offsets.append_value(current_offset);
168 new_validity.set(idx);
169 }
170
171 let elements_to_take = elements_to_take.finish();
172 let new_offsets = new_offsets.finish();
173 let new_elements = take(array.elements(), elements_to_take.as_ref())?;
174
175 let new_validity = Validity::from(new_validity.freeze());
176 Ok(ListArray::try_new(new_elements, new_offsets, new_validity)?.to_array())
179}
180
181#[cfg(test)]
182mod test {
183 use std::sync::Arc;
184
185 use rstest::rstest;
186 use vortex_buffer::buffer;
187 use vortex_dtype::PType::I32;
188 use vortex_dtype::{DType, Nullability};
189 use vortex_scalar::Scalar;
190
191 use crate::arrays::list::ListArray;
192 use crate::arrays::{BoolArray, PrimitiveArray};
193 use crate::compute::conformance::take::test_take_conformance;
194 use crate::compute::take;
195 use crate::validity::Validity;
196 use crate::{Array, IntoArray as _, ToCanonical};
197
198 #[test]
199 fn nullable_take() {
200 let list = ListArray::try_new(
201 buffer![0i32, 5, 3, 4].into_array(),
202 buffer![0, 2, 3, 4, 4].into_array(),
203 Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
204 )
205 .unwrap()
206 .to_array();
207
208 let idx =
209 PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
210
211 let result = take(&list, &idx).unwrap();
212
213 assert_eq!(
214 result.dtype(),
215 &DType::List(
216 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
217 Nullability::Nullable
218 )
219 );
220
221 let result = result.to_listview();
222
223 assert_eq!(result.len(), 4);
224
225 let element_dtype: Arc<DType> = Arc::new(I32.into());
226
227 assert!(result.is_valid(0));
228 assert_eq!(
229 result.scalar_at(0),
230 Scalar::list(
231 element_dtype.clone(),
232 vec![0i32.into(), 5.into()],
233 Nullability::Nullable
234 )
235 );
236
237 assert!(result.is_invalid(1));
238
239 assert!(result.is_valid(2));
240 assert_eq!(
241 result.scalar_at(2),
242 Scalar::list(
243 element_dtype.clone(),
244 vec![3i32.into()],
245 Nullability::Nullable
246 )
247 );
248
249 assert!(result.is_valid(3));
250 assert_eq!(
251 result.scalar_at(3),
252 Scalar::list(element_dtype, vec![], Nullability::Nullable)
253 );
254 }
255
256 #[test]
257 fn change_validity() {
258 let list = ListArray::try_new(
259 buffer![0i32, 5, 3, 4].into_array(),
260 buffer![0, 2, 3].into_array(),
261 Validity::NonNullable,
262 )
263 .unwrap()
264 .to_array();
265
266 let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
267 let result = take(&list, &idx).unwrap();
270 assert_eq!(
271 result.dtype(),
272 &DType::List(
273 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
274 Nullability::Nullable
275 )
276 );
277 }
278
279 #[test]
280 fn non_nullable_take() {
281 let list = ListArray::try_new(
282 buffer![0i32, 5, 3, 4].into_array(),
283 buffer![0, 2, 3, 3, 4].into_array(),
284 Validity::NonNullable,
285 )
286 .unwrap()
287 .to_array();
288
289 let idx = buffer![1, 0, 2].into_array();
290
291 let result = take(&list, &idx).unwrap();
292
293 assert_eq!(
294 result.dtype(),
295 &DType::List(
296 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
297 Nullability::NonNullable
298 )
299 );
300
301 let result = result.to_listview();
302
303 assert_eq!(result.len(), 3);
304
305 let element_dtype: Arc<DType> = Arc::new(I32.into());
306
307 assert!(result.is_valid(0));
308 assert_eq!(
309 result.scalar_at(0),
310 Scalar::list(
311 element_dtype.clone(),
312 vec![3i32.into()],
313 Nullability::NonNullable
314 )
315 );
316
317 assert!(result.is_valid(1));
318 assert_eq!(
319 result.scalar_at(1),
320 Scalar::list(
321 element_dtype.clone(),
322 vec![0i32.into(), 5.into()],
323 Nullability::NonNullable
324 )
325 );
326
327 assert!(result.is_valid(2));
328 assert_eq!(
329 result.scalar_at(2),
330 Scalar::list(element_dtype, vec![], Nullability::NonNullable)
331 );
332 }
333
334 #[test]
335 fn test_take_empty_array() {
336 let list = ListArray::try_new(
337 buffer![0i32, 5, 3, 4].into_array(),
338 buffer![0].into_array(),
339 Validity::NonNullable,
340 )
341 .unwrap()
342 .to_array();
343
344 let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
345
346 let result = take(&list, &idx).unwrap();
347 assert_eq!(
348 result.dtype(),
349 &DType::List(
350 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
351 Nullability::Nullable
352 )
353 );
354 assert_eq!(result.len(), 0,);
355 }
356
357 #[rstest]
358 #[case(ListArray::try_new(
359 buffer![0i32, 1, 2, 3, 4, 5].into_array(),
360 buffer![0, 2, 3, 5, 5, 6].into_array(),
361 Validity::NonNullable,
362 ).unwrap())]
363 #[case(ListArray::try_new(
364 buffer![10i32, 20, 30, 40, 50].into_array(),
365 buffer![0, 2, 3, 4, 5].into_array(),
366 Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).to_array()),
367 ).unwrap())]
368 #[case(ListArray::try_new(
369 buffer![1i32, 2, 3].into_array(),
370 buffer![0, 0, 2, 2, 3].into_array(), Validity::NonNullable,
372 ).unwrap())]
373 #[case(ListArray::try_new(
374 buffer![42i32, 43].into_array(),
375 buffer![0, 2].into_array(),
376 Validity::NonNullable,
377 ).unwrap())]
378 #[case({
379 let elements = buffer![0i32..200].into_array();
380 let mut offsets = vec![0u64];
381 for i in 1..=50 {
382 offsets.push(offsets[i - 1] + (i as u64 % 5)); }
384 ListArray::try_new(
385 elements,
386 PrimitiveArray::from_iter(offsets).to_array(),
387 Validity::NonNullable,
388 ).unwrap()
389 })]
390 #[case(ListArray::try_new(
391 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).to_array(),
392 buffer![0, 2, 3, 5].into_array(),
393 Validity::NonNullable,
394 ).unwrap())]
395 fn test_take_list_conformance(#[case] list: ListArray) {
396 test_take_conformance(list.as_ref());
397 }
398}