vortex_array/arrays/varbin/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::{BitBufferMut, BufferMut, ByteBufferMut};
5use vortex_dtype::{DType, IntegerPType, match_each_integer_ptype};
6use vortex_error::{VortexExpect, VortexResult, vortex_panic};
7use vortex_mask::Mask;
8
9use crate::arrays::varbin::VarBinArray;
10use crate::arrays::{PrimitiveArray, VarBinVTable};
11use crate::compute::{TakeKernel, TakeKernelAdapter};
12use crate::validity::Validity;
13use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
14
15impl TakeKernel for VarBinVTable {
16    fn take(&self, array: &VarBinArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
17        let offsets = array.offsets().to_primitive();
18        let data = array.bytes();
19        let indices = indices.to_primitive();
20        match_each_integer_ptype!(offsets.ptype(), |O| {
21            match_each_integer_ptype!(indices.ptype(), |I| {
22                Ok(take(
23                    array
24                        .dtype()
25                        .clone()
26                        .union_nullability(indices.dtype().nullability()),
27                    offsets.as_slice::<O>(),
28                    data.as_slice(),
29                    indices.as_slice::<I>(),
30                    array.validity_mask(),
31                    indices.validity_mask(),
32                )?
33                .into_array())
34            })
35        })
36    }
37}
38
39register_kernel!(TakeKernelAdapter(VarBinVTable).lift());
40
41fn take<I: IntegerPType, O: IntegerPType>(
42    dtype: DType,
43    offsets: &[O],
44    data: &[u8],
45    indices: &[I],
46    validity_mask: Mask,
47    indices_validity_mask: Mask,
48) -> VortexResult<VarBinArray> {
49    if !validity_mask.all_true() || !indices_validity_mask.all_true() {
50        return Ok(take_nullable(
51            dtype,
52            offsets,
53            data,
54            indices,
55            validity_mask,
56            indices_validity_mask,
57        ));
58    }
59
60    let mut new_offsets = BufferMut::with_capacity(indices.len() + 1);
61    new_offsets.push(O::zero());
62    let mut current_offset = O::zero();
63
64    for &idx in indices {
65        let idx = idx
66            .to_usize()
67            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
68        let start = offsets[idx];
69        let stop = offsets[idx + 1];
70        current_offset += stop - start;
71        new_offsets.push(current_offset);
72    }
73
74    let mut new_data = ByteBufferMut::with_capacity(
75        current_offset
76            .to_usize()
77            .vortex_expect("Failed to cast max offset to usize"),
78    );
79
80    for idx in indices {
81        let idx = idx
82            .to_usize()
83            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
84        let start = offsets[idx]
85            .to_usize()
86            .vortex_expect("Failed to cast max offset to usize");
87        let stop = offsets[idx + 1]
88            .to_usize()
89            .vortex_expect("Failed to cast max offset to usize");
90        new_data.extend_from_slice(&data[start..stop]);
91    }
92
93    let array_validity = Validity::from(dtype.nullability());
94
95    // Safety:
96    // All variants of VarBinArray are satisfied here.
97    unsafe {
98        Ok(VarBinArray::new_unchecked(
99            PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
100            new_data.freeze(),
101            dtype,
102            array_validity,
103        ))
104    }
105}
106
107fn take_nullable<I: IntegerPType, O: IntegerPType>(
108    dtype: DType,
109    offsets: &[O],
110    data: &[u8],
111    indices: &[I],
112    data_validity: Mask,
113    indices_validity: Mask,
114) -> VarBinArray {
115    let mut new_offsets = BufferMut::with_capacity(indices.len() + 1);
116    new_offsets.push(O::zero());
117    let mut current_offset = O::zero();
118
119    let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
120
121    // Convert indices once and store valid ones with their positions
122    let mut valid_indices = Vec::with_capacity(indices.len());
123
124    // First pass: calculate offsets and validity
125    for (idx, data_idx) in indices.iter().enumerate() {
126        if !indices_validity.value(idx) {
127            validity_buffer.append(false);
128            new_offsets.push(current_offset);
129            continue;
130        }
131        let data_idx_usize = data_idx
132            .to_usize()
133            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
134        if data_validity.value(data_idx_usize) {
135            validity_buffer.append(true);
136            let start = offsets[data_idx_usize];
137            let stop = offsets[data_idx_usize + 1];
138            current_offset += stop - start;
139            new_offsets.push(current_offset);
140            valid_indices.push(data_idx_usize);
141        } else {
142            validity_buffer.append(false);
143            new_offsets.push(current_offset);
144        }
145    }
146
147    let mut new_data = ByteBufferMut::with_capacity(
148        current_offset
149            .to_usize()
150            .vortex_expect("Failed to cast max offset to usize"),
151    );
152
153    // Second pass: copy data for valid indices only
154    for data_idx in valid_indices {
155        let start = offsets[data_idx]
156            .to_usize()
157            .vortex_expect("Failed to cast max offset to usize");
158        let stop = offsets[data_idx + 1]
159            .to_usize()
160            .vortex_expect("Failed to cast max offset to usize");
161        new_data.extend_from_slice(&data[start..stop]);
162    }
163
164    let array_validity = Validity::from(validity_buffer.freeze());
165
166    // Safety:
167    // All variants of VarBinArray are satisfied here.
168    unsafe {
169        VarBinArray::new_unchecked(
170            PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
171            new_data.freeze(),
172            dtype,
173            array_validity,
174        )
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use rstest::rstest;
181    use vortex_dtype::{DType, Nullability};
182
183    use crate::Array;
184    use crate::arrays::{PrimitiveArray, VarBinArray};
185    use crate::compute::conformance::take::test_take_conformance;
186    use crate::compute::take;
187
188    #[test]
189    fn test_null_take() {
190        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
191
192        let idx1: PrimitiveArray = (0..1).collect();
193
194        assert_eq!(
195            take(arr.as_ref(), idx1.as_ref()).unwrap().dtype(),
196            &DType::Utf8(Nullability::NonNullable)
197        );
198
199        let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
200
201        assert_eq!(
202            take(arr.as_ref(), idx2.as_ref()).unwrap().dtype(),
203            &DType::Utf8(Nullability::Nullable)
204        );
205    }
206
207    #[rstest]
208    #[case(VarBinArray::from_iter(
209        ["hello", "world", "test", "data", "array"].map(Some),
210        DType::Utf8(Nullability::NonNullable),
211    ))]
212    #[case(VarBinArray::from_iter(
213        [Some("hello"), None, Some("test"), Some("data"), None],
214        DType::Utf8(Nullability::Nullable),
215    ))]
216    #[case(VarBinArray::from_iter(
217        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
218        DType::Binary(Nullability::NonNullable),
219    ))]
220    #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
221    fn test_take_varbin_conformance(#[case] array: VarBinArray) {
222        test_take_conformance(array.as_ref());
223    }
224}