Skip to main content

vortex_array/arrays/chunked/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BufferMut;
5use vortex_dtype::DType;
6use vortex_dtype::PType;
7use vortex_error::VortexResult;
8
9use crate::Array;
10use crate::ArrayRef;
11use crate::IntoArray;
12use crate::ToCanonical;
13use crate::arrays::ChunkedVTable;
14use crate::arrays::PrimitiveArray;
15use crate::arrays::TakeExecute;
16use crate::arrays::chunked::ChunkedArray;
17use crate::builtins::ArrayBuiltins;
18use crate::executor::ExecutionCtx;
19use crate::validity::Validity;
20
21fn take_chunked(array: &ChunkedArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
22    let indices = indices
23        .to_array()
24        .cast(DType::Primitive(PType::U64, indices.dtype().nullability()))?
25        .to_primitive();
26
27    // TODO(joe): Should we split this implementation based on indices nullability?
28    let nullability = indices.dtype().nullability();
29    let indices_mask = indices.validity_mask()?;
30    let indices = indices.as_slice::<u64>();
31
32    let mut chunks = Vec::new();
33    let mut indices_in_chunk = BufferMut::<u64>::empty();
34    let mut start = 0;
35    let mut stop = 0;
36    // We assume indices are non-empty as it's handled in the top-level `take` function
37    let mut prev_chunk_idx = array.find_chunk_idx(indices[0].try_into()?)?.0;
38    for idx in indices {
39        let idx = usize::try_from(*idx)?;
40        let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx)?;
41
42        if chunk_idx != prev_chunk_idx {
43            // Start a new chunk
44            let indices_in_chunk_array = PrimitiveArray::new(
45                indices_in_chunk.clone().freeze(),
46                Validity::from_mask(indices_mask.slice(start..stop), nullability),
47            );
48            chunks.push(
49                array
50                    .chunk(prev_chunk_idx)
51                    .take(indices_in_chunk_array.into_array())?,
52            );
53            indices_in_chunk.clear();
54            start = stop;
55        }
56
57        indices_in_chunk.push(idx_in_chunk as u64);
58        stop += 1;
59        prev_chunk_idx = chunk_idx;
60    }
61
62    if !indices_in_chunk.is_empty() {
63        let indices_in_chunk_array = PrimitiveArray::new(
64            indices_in_chunk.freeze(),
65            Validity::from_mask(indices_mask.slice(start..stop), nullability),
66        );
67        chunks.push(
68            array
69                .chunk(prev_chunk_idx)
70                .take(indices_in_chunk_array.into_array())?,
71        );
72    }
73
74    // SAFETY: take on chunks that all have same DType retains same DType
75    unsafe {
76        Ok(ChunkedArray::new_unchecked(
77            chunks,
78            array.dtype().clone().union_nullability(nullability),
79        )
80        .into_array())
81    }
82}
83
84impl TakeExecute for ChunkedVTable {
85    fn take(
86        array: &ChunkedArray,
87        indices: &dyn Array,
88        _ctx: &mut ExecutionCtx,
89    ) -> VortexResult<Option<ArrayRef>> {
90        take_chunked(array, indices).map(Some)
91    }
92}
93
94#[cfg(test)]
95mod test {
96    use vortex_buffer::buffer;
97    use vortex_dtype::FieldNames;
98    use vortex_dtype::Nullability;
99
100    use crate::IntoArray;
101    use crate::array::Array;
102    use crate::arrays::BoolArray;
103    use crate::arrays::PrimitiveArray;
104    use crate::arrays::StructArray;
105    use crate::arrays::chunked::ChunkedArray;
106    use crate::assert_arrays_eq;
107    use crate::compute::conformance::take::test_take_conformance;
108    use crate::validity::Validity;
109
110    #[test]
111    fn test_take() {
112        let a = buffer![1i32, 2, 3].into_array();
113        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
114            .unwrap();
115        assert_eq!(arr.nchunks(), 3);
116        assert_eq!(arr.len(), 9);
117        let indices = buffer![0u64, 0, 6, 4].into_array();
118
119        let result = arr.take(indices.to_array()).unwrap();
120        assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 1, 1, 2]));
121    }
122
123    #[test]
124    fn test_take_nullability() {
125        let struct_array =
126            StructArray::try_new(FieldNames::default(), vec![], 100, Validity::NonNullable)
127                .unwrap();
128
129        let arr = ChunkedArray::from_iter(vec![struct_array.to_array(), struct_array.to_array()]);
130
131        let result = arr
132            .take(PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).to_array())
133            .unwrap();
134
135        let expect = StructArray::try_new(
136            FieldNames::default(),
137            vec![],
138            3,
139            Validity::Array(BoolArray::from_iter(vec![true, false, true]).to_array()),
140        )
141        .unwrap();
142        assert_arrays_eq!(result, expect);
143    }
144
145    #[test]
146    fn test_empty_take() {
147        let a = buffer![1i32, 2, 3].into_array();
148        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
149            .unwrap();
150        assert_eq!(arr.nchunks(), 3);
151        assert_eq!(arr.len(), 9);
152
153        let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
154        let result = arr.take(indices.to_array()).unwrap();
155
156        assert!(result.is_empty());
157        assert_eq!(result.dtype(), arr.dtype());
158        assert_arrays_eq!(
159            result,
160            PrimitiveArray::empty::<i32>(Nullability::NonNullable)
161        );
162    }
163
164    #[test]
165    fn test_take_chunked_conformance() {
166        let a = buffer![1i32, 2, 3].into_array();
167        let b = buffer![4i32, 5].into_array();
168        let arr = ChunkedArray::try_new(
169            vec![a, b],
170            PrimitiveArray::empty::<i32>(Nullability::NonNullable)
171                .dtype()
172                .clone(),
173        )
174        .unwrap();
175        test_take_conformance(arr.as_ref());
176
177        // Test with nullable chunked array
178        let a = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
179        let b = PrimitiveArray::from_option_iter([Some(4i32), Some(5)]);
180        let dtype = a.dtype().clone();
181        let arr = ChunkedArray::try_new(vec![a.into_array(), b.into_array()], dtype).unwrap();
182        test_take_conformance(arr.as_ref());
183
184        // Test with multiple identical chunks
185        let chunk = buffer![10i32, 20, 30, 40, 50].into_array();
186        let arr = ChunkedArray::try_new(
187            vec![chunk.clone(), chunk.clone(), chunk.clone()],
188            chunk.dtype().clone(),
189        )
190        .unwrap();
191        test_take_conformance(arr.as_ref());
192    }
193}