vortex_array/arrays/chunked/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BufferMut;
5use vortex_dtype::{DType, PType};
6use vortex_error::VortexResult;
7
8use crate::arrays::chunked::ChunkedArray;
9use crate::arrays::{ChunkedVTable, PrimitiveArray};
10use crate::compute::{TakeKernel, TakeKernelAdapter, cast, take};
11use crate::validity::Validity;
12use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
13
14impl TakeKernel for ChunkedVTable {
15    fn take(&self, array: &ChunkedArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
16        let indices = cast(
17            indices,
18            &DType::Primitive(PType::U64, indices.dtype().nullability()),
19        )?
20        .to_primitive()?;
21
22        // TODO(joe): Should we split this implementation based on indices nullability?
23        let nullability = indices.dtype().nullability();
24        let indices_mask = indices.validity_mask()?;
25        let indices = indices.as_slice::<u64>();
26
27        let mut chunks = Vec::new();
28        let mut indices_in_chunk = BufferMut::<u64>::empty();
29        let mut start = 0;
30        let mut stop = 0;
31        // We assume indices are non-empty as it's handled in the top-level `take` function
32        let mut prev_chunk_idx = array.find_chunk_idx(indices[0].try_into()?).0;
33        for idx in indices {
34            let idx = usize::try_from(*idx)?;
35            let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx);
36
37            if chunk_idx != prev_chunk_idx {
38                // Start a new chunk
39                let indices_in_chunk_array = PrimitiveArray::new(
40                    indices_in_chunk.clone().freeze(),
41                    Validity::from_mask(indices_mask.slice(start, stop - start), nullability),
42                );
43                chunks.push(take(
44                    array.chunk(prev_chunk_idx),
45                    indices_in_chunk_array.as_ref(),
46                )?);
47                indices_in_chunk.clear();
48                start = stop;
49            }
50
51            indices_in_chunk.push(idx_in_chunk as u64);
52            stop += 1;
53            prev_chunk_idx = chunk_idx;
54        }
55
56        if !indices_in_chunk.is_empty() {
57            let indices_in_chunk_array = PrimitiveArray::new(
58                indices_in_chunk.freeze(),
59                Validity::from_mask(indices_mask.slice(start, stop - start), nullability),
60            );
61            chunks.push(take(
62                array.chunk(prev_chunk_idx),
63                indices_in_chunk_array.as_ref(),
64            )?);
65        }
66
67        // SAFETY: take on chunks that all have same DType retains same DType
68        unsafe {
69            Ok(ChunkedArray::new_unchecked(
70                chunks,
71                array.dtype().clone().union_nullability(nullability),
72            )
73            .into_array())
74        }
75    }
76}
77
78register_kernel!(TakeKernelAdapter(ChunkedVTable).lift());
79
80#[cfg(test)]
81mod test {
82    use vortex_buffer::buffer;
83    use vortex_dtype::{FieldNames, Nullability};
84
85    use crate::IntoArray;
86    use crate::array::Array;
87    use crate::arrays::chunked::ChunkedArray;
88    use crate::arrays::{BoolArray, PrimitiveArray, StructArray};
89    use crate::canonical::ToCanonical;
90    use crate::compute::conformance::take::test_take_conformance;
91    use crate::compute::take;
92    use crate::validity::Validity;
93
94    #[test]
95    fn test_take() {
96        let a = buffer![1i32, 2, 3].into_array();
97        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
98            .unwrap();
99        assert_eq!(arr.nchunks(), 3);
100        assert_eq!(arr.len(), 9);
101        let indices = buffer![0u64, 0, 6, 4].into_array();
102
103        let result = take(arr.as_ref(), indices.as_ref())
104            .unwrap()
105            .to_primitive()
106            .unwrap();
107        assert_eq!(result.as_slice::<i32>(), &[1, 1, 1, 2]);
108    }
109
110    #[test]
111    fn test_take_nullability() {
112        let struct_array =
113            StructArray::try_new(FieldNames::default(), vec![], 100, Validity::NonNullable)
114                .unwrap();
115
116        let arr = ChunkedArray::from_iter(vec![struct_array.to_array(), struct_array.to_array()]);
117
118        let result = take(
119            arr.as_ref(),
120            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).as_ref(),
121        )
122        .unwrap();
123
124        let expect = StructArray::try_new(
125            FieldNames::default(),
126            vec![],
127            3,
128            Validity::Array(BoolArray::from_iter(vec![true, false, true]).to_array()),
129        )
130        .unwrap();
131        assert_eq!(result.dtype(), expect.dtype());
132        assert_eq!(result.scalar_at(0), expect.scalar_at(0));
133        assert_eq!(result.scalar_at(1), expect.scalar_at(1));
134        assert_eq!(result.scalar_at(2), expect.scalar_at(2));
135    }
136
137    #[test]
138    fn test_empty_take() {
139        let a = buffer![1i32, 2, 3].into_array();
140        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
141            .unwrap();
142        assert_eq!(arr.nchunks(), 3);
143        assert_eq!(arr.len(), 9);
144
145        let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
146        let result = take(arr.as_ref(), indices.as_ref())
147            .unwrap()
148            .to_primitive()
149            .unwrap();
150
151        assert!(result.is_empty());
152        assert_eq!(result.dtype(), arr.dtype());
153        assert!(result.as_slice::<i32>().is_empty());
154    }
155
156    #[test]
157    fn test_take_chunked_conformance() {
158        let a = buffer![1i32, 2, 3].into_array();
159        let b = buffer![4i32, 5].into_array();
160        let arr = ChunkedArray::try_new(
161            vec![a, b],
162            PrimitiveArray::empty::<i32>(Nullability::NonNullable)
163                .dtype()
164                .clone(),
165        )
166        .unwrap();
167        test_take_conformance(arr.as_ref());
168
169        // Test with nullable chunked array
170        let a = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
171        let b = PrimitiveArray::from_option_iter([Some(4i32), Some(5)]);
172        let dtype = a.dtype().clone();
173        let arr = ChunkedArray::try_new(vec![a.into_array(), b.into_array()], dtype).unwrap();
174        test_take_conformance(arr.as_ref());
175
176        // Test with multiple identical chunks
177        let chunk = buffer![10i32, 20, 30, 40, 50].into_array();
178        let arr = ChunkedArray::try_new(
179            vec![chunk.clone(), chunk.clone(), chunk.clone()],
180            chunk.dtype().clone(),
181        )
182        .unwrap();
183        test_take_conformance(arr.as_ref());
184    }
185}