vortex_array/arrays/varbin/compute/
take.rs1use vortex_buffer::BitBufferMut;
5use vortex_buffer::BufferMut;
6use vortex_buffer::ByteBufferMut;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_panic;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::array::ArrayView;
15use crate::arrays::PrimitiveArray;
16use crate::arrays::VarBin;
17use crate::arrays::VarBinArray;
18use crate::arrays::dict::TakeExecute;
19use crate::arrays::varbin::VarBinArrayExt;
20use crate::dtype::DType;
21use crate::dtype::IntegerPType;
22use crate::executor::ExecutionCtx;
23use crate::match_each_integer_ptype;
24use crate::validity::Validity;
25
26impl TakeExecute for VarBin {
27 fn take(
28 array: ArrayView<'_, VarBin>,
29 indices: &ArrayRef,
30 ctx: &mut ExecutionCtx,
31 ) -> VortexResult<Option<ArrayRef>> {
32 let offsets = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
34 let data = array.bytes();
35 let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
36 let dtype = array
37 .dtype()
38 .clone()
39 .union_nullability(indices.dtype().nullability());
40 let array_validity = array
41 .varbin_validity()
42 .execute_mask(array.as_ref().len(), ctx)?;
43 let indices_validity = indices
44 .as_ref()
45 .validity()?
46 .execute_mask(indices.as_ref().len(), ctx)?;
47
48 let array = match_each_integer_ptype!(indices.ptype(), |I| {
49 match offsets.ptype() {
52 PType::U8 => take::<I, u8, u32>(
53 dtype,
54 offsets.as_slice::<u8>(),
55 data.as_slice(),
56 indices.as_slice::<I>(),
57 array_validity,
58 indices_validity,
59 ),
60 PType::U16 => take::<I, u16, u32>(
61 dtype,
62 offsets.as_slice::<u16>(),
63 data.as_slice(),
64 indices.as_slice::<I>(),
65 array_validity,
66 indices_validity,
67 ),
68 PType::U32 => take::<I, u32, u32>(
69 dtype,
70 offsets.as_slice::<u32>(),
71 data.as_slice(),
72 indices.as_slice::<I>(),
73 array_validity,
74 indices_validity,
75 ),
76 PType::U64 => take::<I, u64, u64>(
77 dtype,
78 offsets.as_slice::<u64>(),
79 data.as_slice(),
80 indices.as_slice::<I>(),
81 array_validity,
82 indices_validity,
83 ),
84 PType::I8 => take::<I, i8, i32>(
85 dtype,
86 offsets.as_slice::<i8>(),
87 data.as_slice(),
88 indices.as_slice::<I>(),
89 array_validity,
90 indices_validity,
91 ),
92 PType::I16 => take::<I, i16, i32>(
93 dtype,
94 offsets.as_slice::<i16>(),
95 data.as_slice(),
96 indices.as_slice::<I>(),
97 array_validity,
98 indices_validity,
99 ),
100 PType::I32 => take::<I, i32, i32>(
101 dtype,
102 offsets.as_slice::<i32>(),
103 data.as_slice(),
104 indices.as_slice::<I>(),
105 array_validity,
106 indices_validity,
107 ),
108 PType::I64 => take::<I, i64, i64>(
109 dtype,
110 offsets.as_slice::<i64>(),
111 data.as_slice(),
112 indices.as_slice::<I>(),
113 array_validity,
114 indices_validity,
115 ),
116 _ => unreachable!("invalid PType for offsets"),
117 }
118 });
119
120 Ok(Some(array?.into_array()))
121 }
122}
123
124fn take<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
125 dtype: DType,
126 offsets: &[Offset],
127 data: &[u8],
128 indices: &[Index],
129 validity_mask: Mask,
130 indices_validity_mask: Mask,
131) -> VortexResult<VarBinArray> {
132 if !validity_mask.all_true() || !indices_validity_mask.all_true() {
133 return Ok(take_nullable::<Index, Offset, NewOffset>(
134 dtype,
135 offsets,
136 data,
137 indices,
138 validity_mask,
139 indices_validity_mask,
140 ));
141 }
142
143 let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
144 new_offsets.push(NewOffset::zero());
145 let mut current_offset = NewOffset::zero();
146
147 for &idx in indices {
148 let idx = idx
149 .to_usize()
150 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
151 let start = offsets[idx];
152 let stop = offsets[idx + 1];
153
154 current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
155 new_offsets.push(current_offset);
156 }
157
158 let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
159
160 for idx in indices {
161 let idx = idx
162 .to_usize()
163 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
164 let start = offsets[idx]
165 .to_usize()
166 .vortex_expect("Failed to cast max offset to usize");
167 let stop = offsets[idx + 1]
168 .to_usize()
169 .vortex_expect("Failed to cast max offset to usize");
170 new_data.extend_from_slice(&data[start..stop]);
171 }
172
173 let array_validity = Validity::from(dtype.nullability());
174
175 unsafe {
178 Ok(VarBinArray::new_unchecked(
179 PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
180 new_data.freeze(),
181 dtype,
182 array_validity,
183 ))
184 }
185}
186
187fn take_nullable<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
188 dtype: DType,
189 offsets: &[Offset],
190 data: &[u8],
191 indices: &[Index],
192 data_validity: Mask,
193 indices_validity: Mask,
194) -> VarBinArray {
195 let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
196 new_offsets.push(NewOffset::zero());
197 let mut current_offset = NewOffset::zero();
198
199 let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
200
201 let mut valid_indices = Vec::with_capacity(indices.len());
203
204 for (idx, data_idx) in indices.iter().enumerate() {
206 if !indices_validity.value(idx) {
207 validity_buffer.append(false);
208 new_offsets.push(current_offset);
209 continue;
210 }
211 let data_idx_usize = data_idx
212 .to_usize()
213 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
214 if data_validity.value(data_idx_usize) {
215 validity_buffer.append(true);
216 let start = offsets[data_idx_usize];
217 let stop = offsets[data_idx_usize + 1];
218 current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
219 new_offsets.push(current_offset);
220 valid_indices.push(data_idx_usize);
221 } else {
222 validity_buffer.append(false);
223 new_offsets.push(current_offset);
224 }
225 }
226
227 let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
228
229 for data_idx in valid_indices {
231 let start = offsets[data_idx]
232 .to_usize()
233 .vortex_expect("Failed to cast max offset to usize");
234 let stop = offsets[data_idx + 1]
235 .to_usize()
236 .vortex_expect("Failed to cast max offset to usize");
237 new_data.extend_from_slice(&data[start..stop]);
238 }
239
240 let array_validity = Validity::from(validity_buffer.freeze());
241
242 unsafe {
245 VarBinArray::new_unchecked(
246 PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
247 new_data.freeze(),
248 dtype,
249 array_validity,
250 )
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use rstest::rstest;
257 use vortex_buffer::ByteBuffer;
258 use vortex_buffer::buffer;
259
260 use crate::IntoArray;
261 use crate::arrays::VarBinArray;
262 use crate::arrays::VarBinViewArray;
263 use crate::arrays::varbin::compute::take::PrimitiveArray;
264 use crate::assert_arrays_eq;
265 use crate::compute::conformance::take::test_take_conformance;
266 use crate::dtype::DType;
267 use crate::dtype::Nullability;
268 use crate::validity::Validity;
269
270 #[test]
271 fn test_null_take() {
272 let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
273
274 let idx1: PrimitiveArray = (0..1).collect();
275
276 assert_eq!(
277 arr.take(idx1.into_array()).unwrap().dtype(),
278 &DType::Utf8(Nullability::NonNullable)
279 );
280
281 let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
282
283 assert_eq!(
284 arr.take(idx2.into_array()).unwrap().dtype(),
285 &DType::Utf8(Nullability::Nullable)
286 );
287 }
288
289 #[rstest]
290 #[case(VarBinArray::from_iter(
291 ["hello", "world", "test", "data", "array"].map(Some),
292 DType::Utf8(Nullability::NonNullable),
293 ))]
294 #[case(VarBinArray::from_iter(
295 [Some("hello"), None, Some("test"), Some("data"), None],
296 DType::Utf8(Nullability::Nullable),
297 ))]
298 #[case(VarBinArray::from_iter(
299 [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
300 DType::Binary(Nullability::NonNullable),
301 ))]
302 #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
303 fn test_take_varbin_conformance(#[case] array: VarBinArray) {
304 test_take_conformance(&array.into_array());
305 }
306
307 #[test]
308 fn test_take_overflow() {
309 let scream = std::iter::once("a").cycle().take(128).collect::<String>();
310 let bytes = ByteBuffer::copy_from(scream.as_bytes());
311 let offsets = buffer![0u8, 128u8].into_array();
312
313 let array = VarBinArray::new(
314 offsets,
315 bytes,
316 DType::Utf8(Nullability::NonNullable),
317 Validity::NonNullable,
318 );
319
320 let indices = buffer![0u32; 3].into_array();
321 let taken = array.take(indices).unwrap();
322
323 let expected = VarBinViewArray::from_iter(
324 [Some(scream.clone()), Some(scream.clone()), Some(scream)],
325 DType::Utf8(Nullability::NonNullable),
326 );
327 assert_arrays_eq!(expected, taken);
328 }
329}