Skip to main content

vortex_array/arrays/varbin/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBufferMut;
5use vortex_buffer::BufferMut;
6use vortex_buffer::ByteBufferMut;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_panic;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::arrays::PrimitiveArray;
15use crate::arrays::TakeExecute;
16use crate::arrays::VarBinVTable;
17use crate::arrays::varbin::VarBinArray;
18use crate::dtype::DType;
19use crate::dtype::IntegerPType;
20use crate::executor::ExecutionCtx;
21use crate::match_each_integer_ptype;
22use crate::validity::Validity;
23
24impl TakeExecute for VarBinVTable {
25    fn take(
26        array: &VarBinArray,
27        indices: &ArrayRef,
28        ctx: &mut ExecutionCtx,
29    ) -> VortexResult<Option<ArrayRef>> {
30        // TODO(joe): Be lazy with execute
31        let offsets = array.offsets().to_array().execute::<PrimitiveArray>(ctx)?;
32        let data = array.bytes();
33        let indices = indices.to_array().execute::<PrimitiveArray>(ctx)?;
34        let dtype = array
35            .dtype()
36            .clone()
37            .union_nullability(indices.dtype().nullability());
38        let array_validity = array.validity_mask()?;
39        let indices_validity = indices.validity_mask()?;
40
41        let array = match_each_integer_ptype!(indices.ptype(), |I| {
42            // On take, offsets get widened to either 32- or 64-bit based on the original type,
43            // to avoid overflow issues.
44            match offsets.ptype() {
45                PType::U8 => take::<I, u8, u32>(
46                    dtype,
47                    offsets.as_slice::<u8>(),
48                    data.as_slice(),
49                    indices.as_slice::<I>(),
50                    array_validity,
51                    indices_validity,
52                ),
53                PType::U16 => take::<I, u16, u32>(
54                    dtype,
55                    offsets.as_slice::<u16>(),
56                    data.as_slice(),
57                    indices.as_slice::<I>(),
58                    array_validity,
59                    indices_validity,
60                ),
61                PType::U32 => take::<I, u32, u32>(
62                    dtype,
63                    offsets.as_slice::<u32>(),
64                    data.as_slice(),
65                    indices.as_slice::<I>(),
66                    array_validity,
67                    indices_validity,
68                ),
69                PType::U64 => take::<I, u64, u64>(
70                    dtype,
71                    offsets.as_slice::<u64>(),
72                    data.as_slice(),
73                    indices.as_slice::<I>(),
74                    array_validity,
75                    indices_validity,
76                ),
77                PType::I8 => take::<I, i8, i32>(
78                    dtype,
79                    offsets.as_slice::<i8>(),
80                    data.as_slice(),
81                    indices.as_slice::<I>(),
82                    array_validity,
83                    indices_validity,
84                ),
85                PType::I16 => take::<I, i16, i32>(
86                    dtype,
87                    offsets.as_slice::<i16>(),
88                    data.as_slice(),
89                    indices.as_slice::<I>(),
90                    array_validity,
91                    indices_validity,
92                ),
93                PType::I32 => take::<I, i32, i32>(
94                    dtype,
95                    offsets.as_slice::<i32>(),
96                    data.as_slice(),
97                    indices.as_slice::<I>(),
98                    array_validity,
99                    indices_validity,
100                ),
101                PType::I64 => take::<I, i64, i64>(
102                    dtype,
103                    offsets.as_slice::<i64>(),
104                    data.as_slice(),
105                    indices.as_slice::<I>(),
106                    array_validity,
107                    indices_validity,
108                ),
109                _ => unreachable!("invalid PType for offsets"),
110            }
111        });
112
113        Ok(Some(array?.into_array()))
114    }
115}
116
117fn take<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
118    dtype: DType,
119    offsets: &[Offset],
120    data: &[u8],
121    indices: &[Index],
122    validity_mask: Mask,
123    indices_validity_mask: Mask,
124) -> VortexResult<VarBinArray> {
125    if !validity_mask.all_true() || !indices_validity_mask.all_true() {
126        return Ok(take_nullable::<Index, Offset, NewOffset>(
127            dtype,
128            offsets,
129            data,
130            indices,
131            validity_mask,
132            indices_validity_mask,
133        ));
134    }
135
136    let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
137    new_offsets.push(NewOffset::zero());
138    let mut current_offset = NewOffset::zero();
139
140    for &idx in indices {
141        let idx = idx
142            .to_usize()
143            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
144        let start = offsets[idx];
145        let stop = offsets[idx + 1];
146
147        current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
148        new_offsets.push(current_offset);
149    }
150
151    let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
152
153    for idx in indices {
154        let idx = idx
155            .to_usize()
156            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
157        let start = offsets[idx]
158            .to_usize()
159            .vortex_expect("Failed to cast max offset to usize");
160        let stop = offsets[idx + 1]
161            .to_usize()
162            .vortex_expect("Failed to cast max offset to usize");
163        new_data.extend_from_slice(&data[start..stop]);
164    }
165
166    let array_validity = Validity::from(dtype.nullability());
167
168    // Safety:
169    // All variants of VarBinArray are satisfied here.
170    unsafe {
171        Ok(VarBinArray::new_unchecked(
172            PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
173            new_data.freeze(),
174            dtype,
175            array_validity,
176        ))
177    }
178}
179
180fn take_nullable<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
181    dtype: DType,
182    offsets: &[Offset],
183    data: &[u8],
184    indices: &[Index],
185    data_validity: Mask,
186    indices_validity: Mask,
187) -> VarBinArray {
188    let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
189    new_offsets.push(NewOffset::zero());
190    let mut current_offset = NewOffset::zero();
191
192    let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
193
194    // Convert indices once and store valid ones with their positions
195    let mut valid_indices = Vec::with_capacity(indices.len());
196
197    // First pass: calculate offsets and validity
198    for (idx, data_idx) in indices.iter().enumerate() {
199        if !indices_validity.value(idx) {
200            validity_buffer.append(false);
201            new_offsets.push(current_offset);
202            continue;
203        }
204        let data_idx_usize = data_idx
205            .to_usize()
206            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
207        if data_validity.value(data_idx_usize) {
208            validity_buffer.append(true);
209            let start = offsets[data_idx_usize];
210            let stop = offsets[data_idx_usize + 1];
211            current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
212            new_offsets.push(current_offset);
213            valid_indices.push(data_idx_usize);
214        } else {
215            validity_buffer.append(false);
216            new_offsets.push(current_offset);
217        }
218    }
219
220    let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
221
222    // Second pass: copy data for valid indices only
223    for data_idx in valid_indices {
224        let start = offsets[data_idx]
225            .to_usize()
226            .vortex_expect("Failed to cast max offset to usize");
227        let stop = offsets[data_idx + 1]
228            .to_usize()
229            .vortex_expect("Failed to cast max offset to usize");
230        new_data.extend_from_slice(&data[start..stop]);
231    }
232
233    let array_validity = Validity::from(validity_buffer.freeze());
234
235    // Safety:
236    // All variants of VarBinArray are satisfied here.
237    unsafe {
238        VarBinArray::new_unchecked(
239            PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
240            new_data.freeze(),
241            dtype,
242            array_validity,
243        )
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use rstest::rstest;
250    use vortex_buffer::ByteBuffer;
251    use vortex_buffer::buffer;
252
253    use crate::Array;
254    use crate::IntoArray;
255    use crate::arrays::PrimitiveArray;
256    use crate::arrays::VarBinArray;
257    use crate::arrays::VarBinViewArray;
258    use crate::assert_arrays_eq;
259    use crate::compute::conformance::take::test_take_conformance;
260    use crate::dtype::DType;
261    use crate::dtype::Nullability;
262    use crate::validity::Validity;
263
264    #[test]
265    fn test_null_take() {
266        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
267
268        let idx1: PrimitiveArray = (0..1).collect();
269
270        assert_eq!(
271            arr.take(idx1.to_array()).unwrap().dtype(),
272            &DType::Utf8(Nullability::NonNullable)
273        );
274
275        let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
276
277        assert_eq!(
278            arr.take(idx2.to_array()).unwrap().dtype(),
279            &DType::Utf8(Nullability::Nullable)
280        );
281    }
282
283    #[rstest]
284    #[case(VarBinArray::from_iter(
285        ["hello", "world", "test", "data", "array"].map(Some),
286        DType::Utf8(Nullability::NonNullable),
287    ))]
288    #[case(VarBinArray::from_iter(
289        [Some("hello"), None, Some("test"), Some("data"), None],
290        DType::Utf8(Nullability::Nullable),
291    ))]
292    #[case(VarBinArray::from_iter(
293        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
294        DType::Binary(Nullability::NonNullable),
295    ))]
296    #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
297    fn test_take_varbin_conformance(#[case] array: VarBinArray) {
298        test_take_conformance(&array.to_array());
299    }
300
301    #[test]
302    fn test_take_overflow() {
303        let scream = std::iter::once("a").cycle().take(128).collect::<String>();
304        let bytes = ByteBuffer::copy_from(scream.as_bytes());
305        let offsets = buffer![0u8, 128u8].into_array();
306
307        let array = VarBinArray::new(
308            offsets,
309            bytes,
310            DType::Utf8(Nullability::NonNullable),
311            Validity::NonNullable,
312        );
313
314        let indices = buffer![0u32; 3].into_array();
315        let taken = array.take(indices.to_array()).unwrap();
316
317        let expected = VarBinViewArray::from_iter(
318            [Some(scream.clone()), Some(scream.clone()), Some(scream)],
319            DType::Utf8(Nullability::NonNullable),
320        );
321        assert_arrays_eq!(expected, taken);
322    }
323}