vortex_array/arrays/varbin/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBufferMut;
5use vortex_buffer::BufferMut;
6use vortex_buffer::ByteBufferMut;
7use vortex_dtype::DType;
8use vortex_dtype::IntegerPType;
9use vortex_dtype::match_each_integer_ptype;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_panic;
13use vortex_mask::Mask;
14
15use crate::Array;
16use crate::ArrayRef;
17use crate::IntoArray;
18use crate::ToCanonical;
19use crate::arrays::PrimitiveArray;
20use crate::arrays::VarBinVTable;
21use crate::arrays::varbin::VarBinArray;
22use crate::compute::TakeKernel;
23use crate::compute::TakeKernelAdapter;
24use crate::register_kernel;
25use crate::validity::Validity;
26
27impl TakeKernel for VarBinVTable {
28    fn take(&self, array: &VarBinArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
29        let offsets = array.offsets().to_primitive();
30        let data = array.bytes();
31        let indices = indices.to_primitive();
32        let dtype = array
33            .dtype()
34            .clone()
35            .union_nullability(indices.dtype().nullability());
36        let array = match_each_integer_ptype!(indices.ptype(), |I| {
37            // On take, offsets get widened to either 32- or 64-bit based on the original type,
38            // to avoid overflow issues.
39            match offsets.ptype() {
40                PType::U8 => take::<I, u8, u32>(
41                    dtype,
42                    offsets.as_slice::<u8>(),
43                    data.as_slice(),
44                    indices.as_slice::<I>(),
45                    array.validity_mask(),
46                    indices.validity_mask(),
47                ),
48                PType::U16 => take::<I, u16, u32>(
49                    dtype,
50                    offsets.as_slice::<u16>(),
51                    data.as_slice(),
52                    indices.as_slice::<I>(),
53                    array.validity_mask(),
54                    indices.validity_mask(),
55                ),
56                PType::U32 => take::<I, u32, u32>(
57                    dtype,
58                    offsets.as_slice::<u32>(),
59                    data.as_slice(),
60                    indices.as_slice::<I>(),
61                    array.validity_mask(),
62                    indices.validity_mask(),
63                ),
64                PType::U64 => take::<I, u64, u64>(
65                    dtype,
66                    offsets.as_slice::<u64>(),
67                    data.as_slice(),
68                    indices.as_slice::<I>(),
69                    array.validity_mask(),
70                    indices.validity_mask(),
71                ),
72                PType::I8 => take::<I, i8, i32>(
73                    dtype,
74                    offsets.as_slice::<i8>(),
75                    data.as_slice(),
76                    indices.as_slice::<I>(),
77                    array.validity_mask(),
78                    indices.validity_mask(),
79                ),
80                PType::I16 => take::<I, i16, i32>(
81                    dtype,
82                    offsets.as_slice::<i16>(),
83                    data.as_slice(),
84                    indices.as_slice::<I>(),
85                    array.validity_mask(),
86                    indices.validity_mask(),
87                ),
88                PType::I32 => take::<I, i32, i32>(
89                    dtype,
90                    offsets.as_slice::<i32>(),
91                    data.as_slice(),
92                    indices.as_slice::<I>(),
93                    array.validity_mask(),
94                    indices.validity_mask(),
95                ),
96                PType::I64 => take::<I, i64, i64>(
97                    dtype,
98                    offsets.as_slice::<i64>(),
99                    data.as_slice(),
100                    indices.as_slice::<I>(),
101                    array.validity_mask(),
102                    indices.validity_mask(),
103                ),
104                _ => unreachable!("invalid PType for offsets"),
105            }
106        });
107
108        Ok(array?.into_array())
109    }
110}
111
112register_kernel!(TakeKernelAdapter(VarBinVTable).lift());
113
114fn take<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
115    dtype: DType,
116    offsets: &[Offset],
117    data: &[u8],
118    indices: &[Index],
119    validity_mask: Mask,
120    indices_validity_mask: Mask,
121) -> VortexResult<VarBinArray> {
122    if !validity_mask.all_true() || !indices_validity_mask.all_true() {
123        return Ok(take_nullable::<Index, Offset, NewOffset>(
124            dtype,
125            offsets,
126            data,
127            indices,
128            validity_mask,
129            indices_validity_mask,
130        ));
131    }
132
133    let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
134    new_offsets.push(NewOffset::zero());
135    let mut current_offset = NewOffset::zero();
136
137    for &idx in indices {
138        let idx = idx
139            .to_usize()
140            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
141        let start = offsets[idx];
142        let stop = offsets[idx + 1];
143
144        current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
145        new_offsets.push(current_offset);
146    }
147
148    let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
149
150    for idx in indices {
151        let idx = idx
152            .to_usize()
153            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
154        let start = offsets[idx]
155            .to_usize()
156            .vortex_expect("Failed to cast max offset to usize");
157        let stop = offsets[idx + 1]
158            .to_usize()
159            .vortex_expect("Failed to cast max offset to usize");
160        new_data.extend_from_slice(&data[start..stop]);
161    }
162
163    let array_validity = Validity::from(dtype.nullability());
164
165    // Safety:
166    // All variants of VarBinArray are satisfied here.
167    unsafe {
168        Ok(VarBinArray::new_unchecked(
169            PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
170            new_data.freeze(),
171            dtype,
172            array_validity,
173        ))
174    }
175}
176
177fn take_nullable<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
178    dtype: DType,
179    offsets: &[Offset],
180    data: &[u8],
181    indices: &[Index],
182    data_validity: Mask,
183    indices_validity: Mask,
184) -> VarBinArray {
185    let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
186    new_offsets.push(NewOffset::zero());
187    let mut current_offset = NewOffset::zero();
188
189    let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
190
191    // Convert indices once and store valid ones with their positions
192    let mut valid_indices = Vec::with_capacity(indices.len());
193
194    // First pass: calculate offsets and validity
195    for (idx, data_idx) in indices.iter().enumerate() {
196        if !indices_validity.value(idx) {
197            validity_buffer.append(false);
198            new_offsets.push(current_offset);
199            continue;
200        }
201        let data_idx_usize = data_idx
202            .to_usize()
203            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
204        if data_validity.value(data_idx_usize) {
205            validity_buffer.append(true);
206            let start = offsets[data_idx_usize];
207            let stop = offsets[data_idx_usize + 1];
208            current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
209            new_offsets.push(current_offset);
210            valid_indices.push(data_idx_usize);
211        } else {
212            validity_buffer.append(false);
213            new_offsets.push(current_offset);
214        }
215    }
216
217    let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
218
219    // Second pass: copy data for valid indices only
220    for data_idx in valid_indices {
221        let start = offsets[data_idx]
222            .to_usize()
223            .vortex_expect("Failed to cast max offset to usize");
224        let stop = offsets[data_idx + 1]
225            .to_usize()
226            .vortex_expect("Failed to cast max offset to usize");
227        new_data.extend_from_slice(&data[start..stop]);
228    }
229
230    let array_validity = Validity::from(validity_buffer.freeze());
231
232    // Safety:
233    // All variants of VarBinArray are satisfied here.
234    unsafe {
235        VarBinArray::new_unchecked(
236            PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
237            new_data.freeze(),
238            dtype,
239            array_validity,
240        )
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use rstest::rstest;
247    use vortex_buffer::ByteBuffer;
248    use vortex_buffer::buffer;
249    use vortex_dtype::DType;
250    use vortex_dtype::Nullability;
251
252    use crate::Array;
253    use crate::IntoArray;
254    use crate::arrays::PrimitiveArray;
255    use crate::arrays::VarBinArray;
256    use crate::arrays::VarBinVTable;
257    use crate::compute::conformance::take::test_take_conformance;
258    use crate::compute::take;
259    use crate::validity::Validity;
260
261    #[test]
262    fn test_null_take() {
263        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
264
265        let idx1: PrimitiveArray = (0..1).collect();
266
267        assert_eq!(
268            take(arr.as_ref(), idx1.as_ref()).unwrap().dtype(),
269            &DType::Utf8(Nullability::NonNullable)
270        );
271
272        let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
273
274        assert_eq!(
275            take(arr.as_ref(), idx2.as_ref()).unwrap().dtype(),
276            &DType::Utf8(Nullability::Nullable)
277        );
278    }
279
280    #[rstest]
281    #[case(VarBinArray::from_iter(
282        ["hello", "world", "test", "data", "array"].map(Some),
283        DType::Utf8(Nullability::NonNullable),
284    ))]
285    #[case(VarBinArray::from_iter(
286        [Some("hello"), None, Some("test"), Some("data"), None],
287        DType::Utf8(Nullability::Nullable),
288    ))]
289    #[case(VarBinArray::from_iter(
290        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
291        DType::Binary(Nullability::NonNullable),
292    ))]
293    #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
294    fn test_take_varbin_conformance(#[case] array: VarBinArray) {
295        test_take_conformance(array.as_ref());
296    }
297
298    #[test]
299    fn test_take_overflow() {
300        let scream = std::iter::once("a").cycle().take(128).collect::<String>();
301        let bytes = ByteBuffer::copy_from(scream.as_bytes());
302        let offsets = buffer![0u8, 128u8].into_array();
303
304        let array = VarBinArray::new(
305            offsets,
306            bytes,
307            DType::Utf8(Nullability::NonNullable),
308            Validity::NonNullable,
309        );
310
311        let indices = buffer![0u32, 0u32, 0u32].into_array();
312        let taken = take(array.as_ref(), indices.as_ref()).unwrap();
313
314        let taken_str = taken.as_::<VarBinVTable>();
315        assert_eq!(taken_str.len(), 3);
316        assert_eq!(taken_str.bytes_at(0).as_bytes(), scream.as_bytes());
317        assert_eq!(taken_str.bytes_at(1).as_bytes(), scream.as_bytes());
318        assert_eq!(taken_str.bytes_at(2).as_bytes(), scream.as_bytes());
319    }
320}