vortex_array/arrays/varbin/compute/
take.rs1use vortex_buffer::BitBufferMut;
5use vortex_buffer::BufferMut;
6use vortex_buffer::ByteBufferMut;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_panic;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::array::ArrayView;
15use crate::arrays::PrimitiveArray;
16use crate::arrays::VarBin;
17use crate::arrays::VarBinArray;
18use crate::arrays::dict::TakeExecute;
19use crate::arrays::varbin::VarBinArrayExt;
20use crate::dtype::DType;
21use crate::dtype::IntegerPType;
22use crate::executor::ExecutionCtx;
23use crate::match_each_integer_ptype;
24use crate::validity::Validity;
25
26impl TakeExecute for VarBin {
27 fn take(
28 array: ArrayView<'_, VarBin>,
29 indices: &ArrayRef,
30 ctx: &mut ExecutionCtx,
31 ) -> VortexResult<Option<ArrayRef>> {
32 let offsets = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
34 let data = array.bytes();
35 let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
36 let dtype = array
37 .dtype()
38 .clone()
39 .union_nullability(indices.dtype().nullability());
40 let array_validity = array.varbin_validity().to_mask(array.as_ref().len(), ctx)?;
41 let indices_validity = indices
42 .as_ref()
43 .validity()?
44 .to_mask(indices.as_ref().len(), ctx)?;
45
46 let array = match_each_integer_ptype!(indices.ptype(), |I| {
47 match offsets.ptype() {
50 PType::U8 => take::<I, u8, u32>(
51 dtype,
52 offsets.as_slice::<u8>(),
53 data.as_slice(),
54 indices.as_slice::<I>(),
55 array_validity,
56 indices_validity,
57 ),
58 PType::U16 => take::<I, u16, u32>(
59 dtype,
60 offsets.as_slice::<u16>(),
61 data.as_slice(),
62 indices.as_slice::<I>(),
63 array_validity,
64 indices_validity,
65 ),
66 PType::U32 => take::<I, u32, u32>(
67 dtype,
68 offsets.as_slice::<u32>(),
69 data.as_slice(),
70 indices.as_slice::<I>(),
71 array_validity,
72 indices_validity,
73 ),
74 PType::U64 => take::<I, u64, u64>(
75 dtype,
76 offsets.as_slice::<u64>(),
77 data.as_slice(),
78 indices.as_slice::<I>(),
79 array_validity,
80 indices_validity,
81 ),
82 PType::I8 => take::<I, i8, i32>(
83 dtype,
84 offsets.as_slice::<i8>(),
85 data.as_slice(),
86 indices.as_slice::<I>(),
87 array_validity,
88 indices_validity,
89 ),
90 PType::I16 => take::<I, i16, i32>(
91 dtype,
92 offsets.as_slice::<i16>(),
93 data.as_slice(),
94 indices.as_slice::<I>(),
95 array_validity,
96 indices_validity,
97 ),
98 PType::I32 => take::<I, i32, i32>(
99 dtype,
100 offsets.as_slice::<i32>(),
101 data.as_slice(),
102 indices.as_slice::<I>(),
103 array_validity,
104 indices_validity,
105 ),
106 PType::I64 => take::<I, i64, i64>(
107 dtype,
108 offsets.as_slice::<i64>(),
109 data.as_slice(),
110 indices.as_slice::<I>(),
111 array_validity,
112 indices_validity,
113 ),
114 _ => unreachable!("invalid PType for offsets"),
115 }
116 });
117
118 Ok(Some(array?.into_array()))
119 }
120}
121
122fn take<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
123 dtype: DType,
124 offsets: &[Offset],
125 data: &[u8],
126 indices: &[Index],
127 validity_mask: Mask,
128 indices_validity_mask: Mask,
129) -> VortexResult<VarBinArray> {
130 if !validity_mask.all_true() || !indices_validity_mask.all_true() {
131 return Ok(take_nullable::<Index, Offset, NewOffset>(
132 dtype,
133 offsets,
134 data,
135 indices,
136 validity_mask,
137 indices_validity_mask,
138 ));
139 }
140
141 let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
142 new_offsets.push(NewOffset::zero());
143 let mut current_offset = NewOffset::zero();
144
145 for &idx in indices {
146 let idx = idx
147 .to_usize()
148 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
149 let start = offsets[idx];
150 let stop = offsets[idx + 1];
151
152 current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
153 new_offsets.push(current_offset);
154 }
155
156 let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
157
158 for idx in indices {
159 let idx = idx
160 .to_usize()
161 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
162 let start = offsets[idx]
163 .to_usize()
164 .vortex_expect("Failed to cast max offset to usize");
165 let stop = offsets[idx + 1]
166 .to_usize()
167 .vortex_expect("Failed to cast max offset to usize");
168 new_data.extend_from_slice(&data[start..stop]);
169 }
170
171 let array_validity = Validity::from(dtype.nullability());
172
173 unsafe {
176 Ok(VarBinArray::new_unchecked(
177 PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
178 new_data.freeze(),
179 dtype,
180 array_validity,
181 ))
182 }
183}
184
185fn take_nullable<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
186 dtype: DType,
187 offsets: &[Offset],
188 data: &[u8],
189 indices: &[Index],
190 data_validity: Mask,
191 indices_validity: Mask,
192) -> VarBinArray {
193 let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
194 new_offsets.push(NewOffset::zero());
195 let mut current_offset = NewOffset::zero();
196
197 let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
198
199 let mut valid_indices = Vec::with_capacity(indices.len());
201
202 for (idx, data_idx) in indices.iter().enumerate() {
204 if !indices_validity.value(idx) {
205 validity_buffer.append(false);
206 new_offsets.push(current_offset);
207 continue;
208 }
209 let data_idx_usize = data_idx
210 .to_usize()
211 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
212 if data_validity.value(data_idx_usize) {
213 validity_buffer.append(true);
214 let start = offsets[data_idx_usize];
215 let stop = offsets[data_idx_usize + 1];
216 current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
217 new_offsets.push(current_offset);
218 valid_indices.push(data_idx_usize);
219 } else {
220 validity_buffer.append(false);
221 new_offsets.push(current_offset);
222 }
223 }
224
225 let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
226
227 for data_idx in valid_indices {
229 let start = offsets[data_idx]
230 .to_usize()
231 .vortex_expect("Failed to cast max offset to usize");
232 let stop = offsets[data_idx + 1]
233 .to_usize()
234 .vortex_expect("Failed to cast max offset to usize");
235 new_data.extend_from_slice(&data[start..stop]);
236 }
237
238 let array_validity = Validity::from(validity_buffer.freeze());
239
240 unsafe {
243 VarBinArray::new_unchecked(
244 PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
245 new_data.freeze(),
246 dtype,
247 array_validity,
248 )
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use rstest::rstest;
255 use vortex_buffer::ByteBuffer;
256 use vortex_buffer::buffer;
257
258 use crate::IntoArray;
259 use crate::arrays::VarBinArray;
260 use crate::arrays::VarBinViewArray;
261 use crate::arrays::varbin::compute::take::PrimitiveArray;
262 use crate::assert_arrays_eq;
263 use crate::compute::conformance::take::test_take_conformance;
264 use crate::dtype::DType;
265 use crate::dtype::Nullability;
266 use crate::validity::Validity;
267
268 #[test]
269 fn test_null_take() {
270 let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
271
272 let idx1: PrimitiveArray = (0..1).collect();
273
274 assert_eq!(
275 arr.take(idx1.into_array()).unwrap().dtype(),
276 &DType::Utf8(Nullability::NonNullable)
277 );
278
279 let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
280
281 assert_eq!(
282 arr.take(idx2.into_array()).unwrap().dtype(),
283 &DType::Utf8(Nullability::Nullable)
284 );
285 }
286
287 #[rstest]
288 #[case(VarBinArray::from_iter(
289 ["hello", "world", "test", "data", "array"].map(Some),
290 DType::Utf8(Nullability::NonNullable),
291 ))]
292 #[case(VarBinArray::from_iter(
293 [Some("hello"), None, Some("test"), Some("data"), None],
294 DType::Utf8(Nullability::Nullable),
295 ))]
296 #[case(VarBinArray::from_iter(
297 [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
298 DType::Binary(Nullability::NonNullable),
299 ))]
300 #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
301 fn test_take_varbin_conformance(#[case] array: VarBinArray) {
302 test_take_conformance(&array.into_array());
303 }
304
305 #[test]
306 fn test_take_overflow() {
307 let scream = std::iter::once("a").cycle().take(128).collect::<String>();
308 let bytes = ByteBuffer::copy_from(scream.as_bytes());
309 let offsets = buffer![0u8, 128u8].into_array();
310
311 let array = VarBinArray::new(
312 offsets,
313 bytes,
314 DType::Utf8(Nullability::NonNullable),
315 Validity::NonNullable,
316 );
317
318 let indices = buffer![0u32; 3].into_array();
319 let taken = array.take(indices).unwrap();
320
321 let expected = VarBinViewArray::from_iter(
322 [Some(scream.clone()), Some(scream.clone()), Some(scream)],
323 DType::Utf8(Nullability::NonNullable),
324 );
325 assert_arrays_eq!(expected, taken);
326 }
327}