vortex_array/arrays/chunked/compute/
take.rs1use vortex_buffer::BufferMut;
2use vortex_dtype::{DType, PType};
3use vortex_error::VortexResult;
4
5use crate::arrays::chunked::ChunkedArray;
6use crate::arrays::{ChunkedVTable, PrimitiveArray};
7use crate::compute::{TakeKernel, TakeKernelAdapter, cast, take};
8use crate::validity::Validity;
9use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
10
11impl TakeKernel for ChunkedVTable {
12 fn take(&self, array: &ChunkedArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
13 let indices = cast(
14 indices,
15 &DType::Primitive(PType::U64, indices.dtype().nullability()),
16 )?
17 .to_primitive()?;
18
19 let nullability = indices.dtype().nullability();
21 let indices_mask = indices.validity_mask()?;
22 let indices = indices.as_slice::<u64>();
23
24 let mut chunks = Vec::new();
25 let mut indices_in_chunk = BufferMut::<u64>::empty();
26 let mut start = 0;
27 let mut stop = 0;
28 let mut prev_chunk_idx = array.find_chunk_idx(indices[0].try_into()?).0;
29 for idx in indices {
30 let idx = usize::try_from(*idx)?;
31 let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx);
32
33 if chunk_idx != prev_chunk_idx {
34 let indices_in_chunk_array = PrimitiveArray::new(
36 indices_in_chunk.clone().freeze(),
37 Validity::from_mask(indices_mask.slice(start, stop - start), nullability),
38 );
39 chunks.push(take(
40 array.chunk(prev_chunk_idx)?,
41 indices_in_chunk_array.as_ref(),
42 )?);
43 indices_in_chunk.clear();
44 start = stop;
45 }
46
47 indices_in_chunk.push(idx_in_chunk as u64);
48 stop += 1;
49 prev_chunk_idx = chunk_idx;
50 }
51
52 if !indices_in_chunk.is_empty() {
53 let indices_in_chunk_array = PrimitiveArray::new(
54 indices_in_chunk.freeze(),
55 Validity::from_mask(indices_mask.slice(start, stop - start), nullability),
56 );
57 chunks.push(take(
58 array.chunk(prev_chunk_idx)?,
59 indices_in_chunk_array.as_ref(),
60 )?);
61 }
62
63 Ok(ChunkedArray::new_unchecked(
64 chunks,
65 array.dtype().clone().union_nullability(nullability),
66 )
67 .into_array())
68 }
69}
70
71register_kernel!(TakeKernelAdapter(ChunkedVTable).lift());
72
73#[cfg(test)]
74mod test {
75 use vortex_buffer::buffer;
76 use vortex_dtype::FieldNames;
77
78 use crate::IntoArray;
79 use crate::array::Array;
80 use crate::arrays::chunked::ChunkedArray;
81 use crate::arrays::{BoolArray, PrimitiveArray, StructArray};
82 use crate::canonical::ToCanonical;
83 use crate::compute::take;
84 use crate::validity::Validity;
85
86 #[test]
87 fn test_take() {
88 let a = buffer![1i32, 2, 3].into_array();
89 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
90 .unwrap();
91 assert_eq!(arr.nchunks(), 3);
92 assert_eq!(arr.len(), 9);
93 let indices = buffer![0u64, 0, 6, 4].into_array();
94
95 let result = take(arr.as_ref(), indices.as_ref())
96 .unwrap()
97 .to_primitive()
98 .unwrap();
99 assert_eq!(result.as_slice::<i32>(), &[1, 1, 1, 2]);
100 }
101
102 #[test]
103 fn test_take_nullability() {
104 let struct_array =
105 StructArray::try_new(FieldNames::default(), vec![], 100, Validity::NonNullable)
106 .unwrap();
107
108 let arr = ChunkedArray::from_iter(vec![struct_array.to_array(), struct_array.to_array()]);
109
110 let result = take(
111 arr.as_ref(),
112 PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).as_ref(),
113 )
114 .unwrap();
115
116 let expect = StructArray::try_new(
117 FieldNames::default(),
118 vec![],
119 3,
120 Validity::Array(BoolArray::from_iter(vec![true, false, true]).to_array()),
121 )
122 .unwrap();
123 assert_eq!(result.dtype(), expect.dtype());
124 assert_eq!(result.scalar_at(0).unwrap(), expect.scalar_at(0).unwrap());
125 assert_eq!(result.scalar_at(1).unwrap(), expect.scalar_at(1).unwrap());
126 assert_eq!(result.scalar_at(2).unwrap(), expect.scalar_at(2).unwrap());
127 }
128}