1use vortex_buffer::BufferMut;
5use vortex_error::VortexResult;
6use vortex_mask::Mask;
7
8use crate::ArrayRef;
9use crate::Canonical;
10use crate::IntoArray;
11use crate::array::ArrayView;
12use crate::arrays::Chunked;
13use crate::arrays::ChunkedArray;
14use crate::arrays::PrimitiveArray;
15use crate::arrays::chunked::ChunkedArrayExt;
16use crate::arrays::dict::TakeExecute;
17use crate::builtins::ArrayBuiltins;
18use crate::dtype::DType;
19use crate::dtype::PType;
20use crate::executor::ExecutionCtx;
21use crate::validity::Validity;
22
23fn take_chunked(
26 array: ArrayView<'_, Chunked>,
27 indices: &ArrayRef,
28 ctx: &mut ExecutionCtx,
29) -> VortexResult<ArrayRef> {
30 let indices = indices
31 .cast(DType::Primitive(PType::U64, indices.dtype().nullability()))?
32 .execute::<PrimitiveArray>(ctx)?;
33
34 let indices_mask = indices.validity_mask()?;
35 let indices_values = indices.as_slice::<u64>();
36 let n = indices_values.len();
37
38 let mut pairs: Vec<(u64, usize)> = indices_values
41 .iter()
42 .enumerate()
43 .filter(|&(i, _)| indices_mask.value(i))
44 .map(|(i, &v)| (v, i))
45 .collect();
46 pairs.sort_unstable();
47
48 let chunk_offsets = array.chunk_offsets();
52 let nchunks = array.nchunks();
53 let mut chunks = Vec::with_capacity(nchunks);
54 let mut final_take = BufferMut::<u64>::with_capacity(n);
55 final_take.push_n(0u64, n);
56
57 let mut cursor = 0usize;
58 let mut dedup_idx = 0u64;
59
60 for chunk_idx in 0..nchunks {
61 let chunk_start = chunk_offsets[chunk_idx];
62 let chunk_end = chunk_offsets[chunk_idx + 1];
63 let chunk_len = chunk_end - chunk_start;
64 let chunk_end_u64 = u64::try_from(chunk_end)?;
65
66 let range_end = cursor + pairs[cursor..].partition_point(|&(v, _)| v < chunk_end_u64);
67 let chunk_pairs = &pairs[cursor..range_end];
68
69 if !chunk_pairs.is_empty() {
70 let mut local_indices: Vec<usize> = Vec::new();
71 for (i, &(val, orig_pos)) in chunk_pairs.iter().enumerate() {
72 if cursor + i > 0 && val != pairs[cursor + i - 1].0 {
73 dedup_idx += 1;
74 }
75 let local = usize::try_from(val)? - chunk_start;
76 if local_indices.last() != Some(&local) {
77 local_indices.push(local);
78 }
79 final_take[orig_pos] = dedup_idx;
80 }
81
82 let filter_mask = Mask::from_indices(chunk_len, local_indices);
83 chunks.push(array.chunk(chunk_idx).filter(filter_mask)?);
84 }
85
86 cursor = range_end;
87 }
88
89 let flat = unsafe { ChunkedArray::new_unchecked(chunks, array.dtype().clone()) }
92 .into_array()
93 .execute::<Canonical>(ctx)?
95 .into_array();
96
97 let take_validity =
100 Validity::from_mask(indices.validity_mask()?, indices.dtype().nullability());
101 flat.take(PrimitiveArray::new(final_take.freeze(), take_validity).into_array())
102}
103
104impl TakeExecute for Chunked {
105 fn take(
106 array: ArrayView<'_, Chunked>,
107 indices: &ArrayRef,
108 ctx: &mut ExecutionCtx,
109 ) -> VortexResult<Option<ArrayRef>> {
110 take_chunked(array, indices, ctx).map(Some)
111 }
112}
113
114#[cfg(test)]
115mod test {
116 use vortex_buffer::bitbuffer;
117 use vortex_buffer::buffer;
118 use vortex_error::VortexResult;
119
120 use crate::IntoArray;
121 use crate::ToCanonical;
122 use crate::arrays::BoolArray;
123 use crate::arrays::ChunkedArray;
124 use crate::arrays::PrimitiveArray;
125 use crate::arrays::StructArray;
126 use crate::arrays::chunked::ChunkedArrayExt;
127 use crate::assert_arrays_eq;
128 use crate::compute::conformance::take::test_take_conformance;
129 use crate::dtype::FieldNames;
130 use crate::dtype::Nullability;
131 use crate::validity::Validity;
132
133 #[test]
134 fn test_take() {
135 let a = buffer![1i32, 2, 3].into_array();
136 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
137 .unwrap();
138 assert_eq!(arr.nchunks(), 3);
139 assert_eq!(arr.len(), 9);
140 let indices = buffer![0u64, 0, 6, 4].into_array();
141
142 let result = arr.take(indices).unwrap();
143 assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 1, 1, 2]));
144 }
145
146 #[test]
147 fn test_take_nullable_values() {
148 let a = PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array();
149 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
150 .unwrap();
151 assert_eq!(arr.nchunks(), 3);
152 assert_eq!(arr.len(), 9);
153 let indices = PrimitiveArray::new(buffer![0u64, 0, 6, 4], Validity::NonNullable);
154
155 let result = arr.take(indices.into_array()).unwrap();
156 assert_arrays_eq!(
157 result,
158 PrimitiveArray::from_option_iter([1i32, 1, 1, 2].map(Some))
159 );
160 }
161
162 #[test]
163 fn test_take_nullable_indices() {
164 let a = buffer![1i32, 2, 3].into_array();
165 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
166 .unwrap();
167 assert_eq!(arr.nchunks(), 3);
168 assert_eq!(arr.len(), 9);
169 let indices = PrimitiveArray::new(
170 buffer![0u64, 0, 6, 4],
171 Validity::Array(bitbuffer![1 0 0 1].into_array()),
172 );
173
174 let result = arr.take(indices.into_array()).unwrap();
175 assert_arrays_eq!(
176 result,
177 PrimitiveArray::from_option_iter([Some(1i32), None, None, Some(2)])
178 );
179 }
180
181 #[test]
182 fn test_take_nullable_struct() {
183 let struct_array =
184 StructArray::try_new(FieldNames::default(), vec![], 100, Validity::NonNullable)
185 .unwrap();
186
187 let arr = ChunkedArray::from_iter(vec![
188 struct_array.clone().into_array(),
189 struct_array.into_array(),
190 ]);
191
192 let result = arr
193 .take(PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).into_array())
194 .unwrap();
195
196 let expect = StructArray::try_new(
197 FieldNames::default(),
198 vec![],
199 3,
200 Validity::Array(BoolArray::from_iter(vec![true, false, true]).into_array()),
201 )
202 .unwrap();
203 assert_arrays_eq!(result, expect);
204 }
205
206 #[test]
207 fn test_empty_take() {
208 let a = buffer![1i32, 2, 3].into_array();
209 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
210 .unwrap();
211 assert_eq!(arr.nchunks(), 3);
212 assert_eq!(arr.len(), 9);
213
214 let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
215 let result = arr.take(indices.into_array()).unwrap();
216
217 assert!(result.is_empty());
218 assert_eq!(result.dtype(), arr.dtype());
219 assert_arrays_eq!(
220 result,
221 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
222 );
223 }
224
225 #[test]
226 fn test_take_shuffled_indices() -> VortexResult<()> {
227 let c0 = buffer![0i32, 1, 2].into_array();
228 let c1 = buffer![3i32, 4, 5].into_array();
229 let c2 = buffer![6i32, 7, 8].into_array();
230 let arr = ChunkedArray::try_new(
231 vec![c0, c1, c2],
232 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
233 .dtype()
234 .clone(),
235 )?;
236
237 let indices = buffer![8u64, 0, 5, 3, 2, 7, 1, 6, 4].into_array();
239 let result = arr.take(indices)?;
240
241 assert_arrays_eq!(
242 result,
243 PrimitiveArray::from_iter([8i32, 0, 5, 3, 2, 7, 1, 6, 4])
244 );
245 Ok(())
246 }
247
248 #[test]
249 fn test_take_shuffled_large() -> VortexResult<()> {
250 let nchunks: i32 = 100;
251 let chunk_len: i32 = 1_000;
252 let total = nchunks * chunk_len;
253
254 let chunks: Vec<_> = (0..nchunks)
255 .map(|c| {
256 let start = c * chunk_len;
257 PrimitiveArray::from_iter(start..start + chunk_len).into_array()
258 })
259 .collect();
260 let dtype = chunks[0].dtype().clone();
261 let arr = ChunkedArray::try_new(chunks, dtype)?;
262
263 let mut indices: Vec<u64> = (0..u64::try_from(total)?).collect();
265 let mut seed: u64 = 0xdeadbeef;
266 for i in (1..indices.len()).rev() {
267 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
268 let j = (seed >> 33) as usize % (i + 1);
269 indices.swap(i, j);
270 }
271
272 let indices_arr = PrimitiveArray::new(
273 vortex_buffer::Buffer::from(indices.clone()),
274 Validity::NonNullable,
275 );
276 let result = arr.take(indices_arr.into_array())?;
277
278 let result = result.to_primitive();
280 let result_vals = result.as_slice::<i32>();
281 for (pos, &idx) in indices.iter().enumerate() {
282 assert_eq!(
283 result_vals[pos],
284 i32::try_from(idx)?,
285 "mismatch at position {pos}"
286 );
287 }
288 Ok(())
289 }
290
291 #[test]
292 fn test_take_null_indices() -> VortexResult<()> {
293 let c0 = buffer![10i32, 20, 30].into_array();
294 let c1 = buffer![40i32, 50, 60].into_array();
295 let arr = ChunkedArray::try_new(
296 vec![c0, c1],
297 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
298 .dtype()
299 .clone(),
300 )?;
301
302 let indices =
304 PrimitiveArray::from_option_iter([Some(5u64), None, Some(0), Some(3), None, Some(2)]);
305 let result = arr.take(indices.into_array())?;
306
307 assert_arrays_eq!(
308 result,
309 PrimitiveArray::from_option_iter([
310 Some(60i32),
311 None,
312 Some(10),
313 Some(40),
314 None,
315 Some(30)
316 ])
317 );
318 Ok(())
319 }
320
321 #[test]
322 fn test_take_chunked_conformance() {
323 let a = buffer![1i32, 2, 3].into_array();
324 let b = buffer![4i32, 5].into_array();
325 let arr = ChunkedArray::try_new(
326 vec![a, b],
327 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
328 .dtype()
329 .clone(),
330 )
331 .unwrap();
332 test_take_conformance(&arr.into_array());
333
334 let a = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
336 let b = PrimitiveArray::from_option_iter([Some(4i32), Some(5)]);
337 let dtype = a.dtype().clone();
338 let arr = ChunkedArray::try_new(vec![a.into_array(), b.into_array()], dtype).unwrap();
339 test_take_conformance(&arr.into_array());
340
341 let chunk = buffer![10i32, 20, 30, 40, 50].into_array();
343 let arr = ChunkedArray::try_new(
344 vec![chunk.clone(), chunk.clone(), chunk.clone()],
345 chunk.dtype().clone(),
346 )
347 .unwrap();
348 test_take_conformance(&arr.into_array());
349 }
350}