vortex_array/arrays/chunked/compute/
take.rs

1use itertools::Itertools;
2use vortex_buffer::BufferMut;
3use vortex_dtype::PType;
4use vortex_error::VortexResult;
5use vortex_scalar::Scalar;
6
7use crate::arrays::ChunkedEncoding;
8use crate::arrays::chunked::ChunkedArray;
9use crate::compute::{
10    SearchSortedSide, TakeFn, scalar_at, search_sorted_usize, slice, sub_scalar, take, try_cast,
11};
12use crate::{Array, ArrayRef, IntoArray, ToCanonical};
13
14impl TakeFn<&ChunkedArray> for ChunkedEncoding {
15    fn take(&self, array: &ChunkedArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
16        // Fast path for strict sorted indices.
17        if indices
18            .statistics()
19            .compute_is_strict_sorted()
20            .unwrap_or(false)
21        {
22            if array.len() == indices.len() {
23                return Ok(array.to_array().into_array());
24            }
25
26            return take_strict_sorted(array, indices);
27        }
28
29        let indices = try_cast(indices, PType::U64.into())?.to_primitive()?;
30
31        // While the chunk idx remains the same, accumulate a list of chunk indices.
32        let mut chunks = Vec::new();
33        let mut indices_in_chunk = BufferMut::<u64>::empty();
34        let mut prev_chunk_idx = array
35            .find_chunk_idx(indices.as_slice::<u64>()[0].try_into()?)
36            .0;
37        for idx in indices.as_slice::<u64>() {
38            let idx = usize::try_from(*idx)?;
39            let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx);
40
41            if chunk_idx != prev_chunk_idx {
42                // Start a new chunk
43                let indices_in_chunk_array = indices_in_chunk.clone().into_array();
44                chunks.push(take(array.chunk(prev_chunk_idx)?, &indices_in_chunk_array)?);
45                indices_in_chunk.clear();
46            }
47
48            indices_in_chunk.push(idx_in_chunk as u64);
49            prev_chunk_idx = chunk_idx;
50        }
51
52        if !indices_in_chunk.is_empty() {
53            let indices_in_chunk_array = indices_in_chunk.into_array();
54            chunks.push(take(array.chunk(prev_chunk_idx)?, &indices_in_chunk_array)?);
55        }
56
57        Ok(ChunkedArray::new_unchecked(chunks, array.dtype().clone()).into_array())
58    }
59}
60
61/// When the indices are non-null and strict-sorted, we can do better
62fn take_strict_sorted(chunked: &ChunkedArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
63    let mut indices_by_chunk = vec![None; chunked.nchunks()];
64
65    // Track our position in the indices array
66    let mut pos = 0;
67    while pos < indices.len() {
68        // Locate the chunk index for the current index
69        let idx = usize::try_from(&scalar_at(indices, pos)?)?;
70        let (chunk_idx, _idx_in_chunk) = chunked.find_chunk_idx(idx);
71
72        // Find the end of this chunk, and locate that position in the indices array.
73        let chunk_begin = usize::try_from(chunked.chunk_offsets()[chunk_idx])?;
74        let chunk_end = usize::try_from(chunked.chunk_offsets()[chunk_idx + 1])?;
75        let chunk_end_pos =
76            search_sorted_usize(indices, chunk_end, SearchSortedSide::Left)?.to_index();
77
78        // Now we can say the slice of indices belonging to this chunk is [pos, chunk_end_pos)
79        let chunk_indices = slice(indices, pos, chunk_end_pos)?;
80
81        // Adjust the indices so they're relative to the chunk
82        // Note. Indices might not have a dtype big enough to fit chunk_begin after cast,
83        // if it does cast the scalar otherwise upcast the indices.
84        let chunk_indices = if chunk_begin
85            < PType::try_from(chunk_indices.dtype())?
86                .max_value_as_u64()
87                .try_into()?
88        {
89            sub_scalar(
90                &chunk_indices,
91                Scalar::from(chunk_begin).cast(chunk_indices.dtype())?,
92            )?
93        } else {
94            // Note. this try_cast (memory copy) is unnecessary, could instead upcast in the subtract fn.
95            //  and avoid an extra
96            let u64_chunk_indices = try_cast(&chunk_indices, PType::U64.into())?;
97            sub_scalar(&u64_chunk_indices, chunk_begin.into())?
98        };
99
100        indices_by_chunk[chunk_idx] = Some(chunk_indices);
101
102        pos = chunk_end_pos;
103    }
104
105    // Now we can take the chunks
106    let chunks = indices_by_chunk
107        .into_iter()
108        .enumerate()
109        .filter_map(|(chunk_idx, indices)| indices.map(|i| (chunk_idx, i)))
110        .map(|(chunk_idx, chunk_indices)| take(chunked.chunk(chunk_idx)?, &chunk_indices))
111        .try_collect()?;
112
113    Ok(ChunkedArray::try_new(chunks, chunked.dtype().clone())?.into_array())
114}
115
116#[cfg(test)]
117mod test {
118    use vortex_buffer::buffer;
119
120    use crate::IntoArray;
121    use crate::array::Array;
122    use crate::arrays::chunked::ChunkedArray;
123    use crate::canonical::ToCanonical;
124    use crate::compute::take;
125
126    #[test]
127    fn test_take() {
128        let a = buffer![1i32, 2, 3].into_array();
129        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
130            .unwrap();
131        assert_eq!(arr.nchunks(), 3);
132        assert_eq!(arr.len(), 9);
133        let indices = buffer![0u64, 0, 6, 4].into_array();
134
135        let result = take(&arr, &indices).unwrap().to_primitive().unwrap();
136        assert_eq!(result.as_slice::<i32>(), &[1, 1, 1, 2]);
137    }
138}