vortex_array/arrays/varbin/compute/
take.rs1use vortex_buffer::BitBufferMut;
5use vortex_buffer::BufferMut;
6use vortex_buffer::ByteBufferMut;
7use vortex_dtype::DType;
8use vortex_dtype::IntegerPType;
9use vortex_dtype::match_each_integer_ptype;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_panic;
13use vortex_mask::Mask;
14
15use crate::Array;
16use crate::ArrayRef;
17use crate::IntoArray;
18use crate::ToCanonical;
19use crate::arrays::PrimitiveArray;
20use crate::arrays::TakeExecute;
21use crate::arrays::VarBinVTable;
22use crate::arrays::varbin::VarBinArray;
23use crate::executor::ExecutionCtx;
24use crate::validity::Validity;
25
26impl TakeExecute for VarBinVTable {
27 fn take(
28 array: &VarBinArray,
29 indices: &dyn Array,
30 _ctx: &mut ExecutionCtx,
31 ) -> VortexResult<Option<ArrayRef>> {
32 let offsets = array.offsets().to_primitive();
34 let data = array.bytes();
35 let indices = indices.to_primitive();
36 let dtype = array
37 .dtype()
38 .clone()
39 .union_nullability(indices.dtype().nullability());
40 let array_validity = array.validity_mask()?;
41 let indices_validity = indices.validity_mask()?;
42
43 let array = match_each_integer_ptype!(indices.ptype(), |I| {
44 match offsets.ptype() {
47 PType::U8 => take::<I, u8, u32>(
48 dtype,
49 offsets.as_slice::<u8>(),
50 data.as_slice(),
51 indices.as_slice::<I>(),
52 array_validity,
53 indices_validity,
54 ),
55 PType::U16 => take::<I, u16, u32>(
56 dtype,
57 offsets.as_slice::<u16>(),
58 data.as_slice(),
59 indices.as_slice::<I>(),
60 array_validity,
61 indices_validity,
62 ),
63 PType::U32 => take::<I, u32, u32>(
64 dtype,
65 offsets.as_slice::<u32>(),
66 data.as_slice(),
67 indices.as_slice::<I>(),
68 array_validity,
69 indices_validity,
70 ),
71 PType::U64 => take::<I, u64, u64>(
72 dtype,
73 offsets.as_slice::<u64>(),
74 data.as_slice(),
75 indices.as_slice::<I>(),
76 array_validity,
77 indices_validity,
78 ),
79 PType::I8 => take::<I, i8, i32>(
80 dtype,
81 offsets.as_slice::<i8>(),
82 data.as_slice(),
83 indices.as_slice::<I>(),
84 array_validity,
85 indices_validity,
86 ),
87 PType::I16 => take::<I, i16, i32>(
88 dtype,
89 offsets.as_slice::<i16>(),
90 data.as_slice(),
91 indices.as_slice::<I>(),
92 array_validity,
93 indices_validity,
94 ),
95 PType::I32 => take::<I, i32, i32>(
96 dtype,
97 offsets.as_slice::<i32>(),
98 data.as_slice(),
99 indices.as_slice::<I>(),
100 array_validity,
101 indices_validity,
102 ),
103 PType::I64 => take::<I, i64, i64>(
104 dtype,
105 offsets.as_slice::<i64>(),
106 data.as_slice(),
107 indices.as_slice::<I>(),
108 array_validity,
109 indices_validity,
110 ),
111 _ => unreachable!("invalid PType for offsets"),
112 }
113 });
114
115 Ok(Some(array?.into_array()))
116 }
117}
118
119fn take<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
120 dtype: DType,
121 offsets: &[Offset],
122 data: &[u8],
123 indices: &[Index],
124 validity_mask: Mask,
125 indices_validity_mask: Mask,
126) -> VortexResult<VarBinArray> {
127 if !validity_mask.all_true() || !indices_validity_mask.all_true() {
128 return Ok(take_nullable::<Index, Offset, NewOffset>(
129 dtype,
130 offsets,
131 data,
132 indices,
133 validity_mask,
134 indices_validity_mask,
135 ));
136 }
137
138 let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
139 new_offsets.push(NewOffset::zero());
140 let mut current_offset = NewOffset::zero();
141
142 for &idx in indices {
143 let idx = idx
144 .to_usize()
145 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
146 let start = offsets[idx];
147 let stop = offsets[idx + 1];
148
149 current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
150 new_offsets.push(current_offset);
151 }
152
153 let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
154
155 for idx in indices {
156 let idx = idx
157 .to_usize()
158 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
159 let start = offsets[idx]
160 .to_usize()
161 .vortex_expect("Failed to cast max offset to usize");
162 let stop = offsets[idx + 1]
163 .to_usize()
164 .vortex_expect("Failed to cast max offset to usize");
165 new_data.extend_from_slice(&data[start..stop]);
166 }
167
168 let array_validity = Validity::from(dtype.nullability());
169
170 unsafe {
173 Ok(VarBinArray::new_unchecked(
174 PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
175 new_data.freeze(),
176 dtype,
177 array_validity,
178 ))
179 }
180}
181
182fn take_nullable<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
183 dtype: DType,
184 offsets: &[Offset],
185 data: &[u8],
186 indices: &[Index],
187 data_validity: Mask,
188 indices_validity: Mask,
189) -> VarBinArray {
190 let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
191 new_offsets.push(NewOffset::zero());
192 let mut current_offset = NewOffset::zero();
193
194 let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
195
196 let mut valid_indices = Vec::with_capacity(indices.len());
198
199 for (idx, data_idx) in indices.iter().enumerate() {
201 if !indices_validity.value(idx) {
202 validity_buffer.append(false);
203 new_offsets.push(current_offset);
204 continue;
205 }
206 let data_idx_usize = data_idx
207 .to_usize()
208 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
209 if data_validity.value(data_idx_usize) {
210 validity_buffer.append(true);
211 let start = offsets[data_idx_usize];
212 let stop = offsets[data_idx_usize + 1];
213 current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
214 new_offsets.push(current_offset);
215 valid_indices.push(data_idx_usize);
216 } else {
217 validity_buffer.append(false);
218 new_offsets.push(current_offset);
219 }
220 }
221
222 let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
223
224 for data_idx in valid_indices {
226 let start = offsets[data_idx]
227 .to_usize()
228 .vortex_expect("Failed to cast max offset to usize");
229 let stop = offsets[data_idx + 1]
230 .to_usize()
231 .vortex_expect("Failed to cast max offset to usize");
232 new_data.extend_from_slice(&data[start..stop]);
233 }
234
235 let array_validity = Validity::from(validity_buffer.freeze());
236
237 unsafe {
240 VarBinArray::new_unchecked(
241 PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
242 new_data.freeze(),
243 dtype,
244 array_validity,
245 )
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use rstest::rstest;
252 use vortex_buffer::ByteBuffer;
253 use vortex_buffer::buffer;
254 use vortex_dtype::DType;
255 use vortex_dtype::Nullability;
256
257 use crate::Array;
258 use crate::IntoArray;
259 use crate::arrays::PrimitiveArray;
260 use crate::arrays::VarBinArray;
261 use crate::arrays::VarBinViewArray;
262 use crate::assert_arrays_eq;
263 use crate::compute::conformance::take::test_take_conformance;
264 use crate::validity::Validity;
265
266 #[test]
267 fn test_null_take() {
268 let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
269
270 let idx1: PrimitiveArray = (0..1).collect();
271
272 assert_eq!(
273 arr.take(idx1.to_array()).unwrap().dtype(),
274 &DType::Utf8(Nullability::NonNullable)
275 );
276
277 let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
278
279 assert_eq!(
280 arr.take(idx2.to_array()).unwrap().dtype(),
281 &DType::Utf8(Nullability::Nullable)
282 );
283 }
284
285 #[rstest]
286 #[case(VarBinArray::from_iter(
287 ["hello", "world", "test", "data", "array"].map(Some),
288 DType::Utf8(Nullability::NonNullable),
289 ))]
290 #[case(VarBinArray::from_iter(
291 [Some("hello"), None, Some("test"), Some("data"), None],
292 DType::Utf8(Nullability::Nullable),
293 ))]
294 #[case(VarBinArray::from_iter(
295 [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
296 DType::Binary(Nullability::NonNullable),
297 ))]
298 #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
299 fn test_take_varbin_conformance(#[case] array: VarBinArray) {
300 test_take_conformance(array.as_ref());
301 }
302
303 #[test]
304 fn test_take_overflow() {
305 let scream = std::iter::once("a").cycle().take(128).collect::<String>();
306 let bytes = ByteBuffer::copy_from(scream.as_bytes());
307 let offsets = buffer![0u8, 128u8].into_array();
308
309 let array = VarBinArray::new(
310 offsets,
311 bytes,
312 DType::Utf8(Nullability::NonNullable),
313 Validity::NonNullable,
314 );
315
316 let indices = buffer![0u32; 3].into_array();
317 let taken = array.take(indices.to_array()).unwrap();
318
319 let expected = VarBinViewArray::from_iter(
320 [Some(scream.clone()), Some(scream.clone()), Some(scream)],
321 DType::Utf8(Nullability::NonNullable),
322 );
323 assert_arrays_eq!(expected, taken);
324 }
325}