Skip to main content

vortex_array/arrays/chunked/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BufferMut;
5use vortex_error::VortexResult;
6use vortex_mask::Mask;
7
8use crate::ArrayRef;
9use crate::Canonical;
10use crate::IntoArray;
11use crate::array::ArrayView;
12use crate::arrays::Chunked;
13use crate::arrays::ChunkedArray;
14use crate::arrays::PrimitiveArray;
15use crate::arrays::chunked::ChunkedArrayExt;
16use crate::arrays::dict::TakeExecute;
17use crate::builtins::ArrayBuiltins;
18use crate::dtype::DType;
19use crate::dtype::PType;
20use crate::executor::ExecutionCtx;
21use crate::validity::Validity;
22
23// TODO(joe): this is pretty unoptimized but better than before. We want canonical using a builder
24// we also want to return a chunked array ideally.
25fn take_chunked(
26    array: ArrayView<'_, Chunked>,
27    indices: &ArrayRef,
28    ctx: &mut ExecutionCtx,
29) -> VortexResult<ArrayRef> {
30    let indices = indices
31        .cast(DType::Primitive(PType::U64, indices.dtype().nullability()))?
32        .execute::<PrimitiveArray>(ctx)?;
33
34    let indices_mask = indices
35        .as_ref()
36        .validity()?
37        .execute_mask(indices.as_ref().len(), ctx)?;
38    let indices_values = indices.as_slice::<u64>();
39    let n = indices_values.len();
40
41    // 1. Sort (value, orig_pos) pairs so indices for the same chunk are contiguous.
42    //    Skip null indices — their final_take slots stay 0 and are masked null by validity.
43    let mut pairs: Vec<(u64, usize)> = indices_values
44        .iter()
45        .enumerate()
46        .filter(|&(i, _)| indices_mask.value(i))
47        .map(|(i, &v)| (v, i))
48        .collect();
49    pairs.sort_unstable();
50
51    // 2. Fused pass: walk sorted pairs against chunk boundaries.
52    //    - Dedup inline → build per-chunk filter masks
53    //    - Scatter final_take[orig_pos] = dedup_idx for every pair
54    let chunk_offsets = array.chunk_offsets();
55    let nchunks = array.nchunks();
56    let mut chunks = Vec::with_capacity(nchunks);
57    let mut final_take = BufferMut::<u64>::with_capacity(n);
58    final_take.push_n(0u64, n);
59
60    let mut cursor = 0usize;
61    let mut dedup_idx = 0u64;
62
63    for chunk_idx in 0..nchunks {
64        let chunk_start = chunk_offsets[chunk_idx];
65        let chunk_end = chunk_offsets[chunk_idx + 1];
66        let chunk_len = chunk_end - chunk_start;
67        let chunk_end_u64 = u64::try_from(chunk_end)?;
68
69        let range_end = cursor + pairs[cursor..].partition_point(|&(v, _)| v < chunk_end_u64);
70        let chunk_pairs = &pairs[cursor..range_end];
71
72        if !chunk_pairs.is_empty() {
73            let mut local_indices: Vec<usize> = Vec::new();
74            for (i, &(val, orig_pos)) in chunk_pairs.iter().enumerate() {
75                if cursor + i > 0 && val != pairs[cursor + i - 1].0 {
76                    dedup_idx += 1;
77                }
78                let local = usize::try_from(val)? - chunk_start;
79                if local_indices.last() != Some(&local) {
80                    local_indices.push(local);
81                }
82                final_take[orig_pos] = dedup_idx;
83            }
84
85            let filter_mask = Mask::from_indices(chunk_len, local_indices);
86            chunks.push(array.chunk(chunk_idx).filter(filter_mask)?);
87        }
88
89        cursor = range_end;
90    }
91
92    // SAFETY: every chunk came from a filter on a chunk with the same base dtype,
93    // unioned with the index nullability.
94    let flat = unsafe { ChunkedArray::new_unchecked(chunks, array.dtype().clone()) }
95        .into_array()
96        // TODO(joe): can we relax this.
97        .execute::<Canonical>(ctx)?
98        .into_array();
99
100    // 4. Single take to restore original order and expand duplicates.
101    //    Carry the original index validity so null indices produce null outputs.
102    let take_validity = Validity::from_mask(
103        indices
104            .as_ref()
105            .validity()?
106            .execute_mask(indices.as_ref().len(), ctx)?,
107        indices.dtype().nullability(),
108    );
109    flat.take(PrimitiveArray::new(final_take.freeze(), take_validity).into_array())
110}
111
112impl TakeExecute for Chunked {
113    fn take(
114        array: ArrayView<'_, Chunked>,
115        indices: &ArrayRef,
116        ctx: &mut ExecutionCtx,
117    ) -> VortexResult<Option<ArrayRef>> {
118        take_chunked(array, indices, ctx).map(Some)
119    }
120}
121
122#[cfg(test)]
123mod test {
124    use vortex_buffer::bitbuffer;
125    use vortex_buffer::buffer;
126    use vortex_error::VortexResult;
127
128    use crate::IntoArray;
129    #[expect(deprecated)]
130    use crate::ToCanonical as _;
131    use crate::VortexSessionExecute;
132    use crate::array_session;
133    use crate::arrays::BoolArray;
134    use crate::arrays::ChunkedArray;
135    use crate::arrays::PrimitiveArray;
136    use crate::arrays::StructArray;
137    use crate::arrays::chunked::ChunkedArrayExt;
138    use crate::assert_arrays_eq;
139    use crate::compute::conformance::take::test_take_conformance;
140    use crate::dtype::FieldNames;
141    use crate::dtype::Nullability;
142    use crate::validity::Validity;
143
144    #[test]
145    fn test_take() {
146        let mut ctx = array_session().create_execution_ctx();
147        let a = buffer![1i32, 2, 3].into_array();
148        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
149            .unwrap();
150        assert_eq!(arr.nchunks(), 3);
151        assert_eq!(arr.len(), 9);
152        let indices = buffer![0u64, 0, 6, 4].into_array();
153
154        let result = arr.take(indices).unwrap();
155        assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 1, 1, 2]), &mut ctx);
156    }
157
158    #[test]
159    fn test_take_nullable_values() {
160        let mut ctx = array_session().create_execution_ctx();
161        let a = PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array();
162        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
163            .unwrap();
164        assert_eq!(arr.nchunks(), 3);
165        assert_eq!(arr.len(), 9);
166        let indices = PrimitiveArray::new(buffer![0u64, 0, 6, 4], Validity::NonNullable);
167
168        let result = arr.take(indices.into_array()).unwrap();
169        assert_arrays_eq!(
170            result,
171            PrimitiveArray::from_option_iter([1i32, 1, 1, 2].map(Some)),
172            &mut ctx
173        );
174    }
175
176    #[test]
177    fn test_take_nullable_indices() {
178        let mut ctx = array_session().create_execution_ctx();
179        let a = buffer![1i32, 2, 3].into_array();
180        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
181            .unwrap();
182        assert_eq!(arr.nchunks(), 3);
183        assert_eq!(arr.len(), 9);
184        let indices = PrimitiveArray::new(
185            buffer![0u64, 0, 6, 4],
186            Validity::Array(bitbuffer![1 0 0 1].into_array()),
187        );
188
189        let result = arr.take(indices.into_array()).unwrap();
190        assert_arrays_eq!(
191            result,
192            PrimitiveArray::from_option_iter([Some(1i32), None, None, Some(2)]),
193            &mut ctx
194        );
195    }
196
197    #[test]
198    fn test_take_nullable_struct() {
199        let mut ctx = array_session().create_execution_ctx();
200        let struct_array =
201            StructArray::try_new(FieldNames::default(), vec![], 100, Validity::NonNullable)
202                .unwrap();
203
204        let arr = ChunkedArray::from_iter(vec![
205            struct_array.clone().into_array(),
206            struct_array.into_array(),
207        ]);
208
209        let result = arr
210            .take(PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).into_array())
211            .unwrap();
212
213        let expect = StructArray::try_new(
214            FieldNames::default(),
215            vec![],
216            3,
217            Validity::Array(BoolArray::from_iter(vec![true, false, true]).into_array()),
218        )
219        .unwrap();
220        assert_arrays_eq!(result, expect, &mut ctx);
221    }
222
223    #[test]
224    fn test_empty_take() {
225        let mut ctx = array_session().create_execution_ctx();
226        let a = buffer![1i32, 2, 3].into_array();
227        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
228            .unwrap();
229        assert_eq!(arr.nchunks(), 3);
230        assert_eq!(arr.len(), 9);
231
232        let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
233        let result = arr.take(indices.into_array()).unwrap();
234
235        assert!(result.is_empty());
236        assert_eq!(result.dtype(), arr.dtype());
237        assert_arrays_eq!(
238            result,
239            PrimitiveArray::empty::<i32>(Nullability::NonNullable),
240            &mut ctx
241        );
242    }
243
244    #[test]
245    fn test_take_shuffled_indices() -> VortexResult<()> {
246        let mut ctx = array_session().create_execution_ctx();
247        let c0 = buffer![0i32, 1, 2].into_array();
248        let c1 = buffer![3i32, 4, 5].into_array();
249        let c2 = buffer![6i32, 7, 8].into_array();
250        let arr = ChunkedArray::try_new(
251            vec![c0, c1, c2],
252            PrimitiveArray::empty::<i32>(Nullability::NonNullable)
253                .dtype()
254                .clone(),
255        )?;
256
257        // Fully shuffled indices that cross every chunk boundary.
258        let indices = buffer![8u64, 0, 5, 3, 2, 7, 1, 6, 4].into_array();
259        let result = arr.take(indices)?;
260
261        assert_arrays_eq!(
262            result,
263            PrimitiveArray::from_iter([8i32, 0, 5, 3, 2, 7, 1, 6, 4]),
264            &mut ctx
265        );
266        Ok(())
267    }
268
269    #[test]
270    fn test_take_shuffled_large() -> VortexResult<()> {
271        let nchunks: i32 = 100;
272        let chunk_len: i32 = 1_000;
273        let total = nchunks * chunk_len;
274
275        let chunks: Vec<_> = (0..nchunks)
276            .map(|c| {
277                let start = c * chunk_len;
278                PrimitiveArray::from_iter(start..start + chunk_len).into_array()
279            })
280            .collect();
281        let dtype = chunks[0].dtype().clone();
282        let arr = ChunkedArray::try_new(chunks, dtype)?;
283
284        // Fisher-Yates shuffle with a fixed seed for determinism.
285        let mut indices: Vec<u64> = (0..u64::try_from(total)?).collect();
286        let mut seed: u64 = 0xdeadbeef;
287        for i in (1..indices.len()).rev() {
288            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
289            let j = (seed >> 33) as usize % (i + 1);
290            indices.swap(i, j);
291        }
292
293        let indices_arr = PrimitiveArray::new(
294            vortex_buffer::Buffer::from(indices.clone()),
295            Validity::NonNullable,
296        );
297        let result = arr.take(indices_arr.into_array())?;
298
299        // Verify every element.
300        #[expect(deprecated)]
301        let result = result.to_primitive();
302        let result_vals = result.as_slice::<i32>();
303        for (pos, &idx) in indices.iter().enumerate() {
304            assert_eq!(
305                result_vals[pos],
306                i32::try_from(idx)?,
307                "mismatch at position {pos}"
308            );
309        }
310        Ok(())
311    }
312
313    #[test]
314    fn test_take_null_indices() -> VortexResult<()> {
315        let mut ctx = array_session().create_execution_ctx();
316        let c0 = buffer![10i32, 20, 30].into_array();
317        let c1 = buffer![40i32, 50, 60].into_array();
318        let arr = ChunkedArray::try_new(
319            vec![c0, c1],
320            PrimitiveArray::empty::<i32>(Nullability::NonNullable)
321                .dtype()
322                .clone(),
323        )?;
324
325        // Indices with nulls scattered across chunk boundaries.
326        let indices =
327            PrimitiveArray::from_option_iter([Some(5u64), None, Some(0), Some(3), None, Some(2)]);
328        let result = arr.take(indices.into_array())?;
329
330        assert_arrays_eq!(
331            result,
332            PrimitiveArray::from_option_iter([
333                Some(60i32),
334                None,
335                Some(10),
336                Some(40),
337                None,
338                Some(30)
339            ]),
340            &mut ctx
341        );
342        Ok(())
343    }
344
345    #[test]
346    fn test_take_chunked_conformance() {
347        let a = buffer![1i32, 2, 3].into_array();
348        let b = buffer![4i32, 5].into_array();
349        let arr = ChunkedArray::try_new(
350            vec![a, b],
351            PrimitiveArray::empty::<i32>(Nullability::NonNullable)
352                .dtype()
353                .clone(),
354        )
355        .unwrap();
356        test_take_conformance(&arr.into_array());
357
358        // Test with nullable chunked array
359        let a = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
360        let b = PrimitiveArray::from_option_iter([Some(4i32), Some(5)]);
361        let dtype = a.dtype().clone();
362        let arr = ChunkedArray::try_new(vec![a.into_array(), b.into_array()], dtype).unwrap();
363        test_take_conformance(&arr.into_array());
364
365        // Test with multiple identical chunks
366        let chunk = buffer![10i32, 20, 30, 40, 50].into_array();
367        let arr = ChunkedArray::try_new(
368            vec![chunk.clone(), chunk.clone(), chunk.clone()],
369            chunk.dtype().clone(),
370        )
371        .unwrap();
372        test_take_conformance(&arr.into_array());
373    }
374}