vortex_array/arrays/chunked/compute/
take.rs1use vortex_buffer::BufferMut;
5use vortex_error::VortexResult;
6use vortex_mask::Mask;
7
8use crate::Array;
9use crate::ArrayRef;
10use crate::Canonical;
11use crate::IntoArray;
12use crate::arrays::ChunkedVTable;
13use crate::arrays::PrimitiveArray;
14use crate::arrays::TakeExecute;
15use crate::arrays::chunked::ChunkedArray;
16use crate::builtins::ArrayBuiltins;
17use crate::canonical::ToCanonical;
18use crate::dtype::DType;
19use crate::dtype::PType;
20use crate::executor::ExecutionCtx;
21use crate::validity::Validity;
22
23fn take_chunked(
26 array: &ChunkedArray,
27 indices: &dyn Array,
28 ctx: &mut ExecutionCtx,
29) -> VortexResult<ArrayRef> {
30 let indices = indices
31 .to_array()
32 .cast(DType::Primitive(PType::U64, indices.dtype().nullability()))?
33 .to_primitive();
34
35 let indices_mask = indices.validity_mask()?;
36 let indices_values = indices.as_slice::<u64>();
37 let n = indices_values.len();
38
39 let mut pairs: Vec<(u64, usize)> = indices_values
42 .iter()
43 .enumerate()
44 .filter(|&(i, _)| indices_mask.value(i))
45 .map(|(i, &v)| (v, i))
46 .collect();
47 pairs.sort_unstable();
48
49 let chunk_offsets = array.chunk_offsets();
53 let nchunks = array.nchunks();
54 let mut chunks = Vec::with_capacity(nchunks);
55 let mut final_take = BufferMut::<u64>::with_capacity(n);
56 final_take.push_n(0u64, n);
57
58 let mut cursor = 0usize;
59 let mut dedup_idx = 0u64;
60
61 for chunk_idx in 0..nchunks {
62 let chunk_start = chunk_offsets[chunk_idx];
63 let chunk_end = chunk_offsets[chunk_idx + 1];
64 let chunk_len = usize::try_from(chunk_end - chunk_start)?;
65
66 let range_end = cursor + pairs[cursor..].partition_point(|&(v, _)| v < chunk_end);
67 let chunk_pairs = &pairs[cursor..range_end];
68
69 if !chunk_pairs.is_empty() {
70 let mut local_indices: Vec<usize> = Vec::new();
71 for (i, &(val, orig_pos)) in chunk_pairs.iter().enumerate() {
72 if cursor + i > 0 && val != pairs[cursor + i - 1].0 {
73 dedup_idx += 1;
74 }
75 let local = usize::try_from(val - chunk_start)?;
76 if local_indices.last() != Some(&local) {
77 local_indices.push(local);
78 }
79 final_take[orig_pos] = dedup_idx;
80 }
81
82 let filter_mask = Mask::from_indices(chunk_len, local_indices);
83 chunks.push(array.chunk(chunk_idx).filter(filter_mask)?);
84 }
85
86 cursor = range_end;
87 }
88
89 let flat = unsafe { ChunkedArray::new_unchecked(chunks, array.dtype().clone()) }
92 .into_array()
93 .execute::<Canonical>(ctx)?
95 .into_array();
96
97 let take_validity = Validity::from_mask(indices_mask, indices.dtype().nullability());
100 flat.take(PrimitiveArray::new(final_take.freeze(), take_validity).into_array())
101}
102
103impl TakeExecute for ChunkedVTable {
104 fn take(
105 array: &ChunkedArray,
106 indices: &dyn Array,
107 ctx: &mut ExecutionCtx,
108 ) -> VortexResult<Option<ArrayRef>> {
109 take_chunked(array, indices, ctx).map(Some)
110 }
111}
112
113#[cfg(test)]
114mod test {
115 use vortex_buffer::bitbuffer;
116 use vortex_buffer::buffer;
117 use vortex_error::VortexResult;
118
119 use crate::IntoArray;
120 use crate::ToCanonical;
121 use crate::array::Array;
122 use crate::arrays::BoolArray;
123 use crate::arrays::PrimitiveArray;
124 use crate::arrays::StructArray;
125 use crate::arrays::chunked::ChunkedArray;
126 use crate::assert_arrays_eq;
127 use crate::compute::conformance::take::test_take_conformance;
128 use crate::dtype::FieldNames;
129 use crate::dtype::Nullability;
130 use crate::validity::Validity;
131
132 #[test]
133 fn test_take() {
134 let a = buffer![1i32, 2, 3].into_array();
135 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
136 .unwrap();
137 assert_eq!(arr.nchunks(), 3);
138 assert_eq!(arr.len(), 9);
139 let indices = buffer![0u64, 0, 6, 4].into_array();
140
141 let result = arr.take(indices.to_array()).unwrap();
142 assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 1, 1, 2]));
143 }
144
145 #[test]
146 fn test_take_nullable_values() {
147 let a = PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array();
148 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
149 .unwrap();
150 assert_eq!(arr.nchunks(), 3);
151 assert_eq!(arr.len(), 9);
152 let indices = PrimitiveArray::new(buffer![0u64, 0, 6, 4], Validity::NonNullable);
153
154 let result = arr.take(indices.to_array()).unwrap();
155 assert_arrays_eq!(
156 result,
157 PrimitiveArray::from_option_iter([1i32, 1, 1, 2].map(Some))
158 );
159 }
160
161 #[test]
162 fn test_take_nullable_indices() {
163 let a = buffer![1i32, 2, 3].into_array();
164 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
165 .unwrap();
166 assert_eq!(arr.nchunks(), 3);
167 assert_eq!(arr.len(), 9);
168 let indices = PrimitiveArray::new(
169 buffer![0u64, 0, 6, 4],
170 Validity::Array(bitbuffer![1 0 0 1].into_array()),
171 );
172
173 let result = arr.take(indices.to_array()).unwrap();
174 assert_arrays_eq!(
175 result,
176 PrimitiveArray::from_option_iter([Some(1i32), None, None, Some(2)])
177 );
178 }
179
180 #[test]
181 fn test_take_nullable_struct() {
182 let struct_array =
183 StructArray::try_new(FieldNames::default(), vec![], 100, Validity::NonNullable)
184 .unwrap();
185
186 let arr = ChunkedArray::from_iter(vec![struct_array.to_array(), struct_array.to_array()]);
187
188 let result = arr
189 .take(PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).to_array())
190 .unwrap();
191
192 let expect = StructArray::try_new(
193 FieldNames::default(),
194 vec![],
195 3,
196 Validity::Array(BoolArray::from_iter(vec![true, false, true]).to_array()),
197 )
198 .unwrap();
199 assert_arrays_eq!(result, expect);
200 }
201
202 #[test]
203 fn test_empty_take() {
204 let a = buffer![1i32, 2, 3].into_array();
205 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
206 .unwrap();
207 assert_eq!(arr.nchunks(), 3);
208 assert_eq!(arr.len(), 9);
209
210 let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
211 let result = arr.take(indices.to_array()).unwrap();
212
213 assert!(result.is_empty());
214 assert_eq!(result.dtype(), arr.dtype());
215 assert_arrays_eq!(
216 result,
217 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
218 );
219 }
220
221 #[test]
222 fn test_take_shuffled_indices() -> VortexResult<()> {
223 let c0 = buffer![0i32, 1, 2].into_array();
224 let c1 = buffer![3i32, 4, 5].into_array();
225 let c2 = buffer![6i32, 7, 8].into_array();
226 let arr = ChunkedArray::try_new(
227 vec![c0, c1, c2],
228 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
229 .dtype()
230 .clone(),
231 )?;
232
233 let indices = buffer![8u64, 0, 5, 3, 2, 7, 1, 6, 4].into_array();
235 let result = arr.take(indices.to_array())?;
236
237 assert_arrays_eq!(
238 result,
239 PrimitiveArray::from_iter([8i32, 0, 5, 3, 2, 7, 1, 6, 4])
240 );
241 Ok(())
242 }
243
244 #[test]
245 fn test_take_shuffled_large() -> VortexResult<()> {
246 let nchunks: i32 = 100;
247 let chunk_len: i32 = 1_000;
248 let total = nchunks * chunk_len;
249
250 let chunks: Vec<_> = (0..nchunks)
251 .map(|c| {
252 let start = c * chunk_len;
253 PrimitiveArray::from_iter(start..start + chunk_len).into_array()
254 })
255 .collect();
256 let dtype = chunks[0].dtype().clone();
257 let arr = ChunkedArray::try_new(chunks, dtype)?;
258
259 let mut indices: Vec<u64> = (0..u64::try_from(total)?).collect();
261 let mut seed: u64 = 0xdeadbeef;
262 for i in (1..indices.len()).rev() {
263 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
264 let j = (seed >> 33) as usize % (i + 1);
265 indices.swap(i, j);
266 }
267
268 let indices_arr = PrimitiveArray::new(
269 vortex_buffer::Buffer::from(indices.clone()),
270 Validity::NonNullable,
271 );
272 let result = arr.take(indices_arr.to_array())?;
273
274 let result = result.to_primitive();
276 let result_vals = result.as_slice::<i32>();
277 for (pos, &idx) in indices.iter().enumerate() {
278 assert_eq!(
279 result_vals[pos],
280 i32::try_from(idx)?,
281 "mismatch at position {pos}"
282 );
283 }
284 Ok(())
285 }
286
287 #[test]
288 fn test_take_null_indices() -> VortexResult<()> {
289 let c0 = buffer![10i32, 20, 30].into_array();
290 let c1 = buffer![40i32, 50, 60].into_array();
291 let arr = ChunkedArray::try_new(
292 vec![c0, c1],
293 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
294 .dtype()
295 .clone(),
296 )?;
297
298 let indices =
300 PrimitiveArray::from_option_iter([Some(5u64), None, Some(0), Some(3), None, Some(2)]);
301 let result = arr.take(indices.to_array())?;
302
303 assert_arrays_eq!(
304 result,
305 PrimitiveArray::from_option_iter([
306 Some(60i32),
307 None,
308 Some(10),
309 Some(40),
310 None,
311 Some(30)
312 ])
313 );
314 Ok(())
315 }
316
317 #[test]
318 fn test_take_chunked_conformance() {
319 let a = buffer![1i32, 2, 3].into_array();
320 let b = buffer![4i32, 5].into_array();
321 let arr = ChunkedArray::try_new(
322 vec![a, b],
323 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
324 .dtype()
325 .clone(),
326 )
327 .unwrap();
328 test_take_conformance(arr.as_ref());
329
330 let a = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
332 let b = PrimitiveArray::from_option_iter([Some(4i32), Some(5)]);
333 let dtype = a.dtype().clone();
334 let arr = ChunkedArray::try_new(vec![a.into_array(), b.into_array()], dtype).unwrap();
335 test_take_conformance(arr.as_ref());
336
337 let chunk = buffer![10i32, 20, 30, 40, 50].into_array();
339 let arr = ChunkedArray::try_new(
340 vec![chunk.clone(), chunk.clone(), chunk.clone()],
341 chunk.dtype().clone(),
342 )
343 .unwrap();
344 test_take_conformance(arr.as_ref());
345 }
346}