vortex_array/arrays/chunked/compute/
take.rs

1use vortex_buffer::BufferMut;
2use vortex_dtype::{DType, PType};
3use vortex_error::VortexResult;
4
5use crate::arrays::chunked::ChunkedArray;
6use crate::arrays::{ChunkedVTable, PrimitiveArray};
7use crate::compute::{TakeKernel, TakeKernelAdapter, cast, take};
8use crate::validity::Validity;
9use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
10
11impl TakeKernel for ChunkedVTable {
12    fn take(&self, array: &ChunkedArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
13        let indices = cast(
14            indices,
15            &DType::Primitive(PType::U64, indices.dtype().nullability()),
16        )?
17        .to_primitive()?;
18
19        // TODO(joe): Should we split this implementation based on indices nullability?
20        let nullability = indices.dtype().nullability();
21        let indices_mask = indices.validity_mask()?;
22        let indices = indices.as_slice::<u64>();
23
24        let mut chunks = Vec::new();
25        let mut indices_in_chunk = BufferMut::<u64>::empty();
26        let mut start = 0;
27        let mut stop = 0;
28        let mut prev_chunk_idx = array.find_chunk_idx(indices[0].try_into()?).0;
29        for idx in indices {
30            let idx = usize::try_from(*idx)?;
31            let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx);
32
33            if chunk_idx != prev_chunk_idx {
34                // Start a new chunk
35                let indices_in_chunk_array = PrimitiveArray::new(
36                    indices_in_chunk.clone().freeze(),
37                    Validity::from_mask(indices_mask.slice(start, stop - start), nullability),
38                );
39                chunks.push(take(
40                    array.chunk(prev_chunk_idx)?,
41                    indices_in_chunk_array.as_ref(),
42                )?);
43                indices_in_chunk.clear();
44                start = stop;
45            }
46
47            indices_in_chunk.push(idx_in_chunk as u64);
48            stop += 1;
49            prev_chunk_idx = chunk_idx;
50        }
51
52        if !indices_in_chunk.is_empty() {
53            let indices_in_chunk_array = PrimitiveArray::new(
54                indices_in_chunk.freeze(),
55                Validity::from_mask(indices_mask.slice(start, stop - start), nullability),
56            );
57            chunks.push(take(
58                array.chunk(prev_chunk_idx)?,
59                indices_in_chunk_array.as_ref(),
60            )?);
61        }
62
63        Ok(ChunkedArray::new_unchecked(
64            chunks,
65            array.dtype().clone().union_nullability(nullability),
66        )
67        .into_array())
68    }
69}
70
71register_kernel!(TakeKernelAdapter(ChunkedVTable).lift());
72
73#[cfg(test)]
74mod test {
75    use vortex_buffer::buffer;
76    use vortex_dtype::FieldNames;
77
78    use crate::IntoArray;
79    use crate::array::Array;
80    use crate::arrays::chunked::ChunkedArray;
81    use crate::arrays::{BoolArray, PrimitiveArray, StructArray};
82    use crate::canonical::ToCanonical;
83    use crate::compute::take;
84    use crate::validity::Validity;
85
86    #[test]
87    fn test_take() {
88        let a = buffer![1i32, 2, 3].into_array();
89        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
90            .unwrap();
91        assert_eq!(arr.nchunks(), 3);
92        assert_eq!(arr.len(), 9);
93        let indices = buffer![0u64, 0, 6, 4].into_array();
94
95        let result = take(arr.as_ref(), indices.as_ref())
96            .unwrap()
97            .to_primitive()
98            .unwrap();
99        assert_eq!(result.as_slice::<i32>(), &[1, 1, 1, 2]);
100    }
101
102    #[test]
103    fn test_take_nullability() {
104        let struct_array =
105            StructArray::try_new(FieldNames::default(), vec![], 100, Validity::NonNullable)
106                .unwrap();
107
108        let arr = ChunkedArray::from_iter(vec![struct_array.to_array(), struct_array.to_array()]);
109
110        let result = take(
111            arr.as_ref(),
112            PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).as_ref(),
113        )
114        .unwrap();
115
116        let expect = StructArray::try_new(
117            FieldNames::default(),
118            vec![],
119            3,
120            Validity::Array(BoolArray::from_iter(vec![true, false, true]).to_array()),
121        )
122        .unwrap();
123        assert_eq!(result.dtype(), expect.dtype());
124        assert_eq!(result.scalar_at(0).unwrap(), expect.scalar_at(0).unwrap());
125        assert_eq!(result.scalar_at(1).unwrap(), expect.scalar_at(1).unwrap());
126        assert_eq!(result.scalar_at(2).unwrap(), expect.scalar_at(2).unwrap());
127    }
128}