vortex_array/arrays/chunked/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BufferMut;
5use vortex_dtype::DType;
6use vortex_dtype::PType;
7use vortex_error::VortexResult;
8
9use crate::Array;
10use crate::ArrayRef;
11use crate::IntoArray;
12use crate::ToCanonical;
13use crate::arrays::ChunkedVTable;
14use crate::arrays::PrimitiveArray;
15use crate::arrays::chunked::ChunkedArray;
16use crate::compute::TakeKernel;
17use crate::compute::TakeKernelAdapter;
18use crate::compute::cast;
19use crate::compute::take;
20use crate::register_kernel;
21use crate::validity::Validity;
22
23impl TakeKernel for ChunkedVTable {
24    fn take(&self, array: &ChunkedArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
25        let indices = cast(
26            indices,
27            &DType::Primitive(PType::U64, indices.dtype().nullability()),
28        )?
29        .to_primitive();
30
31        // TODO(joe): Should we split this implementation based on indices nullability?
32        let nullability = indices.dtype().nullability();
33        let indices_mask = indices.validity_mask();
34        let indices = indices.as_slice::<u64>();
35
36        let mut chunks = Vec::new();
37        let mut indices_in_chunk = BufferMut::<u64>::empty();
38        let mut start = 0;
39        let mut stop = 0;
40        // We assume indices are non-empty as it's handled in the top-level `take` function
41        let mut prev_chunk_idx = array.find_chunk_idx(indices[0].try_into()?).0;
42        for idx in indices {
43            let idx = usize::try_from(*idx)?;
44            let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx);
45
46            if chunk_idx != prev_chunk_idx {
47                // Start a new chunk
48                let indices_in_chunk_array = PrimitiveArray::new(
49                    indices_in_chunk.clone().freeze(),
50                    Validity::from_mask(indices_mask.slice(start..stop), nullability),
51                );
52                chunks.push(take(
53                    array.chunk(prev_chunk_idx),
54                    indices_in_chunk_array.as_ref(),
55                )?);
56                indices_in_chunk.clear();
57                start = stop;
58            }
59
60            indices_in_chunk.push(idx_in_chunk as u64);
61            stop += 1;
62            prev_chunk_idx = chunk_idx;
63        }
64
65        if !indices_in_chunk.is_empty() {
66            let indices_in_chunk_array = PrimitiveArray::new(
67                indices_in_chunk.freeze(),
68                Validity::from_mask(indices_mask.slice(start..stop), nullability),
69            );
70            chunks.push(take(
71                array.chunk(prev_chunk_idx),
72                indices_in_chunk_array.as_ref(),
73            )?);
74        }
75
76        // SAFETY: take on chunks that all have same DType retains same DType
77        unsafe {
78            Ok(ChunkedArray::new_unchecked(
79                chunks,
80                array.dtype().clone().union_nullability(nullability),
81            )
82            .into_array())
83        }
84    }
85}
86
87register_kernel!(TakeKernelAdapter(ChunkedVTable).lift());
88
89#[cfg(test)]
90mod test {
91    use vortex_buffer::buffer;
92    use vortex_dtype::FieldNames;
93    use vortex_dtype::Nullability;
94
95    use crate::IntoArray;
96    use crate::array::Array;
97    use crate::arrays::BoolArray;
98    use crate::arrays::PrimitiveArray;
99    use crate::arrays::StructArray;
100    use crate::arrays::chunked::ChunkedArray;
101    use crate::canonical::ToCanonical;
102    use crate::compute::conformance::take::test_take_conformance;
103    use crate::compute::take;
104    use crate::validity::Validity;
105
106    #[test]
107    fn test_take() {
108        let a = buffer![1i32, 2, 3].into_array();
109        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
110            .unwrap();
111        assert_eq!(arr.nchunks(), 3);
112        assert_eq!(arr.len(), 9);
113        let indices = buffer![0u64, 0, 6, 4].into_array();
114
115        let result = take(arr.as_ref(), indices.as_ref()).unwrap().to_primitive();
116        assert_eq!(result.as_slice::<i32>(), &[1, 1, 1, 2]);
117    }
118
119    #[test]
120    fn test_take_nullability() {
121        let struct_array =
122            StructArray::try_new(FieldNames::default(), vec![], 100, Validity::NonNullable)
123                .unwrap();
124
125        let arr = ChunkedArray::from_iter(vec![struct_array.to_array(), struct_array.to_array()]);
126
127        let result = take(
128            arr.as_ref(),
129            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).as_ref(),
130        )
131        .unwrap();
132
133        let expect = StructArray::try_new(
134            FieldNames::default(),
135            vec![],
136            3,
137            Validity::Array(BoolArray::from_iter(vec![true, false, true]).to_array()),
138        )
139        .unwrap();
140        assert_eq!(result.dtype(), expect.dtype());
141        assert_eq!(result.scalar_at(0), expect.scalar_at(0));
142        assert_eq!(result.scalar_at(1), expect.scalar_at(1));
143        assert_eq!(result.scalar_at(2), expect.scalar_at(2));
144    }
145
146    #[test]
147    fn test_empty_take() {
148        let a = buffer![1i32, 2, 3].into_array();
149        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
150            .unwrap();
151        assert_eq!(arr.nchunks(), 3);
152        assert_eq!(arr.len(), 9);
153
154        let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
155        let result = take(arr.as_ref(), indices.as_ref()).unwrap().to_primitive();
156
157        assert!(result.is_empty());
158        assert_eq!(result.dtype(), arr.dtype());
159        assert!(result.as_slice::<i32>().is_empty());
160    }
161
162    #[test]
163    fn test_take_chunked_conformance() {
164        let a = buffer![1i32, 2, 3].into_array();
165        let b = buffer![4i32, 5].into_array();
166        let arr = ChunkedArray::try_new(
167            vec![a, b],
168            PrimitiveArray::empty::<i32>(Nullability::NonNullable)
169                .dtype()
170                .clone(),
171        )
172        .unwrap();
173        test_take_conformance(arr.as_ref());
174
175        // Test with nullable chunked array
176        let a = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
177        let b = PrimitiveArray::from_option_iter([Some(4i32), Some(5)]);
178        let dtype = a.dtype().clone();
179        let arr = ChunkedArray::try_new(vec![a.into_array(), b.into_array()], dtype).unwrap();
180        test_take_conformance(arr.as_ref());
181
182        // Test with multiple identical chunks
183        let chunk = buffer![10i32, 20, 30, 40, 50].into_array();
184        let arr = ChunkedArray::try_new(
185            vec![chunk.clone(), chunk.clone(), chunk.clone()],
186            chunk.dtype().clone(),
187        )
188        .unwrap();
189        test_take_conformance(arr.as_ref());
190    }
191}