vortex_array/arrays/varbin/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::PrimInt;
5use vortex_dtype::{DType, NativePType, match_each_integer_ptype};
6use vortex_error::{VortexResult, vortex_err, vortex_panic};
7use vortex_mask::Mask;
8
9use crate::arrays::VarBinVTable;
10use crate::arrays::varbin::VarBinArray;
11use crate::arrays::varbin::builder::VarBinBuilder;
12use crate::compute::{TakeKernel, TakeKernelAdapter};
13use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
14
15impl TakeKernel for VarBinVTable {
16    fn take(&self, array: &VarBinArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
17        let offsets = array.offsets().to_primitive()?;
18        let data = array.bytes();
19        let indices = indices.to_primitive()?;
20        match_each_integer_ptype!(offsets.ptype(), |O| {
21            match_each_integer_ptype!(indices.ptype(), |I| {
22                Ok(take(
23                    array
24                        .dtype()
25                        .clone()
26                        .union_nullability(indices.dtype().nullability()),
27                    offsets.as_slice::<O>(),
28                    data.as_slice(),
29                    indices.as_slice::<I>(),
30                    array.validity_mask()?,
31                    indices.validity_mask()?,
32                )?
33                .into_array())
34            })
35        })
36    }
37}
38
39register_kernel!(TakeKernelAdapter(VarBinVTable).lift());
40
41fn take<I: NativePType, O: NativePType + PrimInt>(
42    dtype: DType,
43    offsets: &[O],
44    data: &[u8],
45    indices: &[I],
46    validity_mask: Mask,
47    indices_validity_mask: Mask,
48) -> VortexResult<VarBinArray> {
49    if !validity_mask.all_true() || !indices_validity_mask.all_true() {
50        return Ok(take_nullable(
51            dtype,
52            offsets,
53            data,
54            indices,
55            validity_mask,
56            indices_validity_mask,
57        ));
58    }
59
60    let mut builder = VarBinBuilder::<O>::with_capacity(indices.len());
61    for &idx in indices {
62        let idx = idx
63            .to_usize()
64            .ok_or_else(|| vortex_err!("Failed to convert index to usize: {}", idx))?;
65        let start = offsets[idx]
66            .to_usize()
67            .ok_or_else(|| vortex_err!("Failed to convert offset to usize: {}", offsets[idx]))?;
68        let stop = offsets[idx + 1].to_usize().ok_or_else(|| {
69            vortex_err!("Failed to convert offset to usize: {}", offsets[idx + 1])
70        })?;
71        builder.append_value(&data[start..stop]);
72    }
73    Ok(builder.finish(dtype))
74}
75
76fn take_nullable<I: NativePType, O: NativePType + PrimInt>(
77    dtype: DType,
78    offsets: &[O],
79    data: &[u8],
80    indices: &[I],
81    data_validity: Mask,
82    indices_validity: Mask,
83) -> VarBinArray {
84    let mut builder = VarBinBuilder::<O>::with_capacity(indices.len());
85    for (idx, data_idx) in indices.iter().enumerate() {
86        if !indices_validity.value(idx) {
87            builder.append_null();
88            continue;
89        }
90        let data_idx = data_idx
91            .to_usize()
92            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
93        if data_validity.value(data_idx) {
94            let start = offsets[data_idx].to_usize().unwrap_or_else(|| {
95                vortex_panic!("Failed to convert offset to usize: {}", offsets[data_idx])
96            });
97            let stop = offsets[data_idx + 1].to_usize().unwrap_or_else(|| {
98                vortex_panic!(
99                    "Failed to convert offset to usize: {}",
100                    offsets[data_idx + 1]
101                )
102            });
103            builder.append_value(&data[start..stop]);
104        } else {
105            builder.append_null();
106        }
107    }
108    builder.finish(dtype)
109}
110
111#[cfg(test)]
112mod tests {
113    use rstest::rstest;
114    use vortex_dtype::{DType, Nullability};
115
116    use crate::Array;
117    use crate::arrays::{PrimitiveArray, VarBinArray};
118    use crate::compute::conformance::take::test_take_conformance;
119    use crate::compute::take;
120
121    #[test]
122    fn test_null_take() {
123        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
124
125        let idx1: PrimitiveArray = (0..1).collect();
126
127        assert_eq!(
128            take(arr.as_ref(), idx1.as_ref()).unwrap().dtype(),
129            &DType::Utf8(Nullability::NonNullable)
130        );
131
132        let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
133
134        assert_eq!(
135            take(arr.as_ref(), idx2.as_ref()).unwrap().dtype(),
136            &DType::Utf8(Nullability::Nullable)
137        );
138    }
139
140    #[rstest]
141    #[case(VarBinArray::from_iter(
142        ["hello", "world", "test", "data", "array"].map(Some),
143        DType::Utf8(Nullability::NonNullable),
144    ))]
145    #[case(VarBinArray::from_iter(
146        [Some("hello"), None, Some("test"), Some("data"), None],
147        DType::Utf8(Nullability::Nullable),
148    ))]
149    #[case(VarBinArray::from_iter(
150        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
151        DType::Binary(Nullability::NonNullable),
152    ))]
153    #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
154    fn test_take_varbin_conformance(#[case] array: VarBinArray) {
155        test_take_conformance(array.as_ref());
156    }
157}