Skip to main content

vortex_array/arrays/varbin/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBufferMut;
5use vortex_buffer::BufferMut;
6use vortex_buffer::ByteBufferMut;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_panic;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::array::ArrayView;
15use crate::arrays::PrimitiveArray;
16use crate::arrays::VarBin;
17use crate::arrays::VarBinArray;
18use crate::arrays::dict::TakeExecute;
19use crate::arrays::varbin::VarBinArrayExt;
20use crate::dtype::DType;
21use crate::dtype::IntegerPType;
22use crate::executor::ExecutionCtx;
23use crate::match_each_integer_ptype;
24use crate::validity::Validity;
25
26impl TakeExecute for VarBin {
27    fn take(
28        array: ArrayView<'_, VarBin>,
29        indices: &ArrayRef,
30        ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        // TODO(joe): Be lazy with execute
33        let offsets = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
34        let data = array.bytes();
35        let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
36        let dtype = array
37            .dtype()
38            .clone()
39            .union_nullability(indices.dtype().nullability());
40        let array_validity = array.varbin_validity().to_mask(array.as_ref().len(), ctx)?;
41        let indices_validity = indices
42            .as_ref()
43            .validity()?
44            .to_mask(indices.as_ref().len(), ctx)?;
45
46        let array = match_each_integer_ptype!(indices.ptype(), |I| {
47            // On take, offsets get widened to either 32- or 64-bit based on the original type,
48            // to avoid overflow issues.
49            match offsets.ptype() {
50                PType::U8 => take::<I, u8, u32>(
51                    dtype,
52                    offsets.as_slice::<u8>(),
53                    data.as_slice(),
54                    indices.as_slice::<I>(),
55                    array_validity,
56                    indices_validity,
57                ),
58                PType::U16 => take::<I, u16, u32>(
59                    dtype,
60                    offsets.as_slice::<u16>(),
61                    data.as_slice(),
62                    indices.as_slice::<I>(),
63                    array_validity,
64                    indices_validity,
65                ),
66                PType::U32 => take::<I, u32, u32>(
67                    dtype,
68                    offsets.as_slice::<u32>(),
69                    data.as_slice(),
70                    indices.as_slice::<I>(),
71                    array_validity,
72                    indices_validity,
73                ),
74                PType::U64 => take::<I, u64, u64>(
75                    dtype,
76                    offsets.as_slice::<u64>(),
77                    data.as_slice(),
78                    indices.as_slice::<I>(),
79                    array_validity,
80                    indices_validity,
81                ),
82                PType::I8 => take::<I, i8, i32>(
83                    dtype,
84                    offsets.as_slice::<i8>(),
85                    data.as_slice(),
86                    indices.as_slice::<I>(),
87                    array_validity,
88                    indices_validity,
89                ),
90                PType::I16 => take::<I, i16, i32>(
91                    dtype,
92                    offsets.as_slice::<i16>(),
93                    data.as_slice(),
94                    indices.as_slice::<I>(),
95                    array_validity,
96                    indices_validity,
97                ),
98                PType::I32 => take::<I, i32, i32>(
99                    dtype,
100                    offsets.as_slice::<i32>(),
101                    data.as_slice(),
102                    indices.as_slice::<I>(),
103                    array_validity,
104                    indices_validity,
105                ),
106                PType::I64 => take::<I, i64, i64>(
107                    dtype,
108                    offsets.as_slice::<i64>(),
109                    data.as_slice(),
110                    indices.as_slice::<I>(),
111                    array_validity,
112                    indices_validity,
113                ),
114                _ => unreachable!("invalid PType for offsets"),
115            }
116        });
117
118        Ok(Some(array?.into_array()))
119    }
120}
121
122fn take<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
123    dtype: DType,
124    offsets: &[Offset],
125    data: &[u8],
126    indices: &[Index],
127    validity_mask: Mask,
128    indices_validity_mask: Mask,
129) -> VortexResult<VarBinArray> {
130    if !validity_mask.all_true() || !indices_validity_mask.all_true() {
131        return Ok(take_nullable::<Index, Offset, NewOffset>(
132            dtype,
133            offsets,
134            data,
135            indices,
136            validity_mask,
137            indices_validity_mask,
138        ));
139    }
140
141    let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
142    new_offsets.push(NewOffset::zero());
143    let mut current_offset = NewOffset::zero();
144
145    for &idx in indices {
146        let idx = idx
147            .to_usize()
148            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
149        let start = offsets[idx];
150        let stop = offsets[idx + 1];
151
152        current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
153        new_offsets.push(current_offset);
154    }
155
156    let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
157
158    for idx in indices {
159        let idx = idx
160            .to_usize()
161            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
162        let start = offsets[idx]
163            .to_usize()
164            .vortex_expect("Failed to cast max offset to usize");
165        let stop = offsets[idx + 1]
166            .to_usize()
167            .vortex_expect("Failed to cast max offset to usize");
168        new_data.extend_from_slice(&data[start..stop]);
169    }
170
171    let array_validity = Validity::from(dtype.nullability());
172
173    // Safety:
174    // All variants of VarBinArray are satisfied here.
175    unsafe {
176        Ok(VarBinArray::new_unchecked(
177            PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
178            new_data.freeze(),
179            dtype,
180            array_validity,
181        ))
182    }
183}
184
185fn take_nullable<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
186    dtype: DType,
187    offsets: &[Offset],
188    data: &[u8],
189    indices: &[Index],
190    data_validity: Mask,
191    indices_validity: Mask,
192) -> VarBinArray {
193    let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
194    new_offsets.push(NewOffset::zero());
195    let mut current_offset = NewOffset::zero();
196
197    let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
198
199    // Convert indices once and store valid ones with their positions
200    let mut valid_indices = Vec::with_capacity(indices.len());
201
202    // First pass: calculate offsets and validity
203    for (idx, data_idx) in indices.iter().enumerate() {
204        if !indices_validity.value(idx) {
205            validity_buffer.append(false);
206            new_offsets.push(current_offset);
207            continue;
208        }
209        let data_idx_usize = data_idx
210            .to_usize()
211            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
212        if data_validity.value(data_idx_usize) {
213            validity_buffer.append(true);
214            let start = offsets[data_idx_usize];
215            let stop = offsets[data_idx_usize + 1];
216            current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
217            new_offsets.push(current_offset);
218            valid_indices.push(data_idx_usize);
219        } else {
220            validity_buffer.append(false);
221            new_offsets.push(current_offset);
222        }
223    }
224
225    let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
226
227    // Second pass: copy data for valid indices only
228    for data_idx in valid_indices {
229        let start = offsets[data_idx]
230            .to_usize()
231            .vortex_expect("Failed to cast max offset to usize");
232        let stop = offsets[data_idx + 1]
233            .to_usize()
234            .vortex_expect("Failed to cast max offset to usize");
235        new_data.extend_from_slice(&data[start..stop]);
236    }
237
238    let array_validity = Validity::from(validity_buffer.freeze());
239
240    // Safety:
241    // All variants of VarBinArray are satisfied here.
242    unsafe {
243        VarBinArray::new_unchecked(
244            PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
245            new_data.freeze(),
246            dtype,
247            array_validity,
248        )
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use rstest::rstest;
255    use vortex_buffer::ByteBuffer;
256    use vortex_buffer::buffer;
257
258    use crate::IntoArray;
259    use crate::arrays::VarBinArray;
260    use crate::arrays::VarBinViewArray;
261    use crate::arrays::varbin::compute::take::PrimitiveArray;
262    use crate::assert_arrays_eq;
263    use crate::compute::conformance::take::test_take_conformance;
264    use crate::dtype::DType;
265    use crate::dtype::Nullability;
266    use crate::validity::Validity;
267
268    #[test]
269    fn test_null_take() {
270        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
271
272        let idx1: PrimitiveArray = (0..1).collect();
273
274        assert_eq!(
275            arr.take(idx1.into_array()).unwrap().dtype(),
276            &DType::Utf8(Nullability::NonNullable)
277        );
278
279        let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
280
281        assert_eq!(
282            arr.take(idx2.into_array()).unwrap().dtype(),
283            &DType::Utf8(Nullability::Nullable)
284        );
285    }
286
287    #[rstest]
288    #[case(VarBinArray::from_iter(
289        ["hello", "world", "test", "data", "array"].map(Some),
290        DType::Utf8(Nullability::NonNullable),
291    ))]
292    #[case(VarBinArray::from_iter(
293        [Some("hello"), None, Some("test"), Some("data"), None],
294        DType::Utf8(Nullability::Nullable),
295    ))]
296    #[case(VarBinArray::from_iter(
297        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
298        DType::Binary(Nullability::NonNullable),
299    ))]
300    #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
301    fn test_take_varbin_conformance(#[case] array: VarBinArray) {
302        test_take_conformance(&array.into_array());
303    }
304
305    #[test]
306    fn test_take_overflow() {
307        let scream = std::iter::once("a").cycle().take(128).collect::<String>();
308        let bytes = ByteBuffer::copy_from(scream.as_bytes());
309        let offsets = buffer![0u8, 128u8].into_array();
310
311        let array = VarBinArray::new(
312            offsets,
313            bytes,
314            DType::Utf8(Nullability::NonNullable),
315            Validity::NonNullable,
316        );
317
318        let indices = buffer![0u32; 3].into_array();
319        let taken = array.take(indices).unwrap();
320
321        let expected = VarBinViewArray::from_iter(
322            [Some(scream.clone()), Some(scream.clone()), Some(scream)],
323            DType::Utf8(Nullability::NonNullable),
324        );
325        assert_arrays_eq!(expected, taken);
326    }
327}