vortex_array/arrays/chunked/compute/
take.rs1use vortex_buffer::BufferMut;
5use vortex_error::VortexResult;
6use vortex_mask::Mask;
7
8use crate::ArrayRef;
9use crate::Canonical;
10use crate::DynArray;
11use crate::IntoArray;
12use crate::arrays::ChunkedArray;
13use crate::arrays::ChunkedVTable;
14use crate::arrays::PrimitiveArray;
15use crate::arrays::dict::TakeExecute;
16use crate::builtins::ArrayBuiltins;
17use crate::dtype::DType;
18use crate::dtype::PType;
19use crate::executor::ExecutionCtx;
20use crate::validity::Validity;
21
22fn take_chunked(
25 array: &ChunkedArray,
26 indices: &ArrayRef,
27 ctx: &mut ExecutionCtx,
28) -> VortexResult<ArrayRef> {
29 let indices = indices
30 .to_array()
31 .cast(DType::Primitive(PType::U64, indices.dtype().nullability()))?
32 .execute::<PrimitiveArray>(ctx)?;
33
34 let indices_mask = indices.validity_mask()?;
35 let indices_values = indices.as_slice::<u64>();
36 let n = indices_values.len();
37
38 let mut pairs: Vec<(u64, usize)> = indices_values
41 .iter()
42 .enumerate()
43 .filter(|&(i, _)| indices_mask.value(i))
44 .map(|(i, &v)| (v, i))
45 .collect();
46 pairs.sort_unstable();
47
48 let chunk_offsets = array.chunk_offsets();
52 let nchunks = array.nchunks();
53 let mut chunks = Vec::with_capacity(nchunks);
54 let mut final_take = BufferMut::<u64>::with_capacity(n);
55 final_take.push_n(0u64, n);
56
57 let mut cursor = 0usize;
58 let mut dedup_idx = 0u64;
59
60 for chunk_idx in 0..nchunks {
61 let chunk_start = chunk_offsets[chunk_idx];
62 let chunk_end = chunk_offsets[chunk_idx + 1];
63 let chunk_len = usize::try_from(chunk_end - chunk_start)?;
64
65 let range_end = cursor + pairs[cursor..].partition_point(|&(v, _)| v < chunk_end);
66 let chunk_pairs = &pairs[cursor..range_end];
67
68 if !chunk_pairs.is_empty() {
69 let mut local_indices: Vec<usize> = Vec::new();
70 for (i, &(val, orig_pos)) in chunk_pairs.iter().enumerate() {
71 if cursor + i > 0 && val != pairs[cursor + i - 1].0 {
72 dedup_idx += 1;
73 }
74 let local = usize::try_from(val - chunk_start)?;
75 if local_indices.last() != Some(&local) {
76 local_indices.push(local);
77 }
78 final_take[orig_pos] = dedup_idx;
79 }
80
81 let filter_mask = Mask::from_indices(chunk_len, local_indices);
82 chunks.push(array.chunk(chunk_idx).filter(filter_mask)?);
83 }
84
85 cursor = range_end;
86 }
87
88 let flat = unsafe { ChunkedArray::new_unchecked(chunks, array.dtype().clone()) }
91 .into_array()
92 .execute::<Canonical>(ctx)?
94 .into_array();
95
96 let take_validity = Validity::from_mask(indices_mask, indices.dtype().nullability());
99 flat.take(PrimitiveArray::new(final_take.freeze(), take_validity).into_array())
100}
101
102impl TakeExecute for ChunkedVTable {
103 fn take(
104 array: &ChunkedArray,
105 indices: &ArrayRef,
106 ctx: &mut ExecutionCtx,
107 ) -> VortexResult<Option<ArrayRef>> {
108 take_chunked(array, indices, ctx).map(Some)
109 }
110}
111
112#[cfg(test)]
113mod test {
114 use vortex_buffer::bitbuffer;
115 use vortex_buffer::buffer;
116 use vortex_error::VortexResult;
117
118 use crate::IntoArray;
119 use crate::ToCanonical;
120 use crate::array::DynArray;
121 use crate::arrays::BoolArray;
122 use crate::arrays::ChunkedArray;
123 use crate::arrays::PrimitiveArray;
124 use crate::arrays::StructArray;
125 use crate::assert_arrays_eq;
126 use crate::compute::conformance::take::test_take_conformance;
127 use crate::dtype::FieldNames;
128 use crate::dtype::Nullability;
129 use crate::validity::Validity;
130
131 #[test]
132 fn test_take() {
133 let a = buffer![1i32, 2, 3].into_array();
134 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
135 .unwrap();
136 assert_eq!(arr.nchunks(), 3);
137 assert_eq!(arr.len(), 9);
138 let indices = buffer![0u64, 0, 6, 4].into_array();
139
140 let result = arr.take(indices.to_array()).unwrap();
141 assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 1, 1, 2]));
142 }
143
144 #[test]
145 fn test_take_nullable_values() {
146 let a = PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array();
147 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
148 .unwrap();
149 assert_eq!(arr.nchunks(), 3);
150 assert_eq!(arr.len(), 9);
151 let indices = PrimitiveArray::new(buffer![0u64, 0, 6, 4], Validity::NonNullable);
152
153 let result = arr.take(indices.into_array()).unwrap();
154 assert_arrays_eq!(
155 result,
156 PrimitiveArray::from_option_iter([1i32, 1, 1, 2].map(Some))
157 );
158 }
159
160 #[test]
161 fn test_take_nullable_indices() {
162 let a = buffer![1i32, 2, 3].into_array();
163 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
164 .unwrap();
165 assert_eq!(arr.nchunks(), 3);
166 assert_eq!(arr.len(), 9);
167 let indices = PrimitiveArray::new(
168 buffer![0u64, 0, 6, 4],
169 Validity::Array(bitbuffer![1 0 0 1].into_array()),
170 );
171
172 let result = arr.take(indices.into_array()).unwrap();
173 assert_arrays_eq!(
174 result,
175 PrimitiveArray::from_option_iter([Some(1i32), None, None, Some(2)])
176 );
177 }
178
179 #[test]
180 fn test_take_nullable_struct() {
181 let struct_array =
182 StructArray::try_new(FieldNames::default(), vec![], 100, Validity::NonNullable)
183 .unwrap();
184
185 let arr = ChunkedArray::from_iter(vec![
186 struct_array.clone().into_array(),
187 struct_array.into_array(),
188 ]);
189
190 let result = arr
191 .take(PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).into_array())
192 .unwrap();
193
194 let expect = StructArray::try_new(
195 FieldNames::default(),
196 vec![],
197 3,
198 Validity::Array(BoolArray::from_iter(vec![true, false, true]).into_array()),
199 )
200 .unwrap();
201 assert_arrays_eq!(result, expect);
202 }
203
204 #[test]
205 fn test_empty_take() {
206 let a = buffer![1i32, 2, 3].into_array();
207 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
208 .unwrap();
209 assert_eq!(arr.nchunks(), 3);
210 assert_eq!(arr.len(), 9);
211
212 let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
213 let result = arr.take(indices.into_array()).unwrap();
214
215 assert!(result.is_empty());
216 assert_eq!(result.dtype(), arr.dtype());
217 assert_arrays_eq!(
218 result,
219 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
220 );
221 }
222
223 #[test]
224 fn test_take_shuffled_indices() -> VortexResult<()> {
225 let c0 = buffer![0i32, 1, 2].into_array();
226 let c1 = buffer![3i32, 4, 5].into_array();
227 let c2 = buffer![6i32, 7, 8].into_array();
228 let arr = ChunkedArray::try_new(
229 vec![c0, c1, c2],
230 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
231 .dtype()
232 .clone(),
233 )?;
234
235 let indices = buffer![8u64, 0, 5, 3, 2, 7, 1, 6, 4].into_array();
237 let result = arr.take(indices.to_array())?;
238
239 assert_arrays_eq!(
240 result,
241 PrimitiveArray::from_iter([8i32, 0, 5, 3, 2, 7, 1, 6, 4])
242 );
243 Ok(())
244 }
245
246 #[test]
247 fn test_take_shuffled_large() -> VortexResult<()> {
248 let nchunks: i32 = 100;
249 let chunk_len: i32 = 1_000;
250 let total = nchunks * chunk_len;
251
252 let chunks: Vec<_> = (0..nchunks)
253 .map(|c| {
254 let start = c * chunk_len;
255 PrimitiveArray::from_iter(start..start + chunk_len).into_array()
256 })
257 .collect();
258 let dtype = chunks[0].dtype().clone();
259 let arr = ChunkedArray::try_new(chunks, dtype)?;
260
261 let mut indices: Vec<u64> = (0..u64::try_from(total)?).collect();
263 let mut seed: u64 = 0xdeadbeef;
264 for i in (1..indices.len()).rev() {
265 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
266 let j = (seed >> 33) as usize % (i + 1);
267 indices.swap(i, j);
268 }
269
270 let indices_arr = PrimitiveArray::new(
271 vortex_buffer::Buffer::from(indices.clone()),
272 Validity::NonNullable,
273 );
274 let result = arr.take(indices_arr.into_array())?;
275
276 let result = result.to_primitive();
278 let result_vals = result.as_slice::<i32>();
279 for (pos, &idx) in indices.iter().enumerate() {
280 assert_eq!(
281 result_vals[pos],
282 i32::try_from(idx)?,
283 "mismatch at position {pos}"
284 );
285 }
286 Ok(())
287 }
288
289 #[test]
290 fn test_take_null_indices() -> VortexResult<()> {
291 let c0 = buffer![10i32, 20, 30].into_array();
292 let c1 = buffer![40i32, 50, 60].into_array();
293 let arr = ChunkedArray::try_new(
294 vec![c0, c1],
295 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
296 .dtype()
297 .clone(),
298 )?;
299
300 let indices =
302 PrimitiveArray::from_option_iter([Some(5u64), None, Some(0), Some(3), None, Some(2)]);
303 let result = arr.take(indices.into_array())?;
304
305 assert_arrays_eq!(
306 result,
307 PrimitiveArray::from_option_iter([
308 Some(60i32),
309 None,
310 Some(10),
311 Some(40),
312 None,
313 Some(30)
314 ])
315 );
316 Ok(())
317 }
318
319 #[test]
320 fn test_take_chunked_conformance() {
321 let a = buffer![1i32, 2, 3].into_array();
322 let b = buffer![4i32, 5].into_array();
323 let arr = ChunkedArray::try_new(
324 vec![a, b],
325 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
326 .dtype()
327 .clone(),
328 )
329 .unwrap();
330 test_take_conformance(&arr.into_array());
331
332 let a = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
334 let b = PrimitiveArray::from_option_iter([Some(4i32), Some(5)]);
335 let dtype = a.dtype().clone();
336 let arr = ChunkedArray::try_new(vec![a.into_array(), b.into_array()], dtype).unwrap();
337 test_take_conformance(&arr.into_array());
338
339 let chunk = buffer![10i32, 20, 30, 40, 50].into_array();
341 let arr = ChunkedArray::try_new(
342 vec![chunk.clone(), chunk.clone(), chunk.clone()],
343 chunk.dtype().clone(),
344 )
345 .unwrap();
346 test_take_conformance(&arr.into_array());
347 }
348}