Skip to main content

vortex_array/arrays/varbin/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBufferMut;
5use vortex_buffer::BufferMut;
6use vortex_buffer::ByteBufferMut;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_panic;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::array::ArrayView;
15use crate::arrays::PrimitiveArray;
16use crate::arrays::VarBin;
17use crate::arrays::VarBinArray;
18use crate::arrays::dict::TakeExecute;
19use crate::arrays::primitive::PrimitiveArrayExt;
20use crate::arrays::varbin::VarBinArrayExt;
21use crate::dtype::DType;
22use crate::dtype::IntegerPType;
23use crate::dtype::PType;
24use crate::executor::ExecutionCtx;
25use crate::match_each_unsigned_integer_ptype;
26use crate::validity::Validity;
27
28/// The widened offset type used for a taken `VarBinArray`: offsets are widened to at least 32 bits
29/// (to avoid overflow) while preserving signedness, so a signed result stays Arrow-compatible.
30fn taken_offset_ptype(offsets_ptype: PType) -> PType {
31    match offsets_ptype {
32        PType::U8 | PType::U16 | PType::U32 => PType::U32,
33        PType::U64 => PType::U64,
34        PType::I8 | PType::I16 | PType::I32 => PType::I32,
35        PType::I64 => PType::I64,
36        _ => unreachable!("invalid PType for offsets"),
37    }
38}
39
40impl TakeExecute for VarBin {
41    fn take(
42        array: ArrayView<'_, VarBin>,
43        indices: &ArrayRef,
44        ctx: &mut ExecutionCtx,
45    ) -> VortexResult<Option<ArrayRef>> {
46        // TODO(joe): Be lazy with execute
47        let offsets = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
48        let data = array.bytes();
49        let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
50        let dtype = array
51            .dtype()
52            .clone()
53            .union_nullability(indices.dtype().nullability());
54        let array_validity = array
55            .varbin_validity()
56            .execute_mask(array.as_ref().len(), ctx)?;
57        let indices_validity = indices
58            .as_ref()
59            .validity()?
60            .execute_mask(indices.as_ref().len(), ctx)?;
61
62        // Offsets and indices are non-negative; read them through their unsigned reinterpretations
63        // so we only monomorphize over the 4 unsigned widths each (4x4 instead of 8x8). On take,
64        // offsets get widened to either 32- or 64-bit (to avoid overflow); the built output offsets
65        // are reinterpreted back to `out_offset_ptype` to preserve the result's offset signedness.
66        let out_offset_ptype = taken_offset_ptype(offsets.ptype());
67        let offsets = offsets.reinterpret_cast(offsets.ptype().to_unsigned());
68        let indices = indices.reinterpret_cast(indices.ptype().to_unsigned());
69
70        let array = match_each_unsigned_integer_ptype!(indices.ptype(), |I| {
71            match offsets.ptype() {
72                PType::U8 => take::<I, u8, u32>(
73                    dtype,
74                    offsets.as_slice::<u8>(),
75                    data.as_slice(),
76                    indices.as_slice::<I>(),
77                    array_validity,
78                    indices_validity,
79                    out_offset_ptype,
80                ),
81                PType::U16 => take::<I, u16, u32>(
82                    dtype,
83                    offsets.as_slice::<u16>(),
84                    data.as_slice(),
85                    indices.as_slice::<I>(),
86                    array_validity,
87                    indices_validity,
88                    out_offset_ptype,
89                ),
90                PType::U32 => take::<I, u32, u32>(
91                    dtype,
92                    offsets.as_slice::<u32>(),
93                    data.as_slice(),
94                    indices.as_slice::<I>(),
95                    array_validity,
96                    indices_validity,
97                    out_offset_ptype,
98                ),
99                PType::U64 => take::<I, u64, u64>(
100                    dtype,
101                    offsets.as_slice::<u64>(),
102                    data.as_slice(),
103                    indices.as_slice::<I>(),
104                    array_validity,
105                    indices_validity,
106                    out_offset_ptype,
107                ),
108                _ => unreachable!("invalid PType for offsets"),
109            }
110        });
111
112        Ok(Some(array?.into_array()))
113    }
114}
115
116fn take<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
117    dtype: DType,
118    offsets: &[Offset],
119    data: &[u8],
120    indices: &[Index],
121    validity_mask: Mask,
122    indices_validity_mask: Mask,
123    out_offset_ptype: PType,
124) -> VortexResult<VarBinArray> {
125    if !validity_mask.all_true() || !indices_validity_mask.all_true() {
126        return Ok(take_nullable::<Index, Offset, NewOffset>(
127            dtype,
128            offsets,
129            data,
130            indices,
131            validity_mask,
132            indices_validity_mask,
133            out_offset_ptype,
134        ));
135    }
136
137    let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
138    new_offsets.push(NewOffset::zero());
139    let mut current_offset = NewOffset::zero();
140
141    for &idx in indices {
142        let idx = idx
143            .to_usize()
144            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
145        let start = offsets[idx];
146        let stop = offsets[idx + 1];
147
148        current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
149        new_offsets.push(current_offset);
150    }
151
152    let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
153
154    for idx in indices {
155        let idx = idx
156            .to_usize()
157            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
158        let start = offsets[idx]
159            .to_usize()
160            .vortex_expect("Failed to cast max offset to usize");
161        let stop = offsets[idx + 1]
162            .to_usize()
163            .vortex_expect("Failed to cast max offset to usize");
164        new_data.extend_from_slice(&data[start..stop]);
165    }
166
167    let array_validity = Validity::from(dtype.nullability());
168
169    // Built unsigned; reinterpret back to the signedness-preserving result offset type.
170    let new_offsets = PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable)
171        .reinterpret_cast(out_offset_ptype)
172        .into_array();
173
174    // Safety:
175    // All variants of VarBinArray are satisfied here.
176    unsafe {
177        Ok(VarBinArray::new_unchecked(
178            new_offsets,
179            new_data.freeze(),
180            dtype,
181            array_validity,
182        ))
183    }
184}
185
186fn take_nullable<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
187    dtype: DType,
188    offsets: &[Offset],
189    data: &[u8],
190    indices: &[Index],
191    data_validity: Mask,
192    indices_validity: Mask,
193    out_offset_ptype: PType,
194) -> VarBinArray {
195    let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
196    new_offsets.push(NewOffset::zero());
197    let mut current_offset = NewOffset::zero();
198
199    let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
200
201    // Convert indices once and store valid ones with their positions
202    let mut valid_indices = Vec::with_capacity(indices.len());
203
204    // First pass: calculate offsets and validity
205    for (idx, data_idx) in indices.iter().enumerate() {
206        if !indices_validity.value(idx) {
207            validity_buffer.append(false);
208            new_offsets.push(current_offset);
209            continue;
210        }
211        let data_idx_usize = data_idx
212            .to_usize()
213            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
214        if data_validity.value(data_idx_usize) {
215            validity_buffer.append(true);
216            let start = offsets[data_idx_usize];
217            let stop = offsets[data_idx_usize + 1];
218            current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
219            new_offsets.push(current_offset);
220            valid_indices.push(data_idx_usize);
221        } else {
222            validity_buffer.append(false);
223            new_offsets.push(current_offset);
224        }
225    }
226
227    let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
228
229    // Second pass: copy data for valid indices only
230    for data_idx in valid_indices {
231        let start = offsets[data_idx]
232            .to_usize()
233            .vortex_expect("Failed to cast max offset to usize");
234        let stop = offsets[data_idx + 1]
235            .to_usize()
236            .vortex_expect("Failed to cast max offset to usize");
237        new_data.extend_from_slice(&data[start..stop]);
238    }
239
240    let array_validity = Validity::from(validity_buffer.freeze());
241
242    // Built unsigned; reinterpret back to the signedness-preserving result offset type.
243    let new_offsets = PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable)
244        .reinterpret_cast(out_offset_ptype)
245        .into_array();
246
247    // Safety:
248    // All variants of VarBinArray are satisfied here.
249    unsafe { VarBinArray::new_unchecked(new_offsets, new_data.freeze(), dtype, array_validity) }
250}
251
252#[cfg(test)]
253mod tests {
254    use rstest::rstest;
255    use vortex_buffer::ByteBuffer;
256    use vortex_buffer::buffer;
257
258    use crate::IntoArray;
259    use crate::arrays::VarBinArray;
260    use crate::arrays::VarBinViewArray;
261    use crate::arrays::varbin::compute::take::PrimitiveArray;
262    use crate::assert_arrays_eq;
263    use crate::compute::conformance::take::test_take_conformance;
264    use crate::dtype::DType;
265    use crate::dtype::Nullability;
266    use crate::validity::Validity;
267
268    #[test]
269    fn test_null_take() {
270        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
271
272        let idx1: PrimitiveArray = (0..1).collect();
273
274        assert_eq!(
275            arr.take(idx1.into_array()).unwrap().dtype(),
276            &DType::Utf8(Nullability::NonNullable)
277        );
278
279        let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
280
281        assert_eq!(
282            arr.take(idx2.into_array()).unwrap().dtype(),
283            &DType::Utf8(Nullability::Nullable)
284        );
285    }
286
287    #[rstest]
288    #[case(VarBinArray::from_iter(
289        ["hello", "world", "test", "data", "array"].map(Some),
290        DType::Utf8(Nullability::NonNullable),
291    ))]
292    #[case(VarBinArray::from_iter(
293        [Some("hello"), None, Some("test"), Some("data"), None],
294        DType::Utf8(Nullability::Nullable),
295    ))]
296    #[case(VarBinArray::from_iter(
297        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
298        DType::Binary(Nullability::NonNullable),
299    ))]
300    #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
301    fn test_take_varbin_conformance(#[case] array: VarBinArray) {
302        test_take_conformance(&array.into_array());
303    }
304
305    #[test]
306    fn test_take_overflow() {
307        let scream = std::iter::once("a").cycle().take(128).collect::<String>();
308        let bytes = ByteBuffer::copy_from(scream.as_bytes());
309        let offsets = buffer![0u8, 128u8].into_array();
310
311        let array = VarBinArray::new(
312            offsets,
313            bytes,
314            DType::Utf8(Nullability::NonNullable),
315            Validity::NonNullable,
316        );
317
318        let indices = buffer![0u32; 3].into_array();
319        let taken = array.take(indices).unwrap();
320
321        let expected = VarBinViewArray::from_iter(
322            [Some(scream.clone()), Some(scream.clone()), Some(scream)],
323            DType::Utf8(Nullability::NonNullable),
324        );
325        assert_arrays_eq!(expected, taken);
326    }
327}