vortex_array/arrays/varbin/compute/
take.rs1use vortex_buffer::BitBufferMut;
5use vortex_buffer::BufferMut;
6use vortex_buffer::ByteBufferMut;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_panic;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::array::ArrayView;
15use crate::arrays::PrimitiveArray;
16use crate::arrays::VarBin;
17use crate::arrays::VarBinArray;
18use crate::arrays::dict::TakeExecute;
19use crate::arrays::primitive::PrimitiveArrayExt;
20use crate::arrays::varbin::VarBinArrayExt;
21use crate::dtype::DType;
22use crate::dtype::IntegerPType;
23use crate::dtype::PType;
24use crate::executor::ExecutionCtx;
25use crate::match_each_unsigned_integer_ptype;
26use crate::validity::Validity;
27
28fn taken_offset_ptype(offsets_ptype: PType) -> PType {
31 match offsets_ptype {
32 PType::U8 | PType::U16 | PType::U32 => PType::U32,
33 PType::U64 => PType::U64,
34 PType::I8 | PType::I16 | PType::I32 => PType::I32,
35 PType::I64 => PType::I64,
36 _ => unreachable!("invalid PType for offsets"),
37 }
38}
39
40impl TakeExecute for VarBin {
41 fn take(
42 array: ArrayView<'_, VarBin>,
43 indices: &ArrayRef,
44 ctx: &mut ExecutionCtx,
45 ) -> VortexResult<Option<ArrayRef>> {
46 let offsets = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
48 let data = array.bytes();
49 let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
50 let dtype = array
51 .dtype()
52 .clone()
53 .union_nullability(indices.dtype().nullability());
54 let array_validity = array
55 .varbin_validity()
56 .execute_mask(array.as_ref().len(), ctx)?;
57 let indices_validity = indices
58 .as_ref()
59 .validity()?
60 .execute_mask(indices.as_ref().len(), ctx)?;
61
62 let out_offset_ptype = taken_offset_ptype(offsets.ptype());
67 let offsets = offsets.reinterpret_cast(offsets.ptype().to_unsigned());
68 let indices = indices.reinterpret_cast(indices.ptype().to_unsigned());
69
70 let array = match_each_unsigned_integer_ptype!(indices.ptype(), |I| {
71 match offsets.ptype() {
72 PType::U8 => take::<I, u8, u32>(
73 dtype,
74 offsets.as_slice::<u8>(),
75 data.as_slice(),
76 indices.as_slice::<I>(),
77 array_validity,
78 indices_validity,
79 out_offset_ptype,
80 ),
81 PType::U16 => take::<I, u16, u32>(
82 dtype,
83 offsets.as_slice::<u16>(),
84 data.as_slice(),
85 indices.as_slice::<I>(),
86 array_validity,
87 indices_validity,
88 out_offset_ptype,
89 ),
90 PType::U32 => take::<I, u32, u32>(
91 dtype,
92 offsets.as_slice::<u32>(),
93 data.as_slice(),
94 indices.as_slice::<I>(),
95 array_validity,
96 indices_validity,
97 out_offset_ptype,
98 ),
99 PType::U64 => take::<I, u64, u64>(
100 dtype,
101 offsets.as_slice::<u64>(),
102 data.as_slice(),
103 indices.as_slice::<I>(),
104 array_validity,
105 indices_validity,
106 out_offset_ptype,
107 ),
108 _ => unreachable!("invalid PType for offsets"),
109 }
110 });
111
112 Ok(Some(array?.into_array()))
113 }
114}
115
116fn take<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
117 dtype: DType,
118 offsets: &[Offset],
119 data: &[u8],
120 indices: &[Index],
121 validity_mask: Mask,
122 indices_validity_mask: Mask,
123 out_offset_ptype: PType,
124) -> VortexResult<VarBinArray> {
125 if !validity_mask.all_true() || !indices_validity_mask.all_true() {
126 return Ok(take_nullable::<Index, Offset, NewOffset>(
127 dtype,
128 offsets,
129 data,
130 indices,
131 validity_mask,
132 indices_validity_mask,
133 out_offset_ptype,
134 ));
135 }
136
137 let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
138 new_offsets.push(NewOffset::zero());
139 let mut current_offset = NewOffset::zero();
140
141 for &idx in indices {
142 let idx = idx
143 .to_usize()
144 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
145 let start = offsets[idx];
146 let stop = offsets[idx + 1];
147
148 current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
149 new_offsets.push(current_offset);
150 }
151
152 let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
153
154 for idx in indices {
155 let idx = idx
156 .to_usize()
157 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
158 let start = offsets[idx]
159 .to_usize()
160 .vortex_expect("Failed to cast max offset to usize");
161 let stop = offsets[idx + 1]
162 .to_usize()
163 .vortex_expect("Failed to cast max offset to usize");
164 new_data.extend_from_slice(&data[start..stop]);
165 }
166
167 let array_validity = Validity::from(dtype.nullability());
168
169 let new_offsets = PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable)
171 .reinterpret_cast(out_offset_ptype)
172 .into_array();
173
174 unsafe {
177 Ok(VarBinArray::new_unchecked(
178 new_offsets,
179 new_data.freeze(),
180 dtype,
181 array_validity,
182 ))
183 }
184}
185
186fn take_nullable<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
187 dtype: DType,
188 offsets: &[Offset],
189 data: &[u8],
190 indices: &[Index],
191 data_validity: Mask,
192 indices_validity: Mask,
193 out_offset_ptype: PType,
194) -> VarBinArray {
195 let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
196 new_offsets.push(NewOffset::zero());
197 let mut current_offset = NewOffset::zero();
198
199 let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
200
201 let mut valid_indices = Vec::with_capacity(indices.len());
203
204 for (idx, data_idx) in indices.iter().enumerate() {
206 if !indices_validity.value(idx) {
207 validity_buffer.append(false);
208 new_offsets.push(current_offset);
209 continue;
210 }
211 let data_idx_usize = data_idx
212 .to_usize()
213 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
214 if data_validity.value(data_idx_usize) {
215 validity_buffer.append(true);
216 let start = offsets[data_idx_usize];
217 let stop = offsets[data_idx_usize + 1];
218 current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
219 new_offsets.push(current_offset);
220 valid_indices.push(data_idx_usize);
221 } else {
222 validity_buffer.append(false);
223 new_offsets.push(current_offset);
224 }
225 }
226
227 let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
228
229 for data_idx in valid_indices {
231 let start = offsets[data_idx]
232 .to_usize()
233 .vortex_expect("Failed to cast max offset to usize");
234 let stop = offsets[data_idx + 1]
235 .to_usize()
236 .vortex_expect("Failed to cast max offset to usize");
237 new_data.extend_from_slice(&data[start..stop]);
238 }
239
240 let array_validity = Validity::from(validity_buffer.freeze());
241
242 let new_offsets = PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable)
244 .reinterpret_cast(out_offset_ptype)
245 .into_array();
246
247 unsafe { VarBinArray::new_unchecked(new_offsets, new_data.freeze(), dtype, array_validity) }
250}
251
252#[cfg(test)]
253mod tests {
254 use rstest::rstest;
255 use vortex_buffer::ByteBuffer;
256 use vortex_buffer::buffer;
257
258 use crate::IntoArray;
259 use crate::arrays::VarBinArray;
260 use crate::arrays::VarBinViewArray;
261 use crate::arrays::varbin::compute::take::PrimitiveArray;
262 use crate::assert_arrays_eq;
263 use crate::compute::conformance::take::test_take_conformance;
264 use crate::dtype::DType;
265 use crate::dtype::Nullability;
266 use crate::validity::Validity;
267
268 #[test]
269 fn test_null_take() {
270 let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
271
272 let idx1: PrimitiveArray = (0..1).collect();
273
274 assert_eq!(
275 arr.take(idx1.into_array()).unwrap().dtype(),
276 &DType::Utf8(Nullability::NonNullable)
277 );
278
279 let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
280
281 assert_eq!(
282 arr.take(idx2.into_array()).unwrap().dtype(),
283 &DType::Utf8(Nullability::Nullable)
284 );
285 }
286
287 #[rstest]
288 #[case(VarBinArray::from_iter(
289 ["hello", "world", "test", "data", "array"].map(Some),
290 DType::Utf8(Nullability::NonNullable),
291 ))]
292 #[case(VarBinArray::from_iter(
293 [Some("hello"), None, Some("test"), Some("data"), None],
294 DType::Utf8(Nullability::Nullable),
295 ))]
296 #[case(VarBinArray::from_iter(
297 [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
298 DType::Binary(Nullability::NonNullable),
299 ))]
300 #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
301 fn test_take_varbin_conformance(#[case] array: VarBinArray) {
302 test_take_conformance(&array.into_array());
303 }
304
305 #[test]
306 fn test_take_overflow() {
307 let scream = std::iter::once("a").cycle().take(128).collect::<String>();
308 let bytes = ByteBuffer::copy_from(scream.as_bytes());
309 let offsets = buffer![0u8, 128u8].into_array();
310
311 let array = VarBinArray::new(
312 offsets,
313 bytes,
314 DType::Utf8(Nullability::NonNullable),
315 Validity::NonNullable,
316 );
317
318 let indices = buffer![0u32; 3].into_array();
319 let taken = array.take(indices).unwrap();
320
321 let expected = VarBinViewArray::from_iter(
322 [Some(scream.clone()), Some(scream.clone()), Some(scream)],
323 DType::Utf8(Nullability::NonNullable),
324 );
325 assert_arrays_eq!(expected, taken);
326 }
327}