Skip to main content

vortex_array/arrays/chunked/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BufferMut;
5use vortex_error::VortexResult;
6use vortex_mask::Mask;
7
8use crate::Array;
9use crate::ArrayRef;
10use crate::Canonical;
11use crate::IntoArray;
12use crate::arrays::ChunkedVTable;
13use crate::arrays::PrimitiveArray;
14use crate::arrays::TakeExecute;
15use crate::arrays::chunked::ChunkedArray;
16use crate::builtins::ArrayBuiltins;
17use crate::canonical::ToCanonical;
18use crate::dtype::DType;
19use crate::dtype::PType;
20use crate::executor::ExecutionCtx;
21use crate::validity::Validity;
22
23// TODO(joe): this is pretty unoptimized but better than before. We want canonical using a builder
24// we also want to return a chunked array ideally.
25fn take_chunked(
26    array: &ChunkedArray,
27    indices: &dyn Array,
28    ctx: &mut ExecutionCtx,
29) -> VortexResult<ArrayRef> {
30    let indices = indices
31        .to_array()
32        .cast(DType::Primitive(PType::U64, indices.dtype().nullability()))?
33        .to_primitive();
34
35    let indices_mask = indices.validity_mask()?;
36    let indices_values = indices.as_slice::<u64>();
37    let n = indices_values.len();
38
39    // 1. Sort (value, orig_pos) pairs so indices for the same chunk are contiguous.
40    //    Skip null indices — their final_take slots stay 0 and are masked null by validity.
41    let mut pairs: Vec<(u64, usize)> = indices_values
42        .iter()
43        .enumerate()
44        .filter(|&(i, _)| indices_mask.value(i))
45        .map(|(i, &v)| (v, i))
46        .collect();
47    pairs.sort_unstable();
48
49    // 2. Fused pass: walk sorted pairs against chunk boundaries.
50    //    - Dedup inline → build per-chunk filter masks
51    //    - Scatter final_take[orig_pos] = dedup_idx for every pair
52    let chunk_offsets = array.chunk_offsets();
53    let nchunks = array.nchunks();
54    let mut chunks = Vec::with_capacity(nchunks);
55    let mut final_take = BufferMut::<u64>::with_capacity(n);
56    final_take.push_n(0u64, n);
57
58    let mut cursor = 0usize;
59    let mut dedup_idx = 0u64;
60
61    for chunk_idx in 0..nchunks {
62        let chunk_start = chunk_offsets[chunk_idx];
63        let chunk_end = chunk_offsets[chunk_idx + 1];
64        let chunk_len = usize::try_from(chunk_end - chunk_start)?;
65
66        let range_end = cursor + pairs[cursor..].partition_point(|&(v, _)| v < chunk_end);
67        let chunk_pairs = &pairs[cursor..range_end];
68
69        if !chunk_pairs.is_empty() {
70            let mut local_indices: Vec<usize> = Vec::new();
71            for (i, &(val, orig_pos)) in chunk_pairs.iter().enumerate() {
72                if cursor + i > 0 && val != pairs[cursor + i - 1].0 {
73                    dedup_idx += 1;
74                }
75                let local = usize::try_from(val - chunk_start)?;
76                if local_indices.last() != Some(&local) {
77                    local_indices.push(local);
78                }
79                final_take[orig_pos] = dedup_idx;
80            }
81
82            let filter_mask = Mask::from_indices(chunk_len, local_indices);
83            chunks.push(array.chunk(chunk_idx).filter(filter_mask)?);
84        }
85
86        cursor = range_end;
87    }
88
89    // SAFETY: every chunk came from a filter on a chunk with the same base dtype,
90    // unioned with the index nullability.
91    let flat = unsafe { ChunkedArray::new_unchecked(chunks, array.dtype().clone()) }
92        .into_array()
93        // TODO(joe): can we relax this.
94        .execute::<Canonical>(ctx)?
95        .into_array();
96
97    // 4. Single take to restore original order and expand duplicates.
98    //    Carry the original index validity so null indices produce null outputs.
99    let take_validity = Validity::from_mask(indices_mask, indices.dtype().nullability());
100    flat.take(PrimitiveArray::new(final_take.freeze(), take_validity).into_array())
101}
102
103impl TakeExecute for ChunkedVTable {
104    fn take(
105        array: &ChunkedArray,
106        indices: &dyn Array,
107        ctx: &mut ExecutionCtx,
108    ) -> VortexResult<Option<ArrayRef>> {
109        take_chunked(array, indices, ctx).map(Some)
110    }
111}
112
113#[cfg(test)]
114mod test {
115    use vortex_buffer::bitbuffer;
116    use vortex_buffer::buffer;
117    use vortex_error::VortexResult;
118
119    use crate::IntoArray;
120    use crate::ToCanonical;
121    use crate::array::Array;
122    use crate::arrays::BoolArray;
123    use crate::arrays::PrimitiveArray;
124    use crate::arrays::StructArray;
125    use crate::arrays::chunked::ChunkedArray;
126    use crate::assert_arrays_eq;
127    use crate::compute::conformance::take::test_take_conformance;
128    use crate::dtype::FieldNames;
129    use crate::dtype::Nullability;
130    use crate::validity::Validity;
131
132    #[test]
133    fn test_take() {
134        let a = buffer![1i32, 2, 3].into_array();
135        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
136            .unwrap();
137        assert_eq!(arr.nchunks(), 3);
138        assert_eq!(arr.len(), 9);
139        let indices = buffer![0u64, 0, 6, 4].into_array();
140
141        let result = arr.take(indices.to_array()).unwrap();
142        assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 1, 1, 2]));
143    }
144
145    #[test]
146    fn test_take_nullable_values() {
147        let a = PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array();
148        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
149            .unwrap();
150        assert_eq!(arr.nchunks(), 3);
151        assert_eq!(arr.len(), 9);
152        let indices = PrimitiveArray::new(buffer![0u64, 0, 6, 4], Validity::NonNullable);
153
154        let result = arr.take(indices.to_array()).unwrap();
155        assert_arrays_eq!(
156            result,
157            PrimitiveArray::from_option_iter([1i32, 1, 1, 2].map(Some))
158        );
159    }
160
161    #[test]
162    fn test_take_nullable_indices() {
163        let a = buffer![1i32, 2, 3].into_array();
164        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
165            .unwrap();
166        assert_eq!(arr.nchunks(), 3);
167        assert_eq!(arr.len(), 9);
168        let indices = PrimitiveArray::new(
169            buffer![0u64, 0, 6, 4],
170            Validity::Array(bitbuffer![1 0 0 1].into_array()),
171        );
172
173        let result = arr.take(indices.to_array()).unwrap();
174        assert_arrays_eq!(
175            result,
176            PrimitiveArray::from_option_iter([Some(1i32), None, None, Some(2)])
177        );
178    }
179
180    #[test]
181    fn test_take_nullable_struct() {
182        let struct_array =
183            StructArray::try_new(FieldNames::default(), vec![], 100, Validity::NonNullable)
184                .unwrap();
185
186        let arr = ChunkedArray::from_iter(vec![struct_array.to_array(), struct_array.to_array()]);
187
188        let result = arr
189            .take(PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).to_array())
190            .unwrap();
191
192        let expect = StructArray::try_new(
193            FieldNames::default(),
194            vec![],
195            3,
196            Validity::Array(BoolArray::from_iter(vec![true, false, true]).to_array()),
197        )
198        .unwrap();
199        assert_arrays_eq!(result, expect);
200    }
201
202    #[test]
203    fn test_empty_take() {
204        let a = buffer![1i32, 2, 3].into_array();
205        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
206            .unwrap();
207        assert_eq!(arr.nchunks(), 3);
208        assert_eq!(arr.len(), 9);
209
210        let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
211        let result = arr.take(indices.to_array()).unwrap();
212
213        assert!(result.is_empty());
214        assert_eq!(result.dtype(), arr.dtype());
215        assert_arrays_eq!(
216            result,
217            PrimitiveArray::empty::<i32>(Nullability::NonNullable)
218        );
219    }
220
221    #[test]
222    fn test_take_shuffled_indices() -> VortexResult<()> {
223        let c0 = buffer![0i32, 1, 2].into_array();
224        let c1 = buffer![3i32, 4, 5].into_array();
225        let c2 = buffer![6i32, 7, 8].into_array();
226        let arr = ChunkedArray::try_new(
227            vec![c0, c1, c2],
228            PrimitiveArray::empty::<i32>(Nullability::NonNullable)
229                .dtype()
230                .clone(),
231        )?;
232
233        // Fully shuffled indices that cross every chunk boundary.
234        let indices = buffer![8u64, 0, 5, 3, 2, 7, 1, 6, 4].into_array();
235        let result = arr.take(indices.to_array())?;
236
237        assert_arrays_eq!(
238            result,
239            PrimitiveArray::from_iter([8i32, 0, 5, 3, 2, 7, 1, 6, 4])
240        );
241        Ok(())
242    }
243
244    #[test]
245    fn test_take_shuffled_large() -> VortexResult<()> {
246        let nchunks: i32 = 100;
247        let chunk_len: i32 = 1_000;
248        let total = nchunks * chunk_len;
249
250        let chunks: Vec<_> = (0..nchunks)
251            .map(|c| {
252                let start = c * chunk_len;
253                PrimitiveArray::from_iter(start..start + chunk_len).into_array()
254            })
255            .collect();
256        let dtype = chunks[0].dtype().clone();
257        let arr = ChunkedArray::try_new(chunks, dtype)?;
258
259        // Fisher-Yates shuffle with a fixed seed for determinism.
260        let mut indices: Vec<u64> = (0..u64::try_from(total)?).collect();
261        let mut seed: u64 = 0xdeadbeef;
262        for i in (1..indices.len()).rev() {
263            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
264            let j = (seed >> 33) as usize % (i + 1);
265            indices.swap(i, j);
266        }
267
268        let indices_arr = PrimitiveArray::new(
269            vortex_buffer::Buffer::from(indices.clone()),
270            Validity::NonNullable,
271        );
272        let result = arr.take(indices_arr.to_array())?;
273
274        // Verify every element.
275        let result = result.to_primitive();
276        let result_vals = result.as_slice::<i32>();
277        for (pos, &idx) in indices.iter().enumerate() {
278            assert_eq!(
279                result_vals[pos],
280                i32::try_from(idx)?,
281                "mismatch at position {pos}"
282            );
283        }
284        Ok(())
285    }
286
287    #[test]
288    fn test_take_null_indices() -> VortexResult<()> {
289        let c0 = buffer![10i32, 20, 30].into_array();
290        let c1 = buffer![40i32, 50, 60].into_array();
291        let arr = ChunkedArray::try_new(
292            vec![c0, c1],
293            PrimitiveArray::empty::<i32>(Nullability::NonNullable)
294                .dtype()
295                .clone(),
296        )?;
297
298        // Indices with nulls scattered across chunk boundaries.
299        let indices =
300            PrimitiveArray::from_option_iter([Some(5u64), None, Some(0), Some(3), None, Some(2)]);
301        let result = arr.take(indices.to_array())?;
302
303        assert_arrays_eq!(
304            result,
305            PrimitiveArray::from_option_iter([
306                Some(60i32),
307                None,
308                Some(10),
309                Some(40),
310                None,
311                Some(30)
312            ])
313        );
314        Ok(())
315    }
316
317    #[test]
318    fn test_take_chunked_conformance() {
319        let a = buffer![1i32, 2, 3].into_array();
320        let b = buffer![4i32, 5].into_array();
321        let arr = ChunkedArray::try_new(
322            vec![a, b],
323            PrimitiveArray::empty::<i32>(Nullability::NonNullable)
324                .dtype()
325                .clone(),
326        )
327        .unwrap();
328        test_take_conformance(arr.as_ref());
329
330        // Test with nullable chunked array
331        let a = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
332        let b = PrimitiveArray::from_option_iter([Some(4i32), Some(5)]);
333        let dtype = a.dtype().clone();
334        let arr = ChunkedArray::try_new(vec![a.into_array(), b.into_array()], dtype).unwrap();
335        test_take_conformance(arr.as_ref());
336
337        // Test with multiple identical chunks
338        let chunk = buffer![10i32, 20, 30, 40, 50].into_array();
339        let arr = ChunkedArray::try_new(
340            vec![chunk.clone(), chunk.clone(), chunk.clone()],
341            chunk.dtype().clone(),
342        )
343        .unwrap();
344        test_take_conformance(arr.as_ref());
345    }
346}