vortex_array/arrays/varbin/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::BooleanBufferBuilder;
5use vortex_buffer::{BufferMut, ByteBufferMut};
6use vortex_dtype::{DType, IntegerPType, match_each_integer_ptype};
7use vortex_error::{VortexExpect, VortexResult, vortex_panic};
8use vortex_mask::Mask;
9
10use crate::arrays::varbin::VarBinArray;
11use crate::arrays::{PrimitiveArray, VarBinVTable};
12use crate::compute::{TakeKernel, TakeKernelAdapter};
13use crate::validity::Validity;
14use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
15
16impl TakeKernel for VarBinVTable {
17    fn take(&self, array: &VarBinArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
18        let offsets = array.offsets().to_primitive();
19        let data = array.bytes();
20        let indices = indices.to_primitive();
21        match_each_integer_ptype!(offsets.ptype(), |O| {
22            match_each_integer_ptype!(indices.ptype(), |I| {
23                Ok(take(
24                    array
25                        .dtype()
26                        .clone()
27                        .union_nullability(indices.dtype().nullability()),
28                    offsets.as_slice::<O>(),
29                    data.as_slice(),
30                    indices.as_slice::<I>(),
31                    array.validity_mask(),
32                    indices.validity_mask(),
33                )?
34                .into_array())
35            })
36        })
37    }
38}
39
40register_kernel!(TakeKernelAdapter(VarBinVTable).lift());
41
42fn take<I: IntegerPType, O: IntegerPType>(
43    dtype: DType,
44    offsets: &[O],
45    data: &[u8],
46    indices: &[I],
47    validity_mask: Mask,
48    indices_validity_mask: Mask,
49) -> VortexResult<VarBinArray> {
50    if !validity_mask.all_true() || !indices_validity_mask.all_true() {
51        return Ok(take_nullable(
52            dtype,
53            offsets,
54            data,
55            indices,
56            validity_mask,
57            indices_validity_mask,
58        ));
59    }
60
61    let mut new_offsets = BufferMut::with_capacity(indices.len() + 1);
62    new_offsets.push(O::zero());
63    let mut current_offset = O::zero();
64
65    for &idx in indices {
66        let idx = idx
67            .to_usize()
68            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
69        let start = offsets[idx];
70        let stop = offsets[idx + 1];
71        current_offset += stop - start;
72        new_offsets.push(current_offset);
73    }
74
75    let mut new_data = ByteBufferMut::with_capacity(
76        current_offset
77            .to_usize()
78            .vortex_expect("Failed to cast max offset to usize"),
79    );
80
81    for idx in indices {
82        let idx = idx
83            .to_usize()
84            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
85        let start = offsets[idx]
86            .to_usize()
87            .vortex_expect("Failed to cast max offset to usize");
88        let stop = offsets[idx + 1]
89            .to_usize()
90            .vortex_expect("Failed to cast max offset to usize");
91        new_data.extend_from_slice(&data[start..stop]);
92    }
93
94    let array_validity = Validity::from(dtype.nullability());
95
96    // Safety:
97    // All variants of VarBinArray are satisfied here.
98    unsafe {
99        Ok(VarBinArray::new_unchecked(
100            PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
101            new_data.freeze(),
102            dtype,
103            array_validity,
104        ))
105    }
106}
107
108fn take_nullable<I: IntegerPType, O: IntegerPType>(
109    dtype: DType,
110    offsets: &[O],
111    data: &[u8],
112    indices: &[I],
113    data_validity: Mask,
114    indices_validity: Mask,
115) -> VarBinArray {
116    let mut new_offsets = BufferMut::with_capacity(indices.len() + 1);
117    new_offsets.push(O::zero());
118    let mut current_offset = O::zero();
119
120    let mut validity_buffer = BooleanBufferBuilder::new(indices.len());
121
122    // Convert indices once and store valid ones with their positions
123    let mut valid_indices = Vec::with_capacity(indices.len());
124
125    // First pass: calculate offsets and validity
126    for (idx, data_idx) in indices.iter().enumerate() {
127        if !indices_validity.value(idx) {
128            validity_buffer.append(false);
129            new_offsets.push(current_offset);
130            continue;
131        }
132        let data_idx_usize = data_idx
133            .to_usize()
134            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
135        if data_validity.value(data_idx_usize) {
136            validity_buffer.append(true);
137            let start = offsets[data_idx_usize];
138            let stop = offsets[data_idx_usize + 1];
139            current_offset += stop - start;
140            new_offsets.push(current_offset);
141            valid_indices.push(data_idx_usize);
142        } else {
143            validity_buffer.append(false);
144            new_offsets.push(current_offset);
145        }
146    }
147
148    let mut new_data = ByteBufferMut::with_capacity(
149        current_offset
150            .to_usize()
151            .vortex_expect("Failed to cast max offset to usize"),
152    );
153
154    // Second pass: copy data for valid indices only
155    for data_idx in valid_indices {
156        let start = offsets[data_idx]
157            .to_usize()
158            .vortex_expect("Failed to cast max offset to usize");
159        let stop = offsets[data_idx + 1]
160            .to_usize()
161            .vortex_expect("Failed to cast max offset to usize");
162        new_data.extend_from_slice(&data[start..stop]);
163    }
164
165    let array_validity = Validity::from(validity_buffer.finish());
166
167    // Safety:
168    // All variants of VarBinArray are satisfied here.
169    unsafe {
170        VarBinArray::new_unchecked(
171            PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
172            new_data.freeze(),
173            dtype,
174            array_validity,
175        )
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use rstest::rstest;
182    use vortex_dtype::{DType, Nullability};
183
184    use crate::Array;
185    use crate::arrays::{PrimitiveArray, VarBinArray};
186    use crate::compute::conformance::take::test_take_conformance;
187    use crate::compute::take;
188
189    #[test]
190    fn test_null_take() {
191        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
192
193        let idx1: PrimitiveArray = (0..1).collect();
194
195        assert_eq!(
196            take(arr.as_ref(), idx1.as_ref()).unwrap().dtype(),
197            &DType::Utf8(Nullability::NonNullable)
198        );
199
200        let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
201
202        assert_eq!(
203            take(arr.as_ref(), idx2.as_ref()).unwrap().dtype(),
204            &DType::Utf8(Nullability::Nullable)
205        );
206    }
207
208    #[rstest]
209    #[case(VarBinArray::from_iter(
210        ["hello", "world", "test", "data", "array"].map(Some),
211        DType::Utf8(Nullability::NonNullable),
212    ))]
213    #[case(VarBinArray::from_iter(
214        [Some("hello"), None, Some("test"), Some("data"), None],
215        DType::Utf8(Nullability::Nullable),
216    ))]
217    #[case(VarBinArray::from_iter(
218        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
219        DType::Binary(Nullability::NonNullable),
220    ))]
221    #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
222    fn test_take_varbin_conformance(#[case] array: VarBinArray) {
223        test_take_conformance(array.as_ref());
224    }
225}