Skip to main content

vortex_array/arrays/varbin/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBufferMut;
5use vortex_buffer::BufferMut;
6use vortex_buffer::ByteBufferMut;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_panic;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::array::ArrayView;
15use crate::arrays::PrimitiveArray;
16use crate::arrays::VarBin;
17use crate::arrays::VarBinArray;
18use crate::arrays::dict::TakeExecute;
19use crate::arrays::varbin::VarBinArrayExt;
20use crate::dtype::DType;
21use crate::dtype::IntegerPType;
22use crate::executor::ExecutionCtx;
23use crate::match_each_integer_ptype;
24use crate::validity::Validity;
25
26impl TakeExecute for VarBin {
27    fn take(
28        array: ArrayView<'_, VarBin>,
29        indices: &ArrayRef,
30        ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        // TODO(joe): Be lazy with execute
33        let offsets = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
34        let data = array.bytes();
35        let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
36        let dtype = array
37            .dtype()
38            .clone()
39            .union_nullability(indices.dtype().nullability());
40        let array_validity = array
41            .varbin_validity()
42            .execute_mask(array.as_ref().len(), ctx)?;
43        let indices_validity = indices
44            .as_ref()
45            .validity()?
46            .execute_mask(indices.as_ref().len(), ctx)?;
47
48        let array = match_each_integer_ptype!(indices.ptype(), |I| {
49            // On take, offsets get widened to either 32- or 64-bit based on the original type,
50            // to avoid overflow issues.
51            match offsets.ptype() {
52                PType::U8 => take::<I, u8, u32>(
53                    dtype,
54                    offsets.as_slice::<u8>(),
55                    data.as_slice(),
56                    indices.as_slice::<I>(),
57                    array_validity,
58                    indices_validity,
59                ),
60                PType::U16 => take::<I, u16, u32>(
61                    dtype,
62                    offsets.as_slice::<u16>(),
63                    data.as_slice(),
64                    indices.as_slice::<I>(),
65                    array_validity,
66                    indices_validity,
67                ),
68                PType::U32 => take::<I, u32, u32>(
69                    dtype,
70                    offsets.as_slice::<u32>(),
71                    data.as_slice(),
72                    indices.as_slice::<I>(),
73                    array_validity,
74                    indices_validity,
75                ),
76                PType::U64 => take::<I, u64, u64>(
77                    dtype,
78                    offsets.as_slice::<u64>(),
79                    data.as_slice(),
80                    indices.as_slice::<I>(),
81                    array_validity,
82                    indices_validity,
83                ),
84                PType::I8 => take::<I, i8, i32>(
85                    dtype,
86                    offsets.as_slice::<i8>(),
87                    data.as_slice(),
88                    indices.as_slice::<I>(),
89                    array_validity,
90                    indices_validity,
91                ),
92                PType::I16 => take::<I, i16, i32>(
93                    dtype,
94                    offsets.as_slice::<i16>(),
95                    data.as_slice(),
96                    indices.as_slice::<I>(),
97                    array_validity,
98                    indices_validity,
99                ),
100                PType::I32 => take::<I, i32, i32>(
101                    dtype,
102                    offsets.as_slice::<i32>(),
103                    data.as_slice(),
104                    indices.as_slice::<I>(),
105                    array_validity,
106                    indices_validity,
107                ),
108                PType::I64 => take::<I, i64, i64>(
109                    dtype,
110                    offsets.as_slice::<i64>(),
111                    data.as_slice(),
112                    indices.as_slice::<I>(),
113                    array_validity,
114                    indices_validity,
115                ),
116                _ => unreachable!("invalid PType for offsets"),
117            }
118        });
119
120        Ok(Some(array?.into_array()))
121    }
122}
123
124fn take<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
125    dtype: DType,
126    offsets: &[Offset],
127    data: &[u8],
128    indices: &[Index],
129    validity_mask: Mask,
130    indices_validity_mask: Mask,
131) -> VortexResult<VarBinArray> {
132    if !validity_mask.all_true() || !indices_validity_mask.all_true() {
133        return Ok(take_nullable::<Index, Offset, NewOffset>(
134            dtype,
135            offsets,
136            data,
137            indices,
138            validity_mask,
139            indices_validity_mask,
140        ));
141    }
142
143    let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
144    new_offsets.push(NewOffset::zero());
145    let mut current_offset = NewOffset::zero();
146
147    for &idx in indices {
148        let idx = idx
149            .to_usize()
150            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
151        let start = offsets[idx];
152        let stop = offsets[idx + 1];
153
154        current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
155        new_offsets.push(current_offset);
156    }
157
158    let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
159
160    for idx in indices {
161        let idx = idx
162            .to_usize()
163            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
164        let start = offsets[idx]
165            .to_usize()
166            .vortex_expect("Failed to cast max offset to usize");
167        let stop = offsets[idx + 1]
168            .to_usize()
169            .vortex_expect("Failed to cast max offset to usize");
170        new_data.extend_from_slice(&data[start..stop]);
171    }
172
173    let array_validity = Validity::from(dtype.nullability());
174
175    // Safety:
176    // All variants of VarBinArray are satisfied here.
177    unsafe {
178        Ok(VarBinArray::new_unchecked(
179            PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
180            new_data.freeze(),
181            dtype,
182            array_validity,
183        ))
184    }
185}
186
187fn take_nullable<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
188    dtype: DType,
189    offsets: &[Offset],
190    data: &[u8],
191    indices: &[Index],
192    data_validity: Mask,
193    indices_validity: Mask,
194) -> VarBinArray {
195    let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
196    new_offsets.push(NewOffset::zero());
197    let mut current_offset = NewOffset::zero();
198
199    let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
200
201    // Convert indices once and store valid ones with their positions
202    let mut valid_indices = Vec::with_capacity(indices.len());
203
204    // First pass: calculate offsets and validity
205    for (idx, data_idx) in indices.iter().enumerate() {
206        if !indices_validity.value(idx) {
207            validity_buffer.append(false);
208            new_offsets.push(current_offset);
209            continue;
210        }
211        let data_idx_usize = data_idx
212            .to_usize()
213            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
214        if data_validity.value(data_idx_usize) {
215            validity_buffer.append(true);
216            let start = offsets[data_idx_usize];
217            let stop = offsets[data_idx_usize + 1];
218            current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
219            new_offsets.push(current_offset);
220            valid_indices.push(data_idx_usize);
221        } else {
222            validity_buffer.append(false);
223            new_offsets.push(current_offset);
224        }
225    }
226
227    let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
228
229    // Second pass: copy data for valid indices only
230    for data_idx in valid_indices {
231        let start = offsets[data_idx]
232            .to_usize()
233            .vortex_expect("Failed to cast max offset to usize");
234        let stop = offsets[data_idx + 1]
235            .to_usize()
236            .vortex_expect("Failed to cast max offset to usize");
237        new_data.extend_from_slice(&data[start..stop]);
238    }
239
240    let array_validity = Validity::from(validity_buffer.freeze());
241
242    // Safety:
243    // All variants of VarBinArray are satisfied here.
244    unsafe {
245        VarBinArray::new_unchecked(
246            PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
247            new_data.freeze(),
248            dtype,
249            array_validity,
250        )
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use rstest::rstest;
257    use vortex_buffer::ByteBuffer;
258    use vortex_buffer::buffer;
259
260    use crate::IntoArray;
261    use crate::arrays::VarBinArray;
262    use crate::arrays::VarBinViewArray;
263    use crate::arrays::varbin::compute::take::PrimitiveArray;
264    use crate::assert_arrays_eq;
265    use crate::compute::conformance::take::test_take_conformance;
266    use crate::dtype::DType;
267    use crate::dtype::Nullability;
268    use crate::validity::Validity;
269
270    #[test]
271    fn test_null_take() {
272        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
273
274        let idx1: PrimitiveArray = (0..1).collect();
275
276        assert_eq!(
277            arr.take(idx1.into_array()).unwrap().dtype(),
278            &DType::Utf8(Nullability::NonNullable)
279        );
280
281        let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
282
283        assert_eq!(
284            arr.take(idx2.into_array()).unwrap().dtype(),
285            &DType::Utf8(Nullability::Nullable)
286        );
287    }
288
289    #[rstest]
290    #[case(VarBinArray::from_iter(
291        ["hello", "world", "test", "data", "array"].map(Some),
292        DType::Utf8(Nullability::NonNullable),
293    ))]
294    #[case(VarBinArray::from_iter(
295        [Some("hello"), None, Some("test"), Some("data"), None],
296        DType::Utf8(Nullability::Nullable),
297    ))]
298    #[case(VarBinArray::from_iter(
299        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
300        DType::Binary(Nullability::NonNullable),
301    ))]
302    #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
303    fn test_take_varbin_conformance(#[case] array: VarBinArray) {
304        test_take_conformance(&array.into_array());
305    }
306
307    #[test]
308    fn test_take_overflow() {
309        let scream = std::iter::once("a").cycle().take(128).collect::<String>();
310        let bytes = ByteBuffer::copy_from(scream.as_bytes());
311        let offsets = buffer![0u8, 128u8].into_array();
312
313        let array = VarBinArray::new(
314            offsets,
315            bytes,
316            DType::Utf8(Nullability::NonNullable),
317            Validity::NonNullable,
318        );
319
320        let indices = buffer![0u32; 3].into_array();
321        let taken = array.take(indices).unwrap();
322
323        let expected = VarBinViewArray::from_iter(
324            [Some(scream.clone()), Some(scream.clone()), Some(scream)],
325            DType::Utf8(Nullability::NonNullable),
326        );
327        assert_arrays_eq!(expected, taken);
328    }
329}