1use vortex_buffer::BufferMut;
5use vortex_error::VortexResult;
6use vortex_mask::Mask;
7
8use crate::ArrayRef;
9use crate::Canonical;
10use crate::IntoArray;
11use crate::array::ArrayView;
12use crate::arrays::Chunked;
13use crate::arrays::ChunkedArray;
14use crate::arrays::PrimitiveArray;
15use crate::arrays::chunked::ChunkedArrayExt;
16use crate::arrays::dict::TakeExecute;
17use crate::builtins::ArrayBuiltins;
18use crate::dtype::DType;
19use crate::dtype::PType;
20use crate::executor::ExecutionCtx;
21use crate::validity::Validity;
22
23fn take_chunked(
26 array: ArrayView<'_, Chunked>,
27 indices: &ArrayRef,
28 ctx: &mut ExecutionCtx,
29) -> VortexResult<ArrayRef> {
30 let indices = indices
31 .cast(DType::Primitive(PType::U64, indices.dtype().nullability()))?
32 .execute::<PrimitiveArray>(ctx)?;
33
34 let indices_mask = indices
35 .as_ref()
36 .validity()?
37 .execute_mask(indices.as_ref().len(), ctx)?;
38 let indices_values = indices.as_slice::<u64>();
39 let n = indices_values.len();
40
41 let mut pairs: Vec<(u64, usize)> = indices_values
44 .iter()
45 .enumerate()
46 .filter(|&(i, _)| indices_mask.value(i))
47 .map(|(i, &v)| (v, i))
48 .collect();
49 pairs.sort_unstable();
50
51 let chunk_offsets = array.chunk_offsets();
55 let nchunks = array.nchunks();
56 let mut chunks = Vec::with_capacity(nchunks);
57 let mut final_take = BufferMut::<u64>::with_capacity(n);
58 final_take.push_n(0u64, n);
59
60 let mut cursor = 0usize;
61 let mut dedup_idx = 0u64;
62
63 for chunk_idx in 0..nchunks {
64 let chunk_start = chunk_offsets[chunk_idx];
65 let chunk_end = chunk_offsets[chunk_idx + 1];
66 let chunk_len = chunk_end - chunk_start;
67 let chunk_end_u64 = u64::try_from(chunk_end)?;
68
69 let range_end = cursor + pairs[cursor..].partition_point(|&(v, _)| v < chunk_end_u64);
70 let chunk_pairs = &pairs[cursor..range_end];
71
72 if !chunk_pairs.is_empty() {
73 let mut local_indices: Vec<usize> = Vec::new();
74 for (i, &(val, orig_pos)) in chunk_pairs.iter().enumerate() {
75 if cursor + i > 0 && val != pairs[cursor + i - 1].0 {
76 dedup_idx += 1;
77 }
78 let local = usize::try_from(val)? - chunk_start;
79 if local_indices.last() != Some(&local) {
80 local_indices.push(local);
81 }
82 final_take[orig_pos] = dedup_idx;
83 }
84
85 let filter_mask = Mask::from_indices(chunk_len, local_indices);
86 chunks.push(array.chunk(chunk_idx).filter(filter_mask)?);
87 }
88
89 cursor = range_end;
90 }
91
92 let flat = unsafe { ChunkedArray::new_unchecked(chunks, array.dtype().clone()) }
95 .into_array()
96 .execute::<Canonical>(ctx)?
98 .into_array();
99
100 let take_validity = Validity::from_mask(
103 indices
104 .as_ref()
105 .validity()?
106 .execute_mask(indices.as_ref().len(), ctx)?,
107 indices.dtype().nullability(),
108 );
109 flat.take(PrimitiveArray::new(final_take.freeze(), take_validity).into_array())
110}
111
112impl TakeExecute for Chunked {
113 fn take(
114 array: ArrayView<'_, Chunked>,
115 indices: &ArrayRef,
116 ctx: &mut ExecutionCtx,
117 ) -> VortexResult<Option<ArrayRef>> {
118 take_chunked(array, indices, ctx).map(Some)
119 }
120}
121
122#[cfg(test)]
123mod test {
124 use vortex_buffer::bitbuffer;
125 use vortex_buffer::buffer;
126 use vortex_error::VortexResult;
127
128 use crate::IntoArray;
129 #[expect(deprecated)]
130 use crate::ToCanonical as _;
131 use crate::VortexSessionExecute;
132 use crate::array_session;
133 use crate::arrays::BoolArray;
134 use crate::arrays::ChunkedArray;
135 use crate::arrays::PrimitiveArray;
136 use crate::arrays::StructArray;
137 use crate::arrays::chunked::ChunkedArrayExt;
138 use crate::assert_arrays_eq;
139 use crate::compute::conformance::take::test_take_conformance;
140 use crate::dtype::FieldNames;
141 use crate::dtype::Nullability;
142 use crate::validity::Validity;
143
144 #[test]
145 fn test_take() {
146 let mut ctx = array_session().create_execution_ctx();
147 let a = buffer![1i32, 2, 3].into_array();
148 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
149 .unwrap();
150 assert_eq!(arr.nchunks(), 3);
151 assert_eq!(arr.len(), 9);
152 let indices = buffer![0u64, 0, 6, 4].into_array();
153
154 let result = arr.take(indices).unwrap();
155 assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 1, 1, 2]), &mut ctx);
156 }
157
158 #[test]
159 fn test_take_nullable_values() {
160 let mut ctx = array_session().create_execution_ctx();
161 let a = PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array();
162 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
163 .unwrap();
164 assert_eq!(arr.nchunks(), 3);
165 assert_eq!(arr.len(), 9);
166 let indices = PrimitiveArray::new(buffer![0u64, 0, 6, 4], Validity::NonNullable);
167
168 let result = arr.take(indices.into_array()).unwrap();
169 assert_arrays_eq!(
170 result,
171 PrimitiveArray::from_option_iter([1i32, 1, 1, 2].map(Some)),
172 &mut ctx
173 );
174 }
175
176 #[test]
177 fn test_take_nullable_indices() {
178 let mut ctx = array_session().create_execution_ctx();
179 let a = buffer![1i32, 2, 3].into_array();
180 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
181 .unwrap();
182 assert_eq!(arr.nchunks(), 3);
183 assert_eq!(arr.len(), 9);
184 let indices = PrimitiveArray::new(
185 buffer![0u64, 0, 6, 4],
186 Validity::Array(bitbuffer![1 0 0 1].into_array()),
187 );
188
189 let result = arr.take(indices.into_array()).unwrap();
190 assert_arrays_eq!(
191 result,
192 PrimitiveArray::from_option_iter([Some(1i32), None, None, Some(2)]),
193 &mut ctx
194 );
195 }
196
197 #[test]
198 fn test_take_nullable_struct() {
199 let mut ctx = array_session().create_execution_ctx();
200 let struct_array =
201 StructArray::try_new(FieldNames::default(), vec![], 100, Validity::NonNullable)
202 .unwrap();
203
204 let arr = ChunkedArray::from_iter(vec![
205 struct_array.clone().into_array(),
206 struct_array.into_array(),
207 ]);
208
209 let result = arr
210 .take(PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).into_array())
211 .unwrap();
212
213 let expect = StructArray::try_new(
214 FieldNames::default(),
215 vec![],
216 3,
217 Validity::Array(BoolArray::from_iter(vec![true, false, true]).into_array()),
218 )
219 .unwrap();
220 assert_arrays_eq!(result, expect, &mut ctx);
221 }
222
223 #[test]
224 fn test_empty_take() {
225 let mut ctx = array_session().create_execution_ctx();
226 let a = buffer![1i32, 2, 3].into_array();
227 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
228 .unwrap();
229 assert_eq!(arr.nchunks(), 3);
230 assert_eq!(arr.len(), 9);
231
232 let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
233 let result = arr.take(indices.into_array()).unwrap();
234
235 assert!(result.is_empty());
236 assert_eq!(result.dtype(), arr.dtype());
237 assert_arrays_eq!(
238 result,
239 PrimitiveArray::empty::<i32>(Nullability::NonNullable),
240 &mut ctx
241 );
242 }
243
244 #[test]
245 fn test_take_shuffled_indices() -> VortexResult<()> {
246 let mut ctx = array_session().create_execution_ctx();
247 let c0 = buffer![0i32, 1, 2].into_array();
248 let c1 = buffer![3i32, 4, 5].into_array();
249 let c2 = buffer![6i32, 7, 8].into_array();
250 let arr = ChunkedArray::try_new(
251 vec![c0, c1, c2],
252 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
253 .dtype()
254 .clone(),
255 )?;
256
257 let indices = buffer![8u64, 0, 5, 3, 2, 7, 1, 6, 4].into_array();
259 let result = arr.take(indices)?;
260
261 assert_arrays_eq!(
262 result,
263 PrimitiveArray::from_iter([8i32, 0, 5, 3, 2, 7, 1, 6, 4]),
264 &mut ctx
265 );
266 Ok(())
267 }
268
269 #[test]
270 fn test_take_shuffled_large() -> VortexResult<()> {
271 let nchunks: i32 = 100;
272 let chunk_len: i32 = 1_000;
273 let total = nchunks * chunk_len;
274
275 let chunks: Vec<_> = (0..nchunks)
276 .map(|c| {
277 let start = c * chunk_len;
278 PrimitiveArray::from_iter(start..start + chunk_len).into_array()
279 })
280 .collect();
281 let dtype = chunks[0].dtype().clone();
282 let arr = ChunkedArray::try_new(chunks, dtype)?;
283
284 let mut indices: Vec<u64> = (0..u64::try_from(total)?).collect();
286 let mut seed: u64 = 0xdeadbeef;
287 for i in (1..indices.len()).rev() {
288 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
289 let j = (seed >> 33) as usize % (i + 1);
290 indices.swap(i, j);
291 }
292
293 let indices_arr = PrimitiveArray::new(
294 vortex_buffer::Buffer::from(indices.clone()),
295 Validity::NonNullable,
296 );
297 let result = arr.take(indices_arr.into_array())?;
298
299 #[expect(deprecated)]
301 let result = result.to_primitive();
302 let result_vals = result.as_slice::<i32>();
303 for (pos, &idx) in indices.iter().enumerate() {
304 assert_eq!(
305 result_vals[pos],
306 i32::try_from(idx)?,
307 "mismatch at position {pos}"
308 );
309 }
310 Ok(())
311 }
312
313 #[test]
314 fn test_take_null_indices() -> VortexResult<()> {
315 let mut ctx = array_session().create_execution_ctx();
316 let c0 = buffer![10i32, 20, 30].into_array();
317 let c1 = buffer![40i32, 50, 60].into_array();
318 let arr = ChunkedArray::try_new(
319 vec![c0, c1],
320 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
321 .dtype()
322 .clone(),
323 )?;
324
325 let indices =
327 PrimitiveArray::from_option_iter([Some(5u64), None, Some(0), Some(3), None, Some(2)]);
328 let result = arr.take(indices.into_array())?;
329
330 assert_arrays_eq!(
331 result,
332 PrimitiveArray::from_option_iter([
333 Some(60i32),
334 None,
335 Some(10),
336 Some(40),
337 None,
338 Some(30)
339 ]),
340 &mut ctx
341 );
342 Ok(())
343 }
344
345 #[test]
346 fn test_take_chunked_conformance() {
347 let a = buffer![1i32, 2, 3].into_array();
348 let b = buffer![4i32, 5].into_array();
349 let arr = ChunkedArray::try_new(
350 vec![a, b],
351 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
352 .dtype()
353 .clone(),
354 )
355 .unwrap();
356 test_take_conformance(&arr.into_array());
357
358 let a = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
360 let b = PrimitiveArray::from_option_iter([Some(4i32), Some(5)]);
361 let dtype = a.dtype().clone();
362 let arr = ChunkedArray::try_new(vec![a.into_array(), b.into_array()], dtype).unwrap();
363 test_take_conformance(&arr.into_array());
364
365 let chunk = buffer![10i32, 20, 30, 40, 50].into_array();
367 let arr = ChunkedArray::try_new(
368 vec![chunk.clone(), chunk.clone(), chunk.clone()],
369 chunk.dtype().clone(),
370 )
371 .unwrap();
372 test_take_conformance(&arr.into_array());
373 }
374}