Skip to main content

vortex_array/arrays/varbin/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBufferMut;
5use vortex_buffer::BufferMut;
6use vortex_buffer::ByteBufferMut;
7use vortex_dtype::DType;
8use vortex_dtype::IntegerPType;
9use vortex_dtype::match_each_integer_ptype;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_panic;
13use vortex_mask::Mask;
14
15use crate::Array;
16use crate::ArrayRef;
17use crate::IntoArray;
18use crate::ToCanonical;
19use crate::arrays::PrimitiveArray;
20use crate::arrays::TakeExecute;
21use crate::arrays::VarBinVTable;
22use crate::arrays::varbin::VarBinArray;
23use crate::executor::ExecutionCtx;
24use crate::validity::Validity;
25
26impl TakeExecute for VarBinVTable {
27    fn take(
28        array: &VarBinArray,
29        indices: &dyn Array,
30        _ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        // TODO(joe): Be lazy with execute
33        let offsets = array.offsets().to_primitive();
34        let data = array.bytes();
35        let indices = indices.to_primitive();
36        let dtype = array
37            .dtype()
38            .clone()
39            .union_nullability(indices.dtype().nullability());
40        let array_validity = array.validity_mask()?;
41        let indices_validity = indices.validity_mask()?;
42
43        let array = match_each_integer_ptype!(indices.ptype(), |I| {
44            // On take, offsets get widened to either 32- or 64-bit based on the original type,
45            // to avoid overflow issues.
46            match offsets.ptype() {
47                PType::U8 => take::<I, u8, u32>(
48                    dtype,
49                    offsets.as_slice::<u8>(),
50                    data.as_slice(),
51                    indices.as_slice::<I>(),
52                    array_validity,
53                    indices_validity,
54                ),
55                PType::U16 => take::<I, u16, u32>(
56                    dtype,
57                    offsets.as_slice::<u16>(),
58                    data.as_slice(),
59                    indices.as_slice::<I>(),
60                    array_validity,
61                    indices_validity,
62                ),
63                PType::U32 => take::<I, u32, u32>(
64                    dtype,
65                    offsets.as_slice::<u32>(),
66                    data.as_slice(),
67                    indices.as_slice::<I>(),
68                    array_validity,
69                    indices_validity,
70                ),
71                PType::U64 => take::<I, u64, u64>(
72                    dtype,
73                    offsets.as_slice::<u64>(),
74                    data.as_slice(),
75                    indices.as_slice::<I>(),
76                    array_validity,
77                    indices_validity,
78                ),
79                PType::I8 => take::<I, i8, i32>(
80                    dtype,
81                    offsets.as_slice::<i8>(),
82                    data.as_slice(),
83                    indices.as_slice::<I>(),
84                    array_validity,
85                    indices_validity,
86                ),
87                PType::I16 => take::<I, i16, i32>(
88                    dtype,
89                    offsets.as_slice::<i16>(),
90                    data.as_slice(),
91                    indices.as_slice::<I>(),
92                    array_validity,
93                    indices_validity,
94                ),
95                PType::I32 => take::<I, i32, i32>(
96                    dtype,
97                    offsets.as_slice::<i32>(),
98                    data.as_slice(),
99                    indices.as_slice::<I>(),
100                    array_validity,
101                    indices_validity,
102                ),
103                PType::I64 => take::<I, i64, i64>(
104                    dtype,
105                    offsets.as_slice::<i64>(),
106                    data.as_slice(),
107                    indices.as_slice::<I>(),
108                    array_validity,
109                    indices_validity,
110                ),
111                _ => unreachable!("invalid PType for offsets"),
112            }
113        });
114
115        Ok(Some(array?.into_array()))
116    }
117}
118
119fn take<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
120    dtype: DType,
121    offsets: &[Offset],
122    data: &[u8],
123    indices: &[Index],
124    validity_mask: Mask,
125    indices_validity_mask: Mask,
126) -> VortexResult<VarBinArray> {
127    if !validity_mask.all_true() || !indices_validity_mask.all_true() {
128        return Ok(take_nullable::<Index, Offset, NewOffset>(
129            dtype,
130            offsets,
131            data,
132            indices,
133            validity_mask,
134            indices_validity_mask,
135        ));
136    }
137
138    let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
139    new_offsets.push(NewOffset::zero());
140    let mut current_offset = NewOffset::zero();
141
142    for &idx in indices {
143        let idx = idx
144            .to_usize()
145            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
146        let start = offsets[idx];
147        let stop = offsets[idx + 1];
148
149        current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
150        new_offsets.push(current_offset);
151    }
152
153    let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
154
155    for idx in indices {
156        let idx = idx
157            .to_usize()
158            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
159        let start = offsets[idx]
160            .to_usize()
161            .vortex_expect("Failed to cast max offset to usize");
162        let stop = offsets[idx + 1]
163            .to_usize()
164            .vortex_expect("Failed to cast max offset to usize");
165        new_data.extend_from_slice(&data[start..stop]);
166    }
167
168    let array_validity = Validity::from(dtype.nullability());
169
170    // Safety:
171    // All variants of VarBinArray are satisfied here.
172    unsafe {
173        Ok(VarBinArray::new_unchecked(
174            PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
175            new_data.freeze(),
176            dtype,
177            array_validity,
178        ))
179    }
180}
181
182fn take_nullable<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
183    dtype: DType,
184    offsets: &[Offset],
185    data: &[u8],
186    indices: &[Index],
187    data_validity: Mask,
188    indices_validity: Mask,
189) -> VarBinArray {
190    let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
191    new_offsets.push(NewOffset::zero());
192    let mut current_offset = NewOffset::zero();
193
194    let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
195
196    // Convert indices once and store valid ones with their positions
197    let mut valid_indices = Vec::with_capacity(indices.len());
198
199    // First pass: calculate offsets and validity
200    for (idx, data_idx) in indices.iter().enumerate() {
201        if !indices_validity.value(idx) {
202            validity_buffer.append(false);
203            new_offsets.push(current_offset);
204            continue;
205        }
206        let data_idx_usize = data_idx
207            .to_usize()
208            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
209        if data_validity.value(data_idx_usize) {
210            validity_buffer.append(true);
211            let start = offsets[data_idx_usize];
212            let stop = offsets[data_idx_usize + 1];
213            current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
214            new_offsets.push(current_offset);
215            valid_indices.push(data_idx_usize);
216        } else {
217            validity_buffer.append(false);
218            new_offsets.push(current_offset);
219        }
220    }
221
222    let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
223
224    // Second pass: copy data for valid indices only
225    for data_idx in valid_indices {
226        let start = offsets[data_idx]
227            .to_usize()
228            .vortex_expect("Failed to cast max offset to usize");
229        let stop = offsets[data_idx + 1]
230            .to_usize()
231            .vortex_expect("Failed to cast max offset to usize");
232        new_data.extend_from_slice(&data[start..stop]);
233    }
234
235    let array_validity = Validity::from(validity_buffer.freeze());
236
237    // Safety:
238    // All variants of VarBinArray are satisfied here.
239    unsafe {
240        VarBinArray::new_unchecked(
241            PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
242            new_data.freeze(),
243            dtype,
244            array_validity,
245        )
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use rstest::rstest;
252    use vortex_buffer::ByteBuffer;
253    use vortex_buffer::buffer;
254    use vortex_dtype::DType;
255    use vortex_dtype::Nullability;
256
257    use crate::Array;
258    use crate::IntoArray;
259    use crate::arrays::PrimitiveArray;
260    use crate::arrays::VarBinArray;
261    use crate::arrays::VarBinViewArray;
262    use crate::assert_arrays_eq;
263    use crate::compute::conformance::take::test_take_conformance;
264    use crate::validity::Validity;
265
266    #[test]
267    fn test_null_take() {
268        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
269
270        let idx1: PrimitiveArray = (0..1).collect();
271
272        assert_eq!(
273            arr.take(idx1.to_array()).unwrap().dtype(),
274            &DType::Utf8(Nullability::NonNullable)
275        );
276
277        let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
278
279        assert_eq!(
280            arr.take(idx2.to_array()).unwrap().dtype(),
281            &DType::Utf8(Nullability::Nullable)
282        );
283    }
284
285    #[rstest]
286    #[case(VarBinArray::from_iter(
287        ["hello", "world", "test", "data", "array"].map(Some),
288        DType::Utf8(Nullability::NonNullable),
289    ))]
290    #[case(VarBinArray::from_iter(
291        [Some("hello"), None, Some("test"), Some("data"), None],
292        DType::Utf8(Nullability::Nullable),
293    ))]
294    #[case(VarBinArray::from_iter(
295        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
296        DType::Binary(Nullability::NonNullable),
297    ))]
298    #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
299    fn test_take_varbin_conformance(#[case] array: VarBinArray) {
300        test_take_conformance(array.as_ref());
301    }
302
303    #[test]
304    fn test_take_overflow() {
305        let scream = std::iter::once("a").cycle().take(128).collect::<String>();
306        let bytes = ByteBuffer::copy_from(scream.as_bytes());
307        let offsets = buffer![0u8, 128u8].into_array();
308
309        let array = VarBinArray::new(
310            offsets,
311            bytes,
312            DType::Utf8(Nullability::NonNullable),
313            Validity::NonNullable,
314        );
315
316        let indices = buffer![0u32; 3].into_array();
317        let taken = array.take(indices.to_array()).unwrap();
318
319        let expected = VarBinViewArray::from_iter(
320            [Some(scream.clone()), Some(scream.clone()), Some(scream)],
321            DType::Utf8(Nullability::NonNullable),
322        );
323        assert_arrays_eq!(expected, taken);
324    }
325}