vortex_array/arrays/varbin/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::iter::Sum;
5
6use num_traits::PrimInt;
7use vortex_dtype::{DType, NativePType, match_each_integer_ptype};
8use vortex_error::{VortexResult, vortex_err, vortex_panic};
9use vortex_mask::Mask;
10
11use crate::arrays::VarBinVTable;
12use crate::arrays::varbin::VarBinArray;
13use crate::arrays::varbin::builder::VarBinBuilder;
14use crate::compute::{TakeKernel, TakeKernelAdapter};
15use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
16
17impl TakeKernel for VarBinVTable {
18    fn take(&self, array: &VarBinArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
19        let offsets = array.offsets().to_primitive()?;
20        let data = array.bytes();
21        let indices = indices.to_primitive()?;
22        match_each_integer_ptype!(offsets.ptype(), |O| {
23            match_each_integer_ptype!(indices.ptype(), |I| {
24                Ok(take(
25                    array
26                        .dtype()
27                        .clone()
28                        .union_nullability(indices.dtype().nullability()),
29                    offsets.as_slice::<O>(),
30                    data.as_slice(),
31                    indices.as_slice::<I>(),
32                    array.validity_mask()?,
33                    indices.validity_mask()?,
34                )?
35                .into_array())
36            })
37        })
38    }
39}
40
41register_kernel!(TakeKernelAdapter(VarBinVTable).lift());
42
43fn take<I: NativePType, O: NativePType + PrimInt + Sum>(
44    dtype: DType,
45    offsets: &[O],
46    data: &[u8],
47    indices: &[I],
48    validity_mask: Mask,
49    indices_validity_mask: Mask,
50) -> VortexResult<VarBinArray> {
51    if !validity_mask.all_true() || !indices_validity_mask.all_true() {
52        return Ok(take_nullable(
53            dtype,
54            offsets,
55            data,
56            indices,
57            validity_mask,
58            indices_validity_mask,
59        ));
60    }
61
62    let mut builder = VarBinBuilder::<u32>::with_capacity(indices.len());
63    for &idx in indices {
64        let idx = idx
65            .to_usize()
66            .ok_or_else(|| vortex_err!("Failed to convert index to usize: {}", idx))?;
67        let start = offsets[idx]
68            .to_usize()
69            .ok_or_else(|| vortex_err!("Failed to convert offset to usize: {}", offsets[idx]))?;
70        let stop = offsets[idx + 1].to_usize().ok_or_else(|| {
71            vortex_err!("Failed to convert offset to usize: {}", offsets[idx + 1])
72        })?;
73        builder.append_value(&data[start..stop]);
74    }
75    Ok(builder.finish(dtype))
76}
77
78fn take_nullable<I: NativePType, O: NativePType + PrimInt>(
79    dtype: DType,
80    offsets: &[O],
81    data: &[u8],
82    indices: &[I],
83    data_validity: Mask,
84    indices_validity: Mask,
85) -> VarBinArray {
86    let mut builder = VarBinBuilder::<u32>::with_capacity(indices.len());
87    for (idx, data_idx) in indices.iter().enumerate() {
88        if !indices_validity.value(idx) {
89            builder.append_null();
90            continue;
91        }
92        let data_idx = data_idx
93            .to_usize()
94            .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
95        if data_validity.value(data_idx) {
96            let start = offsets[data_idx].to_usize().unwrap_or_else(|| {
97                vortex_panic!("Failed to convert offset to usize: {}", offsets[data_idx])
98            });
99            let stop = offsets[data_idx + 1].to_usize().unwrap_or_else(|| {
100                vortex_panic!(
101                    "Failed to convert offset to usize: {}",
102                    offsets[data_idx + 1]
103                )
104            });
105            builder.append_value(&data[start..stop]);
106        } else {
107            builder.append_null();
108        }
109    }
110    builder.finish(dtype)
111}
112
113#[cfg(test)]
114mod tests {
115    use rstest::rstest;
116    use vortex_dtype::{DType, Nullability};
117
118    use crate::Array;
119    use crate::arrays::{PrimitiveArray, VarBinArray};
120    use crate::compute::conformance::take::test_take_conformance;
121    use crate::compute::take;
122
123    #[test]
124    fn test_null_take() {
125        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
126
127        let idx1: PrimitiveArray = (0..1).collect();
128
129        assert_eq!(
130            take(arr.as_ref(), idx1.as_ref()).unwrap().dtype(),
131            &DType::Utf8(Nullability::NonNullable)
132        );
133
134        let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
135
136        assert_eq!(
137            take(arr.as_ref(), idx2.as_ref()).unwrap().dtype(),
138            &DType::Utf8(Nullability::Nullable)
139        );
140    }
141
142    #[rstest]
143    #[case(VarBinArray::from_iter(
144        ["hello", "world", "test", "data", "array"].map(Some),
145        DType::Utf8(Nullability::NonNullable),
146    ))]
147    #[case(VarBinArray::from_iter(
148        [Some("hello"), None, Some("test"), Some("data"), None],
149        DType::Utf8(Nullability::Nullable),
150    ))]
151    #[case(VarBinArray::from_iter(
152        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
153        DType::Binary(Nullability::NonNullable),
154    ))]
155    #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
156    fn test_take_varbin_conformance(#[case] array: VarBinArray) {
157        test_take_conformance(array.as_ref());
158    }
159}