1use arrow_buffer::BooleanBufferBuilder;
5use num_traits::PrimInt;
6use vortex_dtype::{NativePType, Nullability, match_each_integer_ptype};
7use vortex_error::{VortexExpect, VortexResult, vortex_panic};
8use vortex_mask::Mask;
9
10use crate::arrays::{ListArray, ListVTable, OffsetPType, PrimitiveArray};
11use crate::builders::{ArrayBuilder, PrimitiveBuilder};
12use crate::compute::{TakeKernel, TakeKernelAdapter, take};
13use crate::validity::Validity;
14use crate::vtable::ValidityHelper;
15use crate::{Array, ArrayRef, ToCanonical, register_kernel};
16
17impl TakeKernel for ListVTable {
18 fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
19 let indices = indices.to_primitive()?;
20 let offsets = array.offsets().to_primitive()?;
21
22 match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
23 match_each_integer_ptype!(indices.ptype(), |I| {
24 _take::<I, O>(
25 array,
26 offsets.as_slice::<O>(),
27 &indices,
28 array.validity_mask()?,
29 indices.validity_mask()?,
30 )
31 })
32 })
33 }
34}
35
36register_kernel!(TakeKernelAdapter(ListVTable).lift());
37
38fn _take<I: NativePType, O: OffsetPType + NativePType + PrimInt>(
39 array: &ListArray,
40 offsets: &[O],
41 indices_array: &PrimitiveArray,
42 data_validity: Mask,
43 indices_validity_mask: Mask,
44) -> VortexResult<ArrayRef> {
45 let indices: &[I] = indices_array.as_slice::<I>();
46
47 if !indices_validity_mask.all_true() || !data_validity.all_true() {
48 return _take_nullable::<I, O>(
49 array,
50 offsets,
51 indices,
52 data_validity,
53 indices_validity_mask,
54 );
55 }
56
57 let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
58 let mut elements_to_take =
59 PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
60
61 let mut current_offset = O::zero();
62 new_offsets.append_zero();
63
64 for &data_idx in indices {
65 let data_idx = data_idx
66 .to_usize()
67 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
68
69 let start = offsets[data_idx];
70 let stop = offsets[data_idx + 1];
71
72 let additional = (stop - start).to_usize().unwrap_or_else(|| {
78 vortex_panic!("Failed to convert range length to usize: {}", stop - start)
79 });
80
81 elements_to_take.ensure_capacity(elements_to_take.len() + additional);
82 for i in 0..additional {
83 elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
84 }
85 current_offset = current_offset + (stop - start);
86 new_offsets.append_value(current_offset);
87 }
88
89 let elements_to_take = elements_to_take.finish();
90 let new_offsets = new_offsets.finish();
91
92 let new_elements = take(array.elements(), elements_to_take.as_ref())?;
93
94 Ok(ListArray::try_new(
95 new_elements,
96 new_offsets,
97 indices_array
98 .validity()
99 .clone()
100 .and(array.validity().clone())?,
101 )?
102 .to_array())
103}
104
105fn _take_nullable<I: NativePType, O: OffsetPType + NativePType + PrimInt>(
106 array: &ListArray,
107 offsets: &[O],
108 indices: &[I],
109 data_validity: Mask,
110 indices_validity: Mask,
111) -> VortexResult<ArrayRef> {
112 let mut new_offsets = PrimitiveBuilder::with_capacity(Nullability::NonNullable, indices.len());
113 let mut elements_to_take =
114 PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
115
116 let mut current_offset = O::zero();
117 new_offsets.append_zero();
118 let mut new_validity = BooleanBufferBuilder::new(2 * indices.len());
119
120 for (idx, data_idx) in indices.iter().enumerate() {
121 if !indices_validity.value(idx) {
122 new_offsets.append_value(current_offset);
123 new_validity.append(false);
124 continue;
125 }
126
127 let data_idx = data_idx
128 .to_usize()
129 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
130
131 if data_validity.value(data_idx) {
132 let start = offsets[data_idx];
133 let stop = offsets[data_idx + 1];
134
135 let additional = (stop - start).to_usize().unwrap_or_else(|| {
137 vortex_panic!("Failed to convert range length to usize: {}", stop - start)
138 });
139
140 elements_to_take.ensure_capacity(elements_to_take.len() + additional);
141 for i in 0..additional {
142 elements_to_take
143 .append_value(start + O::from_usize(i).vortex_expect("i < additional"));
144 }
145 current_offset = current_offset + (stop - start);
146 new_offsets.append_value(current_offset);
147 new_validity.append(true);
148 } else {
149 new_offsets.append_value(current_offset);
150 new_validity.append(false);
151 }
152 }
153
154 let elements_to_take = elements_to_take.finish();
155 let new_offsets = new_offsets.finish();
156 let new_elements = take(array.elements(), elements_to_take.as_ref())?;
157
158 let new_validity: Validity = Validity::from(new_validity.finish());
159 Ok(ListArray::try_new(new_elements, new_offsets, new_validity)?.to_array())
162}
163
164#[cfg(test)]
165mod test {
166 use std::sync::Arc;
167
168 use rstest::rstest;
169 use vortex_dtype::PType::I32;
170 use vortex_dtype::{DType, Nullability};
171 use vortex_scalar::Scalar;
172
173 use crate::arrays::list::ListArray;
174 use crate::arrays::{BoolArray, PrimitiveArray};
175 use crate::compute::conformance::take::test_take_conformance;
176 use crate::compute::take;
177 use crate::validity::Validity;
178 use crate::{Array, ToCanonical};
179
180 #[test]
181 fn nullable_take() {
182 let list = ListArray::try_new(
183 PrimitiveArray::from_iter([0i32, 5, 3, 4]).to_array(),
184 PrimitiveArray::from_iter([0, 2, 3, 4, 4]).to_array(),
185 Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).to_array()),
186 )
187 .unwrap()
188 .to_array();
189
190 let idx =
191 PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).to_array();
192
193 let result = take(&list, &idx).unwrap();
194
195 assert_eq!(
196 result.dtype(),
197 &DType::List(
198 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
199 Nullability::Nullable
200 )
201 );
202
203 let result = result.to_list().unwrap();
204
205 assert_eq!(result.len(), 4);
206
207 let element_dtype: Arc<DType> = Arc::new(I32.into());
208
209 assert!(result.is_valid(0).unwrap());
210 assert_eq!(
211 result.scalar_at(0),
212 Scalar::list(
213 element_dtype.clone(),
214 vec![0i32.into(), 5.into()],
215 Nullability::Nullable
216 )
217 );
218
219 assert!(result.is_invalid(1).unwrap());
220
221 assert!(result.is_valid(2).unwrap());
222 assert_eq!(
223 result.scalar_at(2),
224 Scalar::list(
225 element_dtype.clone(),
226 vec![3i32.into()],
227 Nullability::Nullable
228 )
229 );
230
231 assert!(result.is_valid(3).unwrap());
232 assert_eq!(
233 result.scalar_at(3),
234 Scalar::list(element_dtype, vec![], Nullability::Nullable)
235 );
236 }
237
238 #[test]
239 fn change_validity() {
240 let list = ListArray::try_new(
241 PrimitiveArray::from_iter([0i32, 5, 3, 4]).to_array(),
242 PrimitiveArray::from_iter([0, 2, 3]).to_array(),
243 Validity::NonNullable,
244 )
245 .unwrap()
246 .to_array();
247
248 let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).to_array();
249 let result = take(&list, &idx).unwrap();
252 assert_eq!(
253 result.dtype(),
254 &DType::List(
255 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
256 Nullability::Nullable
257 )
258 );
259 }
260
261 #[test]
262 fn non_nullable_take() {
263 let list = ListArray::try_new(
264 PrimitiveArray::from_iter([0i32, 5, 3, 4]).to_array(),
265 PrimitiveArray::from_iter([0, 2, 3, 3, 4]).to_array(),
266 Validity::NonNullable,
267 )
268 .unwrap()
269 .to_array();
270
271 let idx = PrimitiveArray::from_iter([1, 0, 2]).to_array();
272
273 let result = take(&list, &idx).unwrap();
274
275 assert_eq!(
276 result.dtype(),
277 &DType::List(
278 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
279 Nullability::NonNullable
280 )
281 );
282
283 let result = result.to_list().unwrap();
284
285 assert_eq!(result.len(), 3);
286
287 let element_dtype: Arc<DType> = Arc::new(I32.into());
288
289 assert!(result.is_valid(0).unwrap());
290 assert_eq!(
291 result.scalar_at(0),
292 Scalar::list(
293 element_dtype.clone(),
294 vec![3i32.into()],
295 Nullability::NonNullable
296 )
297 );
298
299 assert!(result.is_valid(1).unwrap());
300 assert_eq!(
301 result.scalar_at(1),
302 Scalar::list(
303 element_dtype.clone(),
304 vec![0i32.into(), 5.into()],
305 Nullability::NonNullable
306 )
307 );
308
309 assert!(result.is_valid(2).unwrap());
310 assert_eq!(
311 result.scalar_at(2),
312 Scalar::list(element_dtype, vec![], Nullability::NonNullable)
313 );
314 }
315
316 #[test]
317 fn test_take_empty_array() {
318 let list = ListArray::try_new(
319 PrimitiveArray::from_iter([0i32, 5, 3, 4]).to_array(),
320 PrimitiveArray::from_iter([0]).to_array(),
321 Validity::NonNullable,
322 )
323 .unwrap()
324 .to_array();
325
326 let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).to_array();
327
328 let result = take(&list, &idx).unwrap();
329 assert_eq!(
330 result.dtype(),
331 &DType::List(
332 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
333 Nullability::Nullable
334 )
335 );
336 assert_eq!(result.len(), 0,);
337 }
338
339 #[rstest]
340 #[case(ListArray::try_new(
341 PrimitiveArray::from_iter([0i32, 1, 2, 3, 4, 5]).to_array(),
342 PrimitiveArray::from_iter([0, 2, 3, 5, 5, 6]).to_array(),
343 Validity::NonNullable,
344 ).unwrap())]
345 #[case(ListArray::try_new(
346 PrimitiveArray::from_iter([10i32, 20, 30, 40, 50]).to_array(),
347 PrimitiveArray::from_iter([0, 2, 3, 4, 5]).to_array(),
348 Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).to_array()),
349 ).unwrap())]
350 #[case(ListArray::try_new(
351 PrimitiveArray::from_iter([1i32, 2, 3]).to_array(),
352 PrimitiveArray::from_iter([0, 0, 2, 2, 3]).to_array(), Validity::NonNullable,
354 ).unwrap())]
355 #[case(ListArray::try_new(
356 PrimitiveArray::from_iter([42i32, 43]).to_array(),
357 PrimitiveArray::from_iter([0, 2]).to_array(),
358 Validity::NonNullable,
359 ).unwrap())]
360 #[case({
361 let elements = PrimitiveArray::from_iter(0i32..200).to_array();
362 let mut offsets = vec![0u64];
363 for i in 1..=50 {
364 offsets.push(offsets[i - 1] + (i as u64 % 5)); }
366 ListArray::try_new(
367 elements,
368 PrimitiveArray::from_iter(offsets).to_array(),
369 Validity::NonNullable,
370 ).unwrap()
371 })]
372 #[case(ListArray::try_new(
373 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).to_array(),
374 PrimitiveArray::from_iter([0, 2, 3, 5]).to_array(),
375 Validity::NonNullable,
376 ).unwrap())]
377 fn test_take_list_conformance(#[case] list: ListArray) {
378 test_take_conformance(list.as_ref());
379 }
380}