vortex_array/arrays/varbin/compute/
take.rs1use vortex_buffer::BitBufferMut;
5use vortex_buffer::BufferMut;
6use vortex_buffer::ByteBufferMut;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_panic;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::arrays::PrimitiveArray;
15use crate::arrays::TakeExecute;
16use crate::arrays::VarBinVTable;
17use crate::arrays::varbin::VarBinArray;
18use crate::dtype::DType;
19use crate::dtype::IntegerPType;
20use crate::executor::ExecutionCtx;
21use crate::match_each_integer_ptype;
22use crate::validity::Validity;
23
24impl TakeExecute for VarBinVTable {
25 fn take(
26 array: &VarBinArray,
27 indices: &ArrayRef,
28 ctx: &mut ExecutionCtx,
29 ) -> VortexResult<Option<ArrayRef>> {
30 let offsets = array.offsets().to_array().execute::<PrimitiveArray>(ctx)?;
32 let data = array.bytes();
33 let indices = indices.to_array().execute::<PrimitiveArray>(ctx)?;
34 let dtype = array
35 .dtype()
36 .clone()
37 .union_nullability(indices.dtype().nullability());
38 let array_validity = array.validity_mask()?;
39 let indices_validity = indices.validity_mask()?;
40
41 let array = match_each_integer_ptype!(indices.ptype(), |I| {
42 match offsets.ptype() {
45 PType::U8 => take::<I, u8, u32>(
46 dtype,
47 offsets.as_slice::<u8>(),
48 data.as_slice(),
49 indices.as_slice::<I>(),
50 array_validity,
51 indices_validity,
52 ),
53 PType::U16 => take::<I, u16, u32>(
54 dtype,
55 offsets.as_slice::<u16>(),
56 data.as_slice(),
57 indices.as_slice::<I>(),
58 array_validity,
59 indices_validity,
60 ),
61 PType::U32 => take::<I, u32, u32>(
62 dtype,
63 offsets.as_slice::<u32>(),
64 data.as_slice(),
65 indices.as_slice::<I>(),
66 array_validity,
67 indices_validity,
68 ),
69 PType::U64 => take::<I, u64, u64>(
70 dtype,
71 offsets.as_slice::<u64>(),
72 data.as_slice(),
73 indices.as_slice::<I>(),
74 array_validity,
75 indices_validity,
76 ),
77 PType::I8 => take::<I, i8, i32>(
78 dtype,
79 offsets.as_slice::<i8>(),
80 data.as_slice(),
81 indices.as_slice::<I>(),
82 array_validity,
83 indices_validity,
84 ),
85 PType::I16 => take::<I, i16, i32>(
86 dtype,
87 offsets.as_slice::<i16>(),
88 data.as_slice(),
89 indices.as_slice::<I>(),
90 array_validity,
91 indices_validity,
92 ),
93 PType::I32 => take::<I, i32, i32>(
94 dtype,
95 offsets.as_slice::<i32>(),
96 data.as_slice(),
97 indices.as_slice::<I>(),
98 array_validity,
99 indices_validity,
100 ),
101 PType::I64 => take::<I, i64, i64>(
102 dtype,
103 offsets.as_slice::<i64>(),
104 data.as_slice(),
105 indices.as_slice::<I>(),
106 array_validity,
107 indices_validity,
108 ),
109 _ => unreachable!("invalid PType for offsets"),
110 }
111 });
112
113 Ok(Some(array?.into_array()))
114 }
115}
116
117fn take<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
118 dtype: DType,
119 offsets: &[Offset],
120 data: &[u8],
121 indices: &[Index],
122 validity_mask: Mask,
123 indices_validity_mask: Mask,
124) -> VortexResult<VarBinArray> {
125 if !validity_mask.all_true() || !indices_validity_mask.all_true() {
126 return Ok(take_nullable::<Index, Offset, NewOffset>(
127 dtype,
128 offsets,
129 data,
130 indices,
131 validity_mask,
132 indices_validity_mask,
133 ));
134 }
135
136 let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
137 new_offsets.push(NewOffset::zero());
138 let mut current_offset = NewOffset::zero();
139
140 for &idx in indices {
141 let idx = idx
142 .to_usize()
143 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
144 let start = offsets[idx];
145 let stop = offsets[idx + 1];
146
147 current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
148 new_offsets.push(current_offset);
149 }
150
151 let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
152
153 for idx in indices {
154 let idx = idx
155 .to_usize()
156 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
157 let start = offsets[idx]
158 .to_usize()
159 .vortex_expect("Failed to cast max offset to usize");
160 let stop = offsets[idx + 1]
161 .to_usize()
162 .vortex_expect("Failed to cast max offset to usize");
163 new_data.extend_from_slice(&data[start..stop]);
164 }
165
166 let array_validity = Validity::from(dtype.nullability());
167
168 unsafe {
171 Ok(VarBinArray::new_unchecked(
172 PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
173 new_data.freeze(),
174 dtype,
175 array_validity,
176 ))
177 }
178}
179
180fn take_nullable<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
181 dtype: DType,
182 offsets: &[Offset],
183 data: &[u8],
184 indices: &[Index],
185 data_validity: Mask,
186 indices_validity: Mask,
187) -> VarBinArray {
188 let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
189 new_offsets.push(NewOffset::zero());
190 let mut current_offset = NewOffset::zero();
191
192 let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
193
194 let mut valid_indices = Vec::with_capacity(indices.len());
196
197 for (idx, data_idx) in indices.iter().enumerate() {
199 if !indices_validity.value(idx) {
200 validity_buffer.append(false);
201 new_offsets.push(current_offset);
202 continue;
203 }
204 let data_idx_usize = data_idx
205 .to_usize()
206 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
207 if data_validity.value(data_idx_usize) {
208 validity_buffer.append(true);
209 let start = offsets[data_idx_usize];
210 let stop = offsets[data_idx_usize + 1];
211 current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
212 new_offsets.push(current_offset);
213 valid_indices.push(data_idx_usize);
214 } else {
215 validity_buffer.append(false);
216 new_offsets.push(current_offset);
217 }
218 }
219
220 let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
221
222 for data_idx in valid_indices {
224 let start = offsets[data_idx]
225 .to_usize()
226 .vortex_expect("Failed to cast max offset to usize");
227 let stop = offsets[data_idx + 1]
228 .to_usize()
229 .vortex_expect("Failed to cast max offset to usize");
230 new_data.extend_from_slice(&data[start..stop]);
231 }
232
233 let array_validity = Validity::from(validity_buffer.freeze());
234
235 unsafe {
238 VarBinArray::new_unchecked(
239 PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
240 new_data.freeze(),
241 dtype,
242 array_validity,
243 )
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use rstest::rstest;
250 use vortex_buffer::ByteBuffer;
251 use vortex_buffer::buffer;
252
253 use crate::Array;
254 use crate::IntoArray;
255 use crate::arrays::PrimitiveArray;
256 use crate::arrays::VarBinArray;
257 use crate::arrays::VarBinViewArray;
258 use crate::assert_arrays_eq;
259 use crate::compute::conformance::take::test_take_conformance;
260 use crate::dtype::DType;
261 use crate::dtype::Nullability;
262 use crate::validity::Validity;
263
264 #[test]
265 fn test_null_take() {
266 let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
267
268 let idx1: PrimitiveArray = (0..1).collect();
269
270 assert_eq!(
271 arr.take(idx1.to_array()).unwrap().dtype(),
272 &DType::Utf8(Nullability::NonNullable)
273 );
274
275 let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
276
277 assert_eq!(
278 arr.take(idx2.to_array()).unwrap().dtype(),
279 &DType::Utf8(Nullability::Nullable)
280 );
281 }
282
283 #[rstest]
284 #[case(VarBinArray::from_iter(
285 ["hello", "world", "test", "data", "array"].map(Some),
286 DType::Utf8(Nullability::NonNullable),
287 ))]
288 #[case(VarBinArray::from_iter(
289 [Some("hello"), None, Some("test"), Some("data"), None],
290 DType::Utf8(Nullability::Nullable),
291 ))]
292 #[case(VarBinArray::from_iter(
293 [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
294 DType::Binary(Nullability::NonNullable),
295 ))]
296 #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
297 fn test_take_varbin_conformance(#[case] array: VarBinArray) {
298 test_take_conformance(&array.to_array());
299 }
300
301 #[test]
302 fn test_take_overflow() {
303 let scream = std::iter::once("a").cycle().take(128).collect::<String>();
304 let bytes = ByteBuffer::copy_from(scream.as_bytes());
305 let offsets = buffer![0u8, 128u8].into_array();
306
307 let array = VarBinArray::new(
308 offsets,
309 bytes,
310 DType::Utf8(Nullability::NonNullable),
311 Validity::NonNullable,
312 );
313
314 let indices = buffer![0u32; 3].into_array();
315 let taken = array.take(indices.to_array()).unwrap();
316
317 let expected = VarBinViewArray::from_iter(
318 [Some(scream.clone()), Some(scream.clone()), Some(scream)],
319 DType::Utf8(Nullability::NonNullable),
320 );
321 assert_arrays_eq!(expected, taken);
322 }
323}