vortex_array/arrays/chunked/compute/
take.rs1use vortex_buffer::BufferMut;
5use vortex_dtype::DType;
6use vortex_dtype::PType;
7use vortex_error::VortexResult;
8
9use crate::Array;
10use crate::ArrayRef;
11use crate::IntoArray;
12use crate::ToCanonical;
13use crate::arrays::ChunkedVTable;
14use crate::arrays::PrimitiveArray;
15use crate::arrays::TakeExecute;
16use crate::arrays::chunked::ChunkedArray;
17use crate::builtins::ArrayBuiltins;
18use crate::executor::ExecutionCtx;
19use crate::validity::Validity;
20
21fn take_chunked(array: &ChunkedArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
22 let indices = indices
23 .to_array()
24 .cast(DType::Primitive(PType::U64, indices.dtype().nullability()))?
25 .to_primitive();
26
27 let nullability = indices.dtype().nullability();
29 let indices_mask = indices.validity_mask()?;
30 let indices = indices.as_slice::<u64>();
31
32 let mut chunks = Vec::new();
33 let mut indices_in_chunk = BufferMut::<u64>::empty();
34 let mut start = 0;
35 let mut stop = 0;
36 let mut prev_chunk_idx = array.find_chunk_idx(indices[0].try_into()?)?.0;
38 for idx in indices {
39 let idx = usize::try_from(*idx)?;
40 let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx)?;
41
42 if chunk_idx != prev_chunk_idx {
43 let indices_in_chunk_array = PrimitiveArray::new(
45 indices_in_chunk.clone().freeze(),
46 Validity::from_mask(indices_mask.slice(start..stop), nullability),
47 );
48 chunks.push(
49 array
50 .chunk(prev_chunk_idx)
51 .take(indices_in_chunk_array.into_array())?,
52 );
53 indices_in_chunk.clear();
54 start = stop;
55 }
56
57 indices_in_chunk.push(idx_in_chunk as u64);
58 stop += 1;
59 prev_chunk_idx = chunk_idx;
60 }
61
62 if !indices_in_chunk.is_empty() {
63 let indices_in_chunk_array = PrimitiveArray::new(
64 indices_in_chunk.freeze(),
65 Validity::from_mask(indices_mask.slice(start..stop), nullability),
66 );
67 chunks.push(
68 array
69 .chunk(prev_chunk_idx)
70 .take(indices_in_chunk_array.into_array())?,
71 );
72 }
73
74 unsafe {
76 Ok(ChunkedArray::new_unchecked(
77 chunks,
78 array.dtype().clone().union_nullability(nullability),
79 )
80 .into_array())
81 }
82}
83
84impl TakeExecute for ChunkedVTable {
85 fn take(
86 array: &ChunkedArray,
87 indices: &dyn Array,
88 _ctx: &mut ExecutionCtx,
89 ) -> VortexResult<Option<ArrayRef>> {
90 take_chunked(array, indices).map(Some)
91 }
92}
93
94#[cfg(test)]
95mod test {
96 use vortex_buffer::buffer;
97 use vortex_dtype::FieldNames;
98 use vortex_dtype::Nullability;
99
100 use crate::IntoArray;
101 use crate::array::Array;
102 use crate::arrays::BoolArray;
103 use crate::arrays::PrimitiveArray;
104 use crate::arrays::StructArray;
105 use crate::arrays::chunked::ChunkedArray;
106 use crate::assert_arrays_eq;
107 use crate::compute::conformance::take::test_take_conformance;
108 use crate::validity::Validity;
109
110 #[test]
111 fn test_take() {
112 let a = buffer![1i32, 2, 3].into_array();
113 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
114 .unwrap();
115 assert_eq!(arr.nchunks(), 3);
116 assert_eq!(arr.len(), 9);
117 let indices = buffer![0u64, 0, 6, 4].into_array();
118
119 let result = arr.take(indices.to_array()).unwrap();
120 assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 1, 1, 2]));
121 }
122
123 #[test]
124 fn test_take_nullability() {
125 let struct_array =
126 StructArray::try_new(FieldNames::default(), vec![], 100, Validity::NonNullable)
127 .unwrap();
128
129 let arr = ChunkedArray::from_iter(vec![struct_array.to_array(), struct_array.to_array()]);
130
131 let result = arr
132 .take(PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).to_array())
133 .unwrap();
134
135 let expect = StructArray::try_new(
136 FieldNames::default(),
137 vec![],
138 3,
139 Validity::Array(BoolArray::from_iter(vec![true, false, true]).to_array()),
140 )
141 .unwrap();
142 assert_arrays_eq!(result, expect);
143 }
144
145 #[test]
146 fn test_empty_take() {
147 let a = buffer![1i32, 2, 3].into_array();
148 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
149 .unwrap();
150 assert_eq!(arr.nchunks(), 3);
151 assert_eq!(arr.len(), 9);
152
153 let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
154 let result = arr.take(indices.to_array()).unwrap();
155
156 assert!(result.is_empty());
157 assert_eq!(result.dtype(), arr.dtype());
158 assert_arrays_eq!(
159 result,
160 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
161 );
162 }
163
164 #[test]
165 fn test_take_chunked_conformance() {
166 let a = buffer![1i32, 2, 3].into_array();
167 let b = buffer![4i32, 5].into_array();
168 let arr = ChunkedArray::try_new(
169 vec![a, b],
170 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
171 .dtype()
172 .clone(),
173 )
174 .unwrap();
175 test_take_conformance(arr.as_ref());
176
177 let a = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
179 let b = PrimitiveArray::from_option_iter([Some(4i32), Some(5)]);
180 let dtype = a.dtype().clone();
181 let arr = ChunkedArray::try_new(vec![a.into_array(), b.into_array()], dtype).unwrap();
182 test_take_conformance(arr.as_ref());
183
184 let chunk = buffer![10i32, 20, 30, 40, 50].into_array();
186 let arr = ChunkedArray::try_new(
187 vec![chunk.clone(), chunk.clone(), chunk.clone()],
188 chunk.dtype().clone(),
189 )
190 .unwrap();
191 test_take_conformance(arr.as_ref());
192 }
193}