Skip to main content

vortex_array/arrays/chunked/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BufferMut;
5use vortex_error::VortexResult;
6use vortex_mask::Mask;
7
8use crate::ArrayRef;
9use crate::Canonical;
10use crate::IntoArray;
11use crate::array::ArrayView;
12use crate::arrays::Chunked;
13use crate::arrays::ChunkedArray;
14use crate::arrays::PrimitiveArray;
15use crate::arrays::chunked::ChunkedArrayExt;
16use crate::arrays::dict::TakeExecute;
17use crate::builtins::ArrayBuiltins;
18use crate::dtype::DType;
19use crate::dtype::PType;
20use crate::executor::ExecutionCtx;
21use crate::validity::Validity;
22
23// TODO(joe): this is pretty unoptimized but better than before. We want canonical using a builder
24// we also want to return a chunked array ideally.
25fn take_chunked(
26    array: ArrayView<'_, Chunked>,
27    indices: &ArrayRef,
28    ctx: &mut ExecutionCtx,
29) -> VortexResult<ArrayRef> {
30    let indices = indices
31        .cast(DType::Primitive(PType::U64, indices.dtype().nullability()))?
32        .execute::<PrimitiveArray>(ctx)?;
33
34    let indices_mask = indices.validity_mask()?;
35    let indices_values = indices.as_slice::<u64>();
36    let n = indices_values.len();
37
38    // 1. Sort (value, orig_pos) pairs so indices for the same chunk are contiguous.
39    //    Skip null indices — their final_take slots stay 0 and are masked null by validity.
40    let mut pairs: Vec<(u64, usize)> = indices_values
41        .iter()
42        .enumerate()
43        .filter(|&(i, _)| indices_mask.value(i))
44        .map(|(i, &v)| (v, i))
45        .collect();
46    pairs.sort_unstable();
47
48    // 2. Fused pass: walk sorted pairs against chunk boundaries.
49    //    - Dedup inline → build per-chunk filter masks
50    //    - Scatter final_take[orig_pos] = dedup_idx for every pair
51    let chunk_offsets = array.chunk_offsets();
52    let nchunks = array.nchunks();
53    let mut chunks = Vec::with_capacity(nchunks);
54    let mut final_take = BufferMut::<u64>::with_capacity(n);
55    final_take.push_n(0u64, n);
56
57    let mut cursor = 0usize;
58    let mut dedup_idx = 0u64;
59
60    for chunk_idx in 0..nchunks {
61        let chunk_start = chunk_offsets[chunk_idx];
62        let chunk_end = chunk_offsets[chunk_idx + 1];
63        let chunk_len = chunk_end - chunk_start;
64        let chunk_end_u64 = u64::try_from(chunk_end)?;
65
66        let range_end = cursor + pairs[cursor..].partition_point(|&(v, _)| v < chunk_end_u64);
67        let chunk_pairs = &pairs[cursor..range_end];
68
69        if !chunk_pairs.is_empty() {
70            let mut local_indices: Vec<usize> = Vec::new();
71            for (i, &(val, orig_pos)) in chunk_pairs.iter().enumerate() {
72                if cursor + i > 0 && val != pairs[cursor + i - 1].0 {
73                    dedup_idx += 1;
74                }
75                let local = usize::try_from(val)? - chunk_start;
76                if local_indices.last() != Some(&local) {
77                    local_indices.push(local);
78                }
79                final_take[orig_pos] = dedup_idx;
80            }
81
82            let filter_mask = Mask::from_indices(chunk_len, local_indices);
83            chunks.push(array.chunk(chunk_idx).filter(filter_mask)?);
84        }
85
86        cursor = range_end;
87    }
88
89    // SAFETY: every chunk came from a filter on a chunk with the same base dtype,
90    // unioned with the index nullability.
91    let flat = unsafe { ChunkedArray::new_unchecked(chunks, array.dtype().clone()) }
92        .into_array()
93        // TODO(joe): can we relax this.
94        .execute::<Canonical>(ctx)?
95        .into_array();
96
97    // 4. Single take to restore original order and expand duplicates.
98    //    Carry the original index validity so null indices produce null outputs.
99    let take_validity =
100        Validity::from_mask(indices.validity_mask()?, indices.dtype().nullability());
101    flat.take(PrimitiveArray::new(final_take.freeze(), take_validity).into_array())
102}
103
104impl TakeExecute for Chunked {
105    fn take(
106        array: ArrayView<'_, Chunked>,
107        indices: &ArrayRef,
108        ctx: &mut ExecutionCtx,
109    ) -> VortexResult<Option<ArrayRef>> {
110        take_chunked(array, indices, ctx).map(Some)
111    }
112}
113
114#[cfg(test)]
115mod test {
116    use vortex_buffer::bitbuffer;
117    use vortex_buffer::buffer;
118    use vortex_error::VortexResult;
119
120    use crate::IntoArray;
121    use crate::ToCanonical;
122    use crate::arrays::BoolArray;
123    use crate::arrays::ChunkedArray;
124    use crate::arrays::PrimitiveArray;
125    use crate::arrays::StructArray;
126    use crate::arrays::chunked::ChunkedArrayExt;
127    use crate::assert_arrays_eq;
128    use crate::compute::conformance::take::test_take_conformance;
129    use crate::dtype::FieldNames;
130    use crate::dtype::Nullability;
131    use crate::validity::Validity;
132
133    #[test]
134    fn test_take() {
135        let a = buffer![1i32, 2, 3].into_array();
136        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
137            .unwrap();
138        assert_eq!(arr.nchunks(), 3);
139        assert_eq!(arr.len(), 9);
140        let indices = buffer![0u64, 0, 6, 4].into_array();
141
142        let result = arr.take(indices).unwrap();
143        assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 1, 1, 2]));
144    }
145
146    #[test]
147    fn test_take_nullable_values() {
148        let a = PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array();
149        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
150            .unwrap();
151        assert_eq!(arr.nchunks(), 3);
152        assert_eq!(arr.len(), 9);
153        let indices = PrimitiveArray::new(buffer![0u64, 0, 6, 4], Validity::NonNullable);
154
155        let result = arr.take(indices.into_array()).unwrap();
156        assert_arrays_eq!(
157            result,
158            PrimitiveArray::from_option_iter([1i32, 1, 1, 2].map(Some))
159        );
160    }
161
162    #[test]
163    fn test_take_nullable_indices() {
164        let a = buffer![1i32, 2, 3].into_array();
165        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
166            .unwrap();
167        assert_eq!(arr.nchunks(), 3);
168        assert_eq!(arr.len(), 9);
169        let indices = PrimitiveArray::new(
170            buffer![0u64, 0, 6, 4],
171            Validity::Array(bitbuffer![1 0 0 1].into_array()),
172        );
173
174        let result = arr.take(indices.into_array()).unwrap();
175        assert_arrays_eq!(
176            result,
177            PrimitiveArray::from_option_iter([Some(1i32), None, None, Some(2)])
178        );
179    }
180
181    #[test]
182    fn test_take_nullable_struct() {
183        let struct_array =
184            StructArray::try_new(FieldNames::default(), vec![], 100, Validity::NonNullable)
185                .unwrap();
186
187        let arr = ChunkedArray::from_iter(vec![
188            struct_array.clone().into_array(),
189            struct_array.into_array(),
190        ]);
191
192        let result = arr
193            .take(PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).into_array())
194            .unwrap();
195
196        let expect = StructArray::try_new(
197            FieldNames::default(),
198            vec![],
199            3,
200            Validity::Array(BoolArray::from_iter(vec![true, false, true]).into_array()),
201        )
202        .unwrap();
203        assert_arrays_eq!(result, expect);
204    }
205
206    #[test]
207    fn test_empty_take() {
208        let a = buffer![1i32, 2, 3].into_array();
209        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
210            .unwrap();
211        assert_eq!(arr.nchunks(), 3);
212        assert_eq!(arr.len(), 9);
213
214        let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
215        let result = arr.take(indices.into_array()).unwrap();
216
217        assert!(result.is_empty());
218        assert_eq!(result.dtype(), arr.dtype());
219        assert_arrays_eq!(
220            result,
221            PrimitiveArray::empty::<i32>(Nullability::NonNullable)
222        );
223    }
224
225    #[test]
226    fn test_take_shuffled_indices() -> VortexResult<()> {
227        let c0 = buffer![0i32, 1, 2].into_array();
228        let c1 = buffer![3i32, 4, 5].into_array();
229        let c2 = buffer![6i32, 7, 8].into_array();
230        let arr = ChunkedArray::try_new(
231            vec![c0, c1, c2],
232            PrimitiveArray::empty::<i32>(Nullability::NonNullable)
233                .dtype()
234                .clone(),
235        )?;
236
237        // Fully shuffled indices that cross every chunk boundary.
238        let indices = buffer![8u64, 0, 5, 3, 2, 7, 1, 6, 4].into_array();
239        let result = arr.take(indices)?;
240
241        assert_arrays_eq!(
242            result,
243            PrimitiveArray::from_iter([8i32, 0, 5, 3, 2, 7, 1, 6, 4])
244        );
245        Ok(())
246    }
247
248    #[test]
249    fn test_take_shuffled_large() -> VortexResult<()> {
250        let nchunks: i32 = 100;
251        let chunk_len: i32 = 1_000;
252        let total = nchunks * chunk_len;
253
254        let chunks: Vec<_> = (0..nchunks)
255            .map(|c| {
256                let start = c * chunk_len;
257                PrimitiveArray::from_iter(start..start + chunk_len).into_array()
258            })
259            .collect();
260        let dtype = chunks[0].dtype().clone();
261        let arr = ChunkedArray::try_new(chunks, dtype)?;
262
263        // Fisher-Yates shuffle with a fixed seed for determinism.
264        let mut indices: Vec<u64> = (0..u64::try_from(total)?).collect();
265        let mut seed: u64 = 0xdeadbeef;
266        for i in (1..indices.len()).rev() {
267            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
268            let j = (seed >> 33) as usize % (i + 1);
269            indices.swap(i, j);
270        }
271
272        let indices_arr = PrimitiveArray::new(
273            vortex_buffer::Buffer::from(indices.clone()),
274            Validity::NonNullable,
275        );
276        let result = arr.take(indices_arr.into_array())?;
277
278        // Verify every element.
279        let result = result.to_primitive();
280        let result_vals = result.as_slice::<i32>();
281        for (pos, &idx) in indices.iter().enumerate() {
282            assert_eq!(
283                result_vals[pos],
284                i32::try_from(idx)?,
285                "mismatch at position {pos}"
286            );
287        }
288        Ok(())
289    }
290
291    #[test]
292    fn test_take_null_indices() -> VortexResult<()> {
293        let c0 = buffer![10i32, 20, 30].into_array();
294        let c1 = buffer![40i32, 50, 60].into_array();
295        let arr = ChunkedArray::try_new(
296            vec![c0, c1],
297            PrimitiveArray::empty::<i32>(Nullability::NonNullable)
298                .dtype()
299                .clone(),
300        )?;
301
302        // Indices with nulls scattered across chunk boundaries.
303        let indices =
304            PrimitiveArray::from_option_iter([Some(5u64), None, Some(0), Some(3), None, Some(2)]);
305        let result = arr.take(indices.into_array())?;
306
307        assert_arrays_eq!(
308            result,
309            PrimitiveArray::from_option_iter([
310                Some(60i32),
311                None,
312                Some(10),
313                Some(40),
314                None,
315                Some(30)
316            ])
317        );
318        Ok(())
319    }
320
321    #[test]
322    fn test_take_chunked_conformance() {
323        let a = buffer![1i32, 2, 3].into_array();
324        let b = buffer![4i32, 5].into_array();
325        let arr = ChunkedArray::try_new(
326            vec![a, b],
327            PrimitiveArray::empty::<i32>(Nullability::NonNullable)
328                .dtype()
329                .clone(),
330        )
331        .unwrap();
332        test_take_conformance(&arr.into_array());
333
334        // Test with nullable chunked array
335        let a = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
336        let b = PrimitiveArray::from_option_iter([Some(4i32), Some(5)]);
337        let dtype = a.dtype().clone();
338        let arr = ChunkedArray::try_new(vec![a.into_array(), b.into_array()], dtype).unwrap();
339        test_take_conformance(&arr.into_array());
340
341        // Test with multiple identical chunks
342        let chunk = buffer![10i32, 20, 30, 40, 50].into_array();
343        let arr = ChunkedArray::try_new(
344            vec![chunk.clone(), chunk.clone(), chunk.clone()],
345            chunk.dtype().clone(),
346        )
347        .unwrap();
348        test_take_conformance(&arr.into_array());
349    }
350}