vortex_array/arrays/varbin/compute/
take.rs1use vortex_buffer::BitBufferMut;
5use vortex_buffer::BufferMut;
6use vortex_buffer::ByteBufferMut;
7use vortex_dtype::DType;
8use vortex_dtype::IntegerPType;
9use vortex_dtype::match_each_integer_ptype;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_panic;
13use vortex_mask::Mask;
14
15use crate::Array;
16use crate::ArrayRef;
17use crate::IntoArray;
18use crate::ToCanonical;
19use crate::arrays::PrimitiveArray;
20use crate::arrays::VarBinVTable;
21use crate::arrays::varbin::VarBinArray;
22use crate::compute::TakeKernel;
23use crate::compute::TakeKernelAdapter;
24use crate::register_kernel;
25use crate::validity::Validity;
26
27impl TakeKernel for VarBinVTable {
28 fn take(&self, array: &VarBinArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
29 let offsets = array.offsets().to_primitive();
30 let data = array.bytes();
31 let indices = indices.to_primitive();
32 let dtype = array
33 .dtype()
34 .clone()
35 .union_nullability(indices.dtype().nullability());
36 let array = match_each_integer_ptype!(indices.ptype(), |I| {
37 match offsets.ptype() {
40 PType::U8 => take::<I, u8, u32>(
41 dtype,
42 offsets.as_slice::<u8>(),
43 data.as_slice(),
44 indices.as_slice::<I>(),
45 array.validity_mask(),
46 indices.validity_mask(),
47 ),
48 PType::U16 => take::<I, u16, u32>(
49 dtype,
50 offsets.as_slice::<u16>(),
51 data.as_slice(),
52 indices.as_slice::<I>(),
53 array.validity_mask(),
54 indices.validity_mask(),
55 ),
56 PType::U32 => take::<I, u32, u32>(
57 dtype,
58 offsets.as_slice::<u32>(),
59 data.as_slice(),
60 indices.as_slice::<I>(),
61 array.validity_mask(),
62 indices.validity_mask(),
63 ),
64 PType::U64 => take::<I, u64, u64>(
65 dtype,
66 offsets.as_slice::<u64>(),
67 data.as_slice(),
68 indices.as_slice::<I>(),
69 array.validity_mask(),
70 indices.validity_mask(),
71 ),
72 PType::I8 => take::<I, i8, i32>(
73 dtype,
74 offsets.as_slice::<i8>(),
75 data.as_slice(),
76 indices.as_slice::<I>(),
77 array.validity_mask(),
78 indices.validity_mask(),
79 ),
80 PType::I16 => take::<I, i16, i32>(
81 dtype,
82 offsets.as_slice::<i16>(),
83 data.as_slice(),
84 indices.as_slice::<I>(),
85 array.validity_mask(),
86 indices.validity_mask(),
87 ),
88 PType::I32 => take::<I, i32, i32>(
89 dtype,
90 offsets.as_slice::<i32>(),
91 data.as_slice(),
92 indices.as_slice::<I>(),
93 array.validity_mask(),
94 indices.validity_mask(),
95 ),
96 PType::I64 => take::<I, i64, i64>(
97 dtype,
98 offsets.as_slice::<i64>(),
99 data.as_slice(),
100 indices.as_slice::<I>(),
101 array.validity_mask(),
102 indices.validity_mask(),
103 ),
104 _ => unreachable!("invalid PType for offsets"),
105 }
106 });
107
108 Ok(array?.into_array())
109 }
110}
111
112register_kernel!(TakeKernelAdapter(VarBinVTable).lift());
113
114fn take<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
115 dtype: DType,
116 offsets: &[Offset],
117 data: &[u8],
118 indices: &[Index],
119 validity_mask: Mask,
120 indices_validity_mask: Mask,
121) -> VortexResult<VarBinArray> {
122 if !validity_mask.all_true() || !indices_validity_mask.all_true() {
123 return Ok(take_nullable::<Index, Offset, NewOffset>(
124 dtype,
125 offsets,
126 data,
127 indices,
128 validity_mask,
129 indices_validity_mask,
130 ));
131 }
132
133 let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
134 new_offsets.push(NewOffset::zero());
135 let mut current_offset = NewOffset::zero();
136
137 for &idx in indices {
138 let idx = idx
139 .to_usize()
140 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
141 let start = offsets[idx];
142 let stop = offsets[idx + 1];
143
144 current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
145 new_offsets.push(current_offset);
146 }
147
148 let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
149
150 for idx in indices {
151 let idx = idx
152 .to_usize()
153 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
154 let start = offsets[idx]
155 .to_usize()
156 .vortex_expect("Failed to cast max offset to usize");
157 let stop = offsets[idx + 1]
158 .to_usize()
159 .vortex_expect("Failed to cast max offset to usize");
160 new_data.extend_from_slice(&data[start..stop]);
161 }
162
163 let array_validity = Validity::from(dtype.nullability());
164
165 unsafe {
168 Ok(VarBinArray::new_unchecked(
169 PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
170 new_data.freeze(),
171 dtype,
172 array_validity,
173 ))
174 }
175}
176
177fn take_nullable<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
178 dtype: DType,
179 offsets: &[Offset],
180 data: &[u8],
181 indices: &[Index],
182 data_validity: Mask,
183 indices_validity: Mask,
184) -> VarBinArray {
185 let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
186 new_offsets.push(NewOffset::zero());
187 let mut current_offset = NewOffset::zero();
188
189 let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
190
191 let mut valid_indices = Vec::with_capacity(indices.len());
193
194 for (idx, data_idx) in indices.iter().enumerate() {
196 if !indices_validity.value(idx) {
197 validity_buffer.append(false);
198 new_offsets.push(current_offset);
199 continue;
200 }
201 let data_idx_usize = data_idx
202 .to_usize()
203 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
204 if data_validity.value(data_idx_usize) {
205 validity_buffer.append(true);
206 let start = offsets[data_idx_usize];
207 let stop = offsets[data_idx_usize + 1];
208 current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
209 new_offsets.push(current_offset);
210 valid_indices.push(data_idx_usize);
211 } else {
212 validity_buffer.append(false);
213 new_offsets.push(current_offset);
214 }
215 }
216
217 let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
218
219 for data_idx in valid_indices {
221 let start = offsets[data_idx]
222 .to_usize()
223 .vortex_expect("Failed to cast max offset to usize");
224 let stop = offsets[data_idx + 1]
225 .to_usize()
226 .vortex_expect("Failed to cast max offset to usize");
227 new_data.extend_from_slice(&data[start..stop]);
228 }
229
230 let array_validity = Validity::from(validity_buffer.freeze());
231
232 unsafe {
235 VarBinArray::new_unchecked(
236 PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
237 new_data.freeze(),
238 dtype,
239 array_validity,
240 )
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use rstest::rstest;
247 use vortex_buffer::ByteBuffer;
248 use vortex_buffer::buffer;
249 use vortex_dtype::DType;
250 use vortex_dtype::Nullability;
251
252 use crate::Array;
253 use crate::IntoArray;
254 use crate::arrays::PrimitiveArray;
255 use crate::arrays::VarBinArray;
256 use crate::arrays::VarBinVTable;
257 use crate::compute::conformance::take::test_take_conformance;
258 use crate::compute::take;
259 use crate::validity::Validity;
260
261 #[test]
262 fn test_null_take() {
263 let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
264
265 let idx1: PrimitiveArray = (0..1).collect();
266
267 assert_eq!(
268 take(arr.as_ref(), idx1.as_ref()).unwrap().dtype(),
269 &DType::Utf8(Nullability::NonNullable)
270 );
271
272 let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
273
274 assert_eq!(
275 take(arr.as_ref(), idx2.as_ref()).unwrap().dtype(),
276 &DType::Utf8(Nullability::Nullable)
277 );
278 }
279
280 #[rstest]
281 #[case(VarBinArray::from_iter(
282 ["hello", "world", "test", "data", "array"].map(Some),
283 DType::Utf8(Nullability::NonNullable),
284 ))]
285 #[case(VarBinArray::from_iter(
286 [Some("hello"), None, Some("test"), Some("data"), None],
287 DType::Utf8(Nullability::Nullable),
288 ))]
289 #[case(VarBinArray::from_iter(
290 [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
291 DType::Binary(Nullability::NonNullable),
292 ))]
293 #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
294 fn test_take_varbin_conformance(#[case] array: VarBinArray) {
295 test_take_conformance(array.as_ref());
296 }
297
298 #[test]
299 fn test_take_overflow() {
300 let scream = std::iter::once("a").cycle().take(128).collect::<String>();
301 let bytes = ByteBuffer::copy_from(scream.as_bytes());
302 let offsets = buffer![0u8, 128u8].into_array();
303
304 let array = VarBinArray::new(
305 offsets,
306 bytes,
307 DType::Utf8(Nullability::NonNullable),
308 Validity::NonNullable,
309 );
310
311 let indices = buffer![0u32, 0u32, 0u32].into_array();
312 let taken = take(array.as_ref(), indices.as_ref()).unwrap();
313
314 let taken_str = taken.as_::<VarBinVTable>();
315 assert_eq!(taken_str.len(), 3);
316 assert_eq!(taken_str.bytes_at(0).as_bytes(), scream.as_bytes());
317 assert_eq!(taken_str.bytes_at(1).as_bytes(), scream.as_bytes());
318 assert_eq!(taken_str.bytes_at(2).as_bytes(), scream.as_bytes());
319 }
320}