vortex_array/arrays/varbin/compute/
take.rs1use num_traits::PrimInt;
5use vortex_dtype::{DType, NativePType, match_each_integer_ptype};
6use vortex_error::{VortexResult, vortex_err, vortex_panic};
7use vortex_mask::Mask;
8
9use crate::arrays::VarBinVTable;
10use crate::arrays::varbin::VarBinArray;
11use crate::arrays::varbin::builder::VarBinBuilder;
12use crate::compute::{TakeKernel, TakeKernelAdapter};
13use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
14
15impl TakeKernel for VarBinVTable {
16 fn take(&self, array: &VarBinArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
17 let offsets = array.offsets().to_primitive()?;
18 let data = array.bytes();
19 let indices = indices.to_primitive()?;
20 match_each_integer_ptype!(offsets.ptype(), |O| {
21 match_each_integer_ptype!(indices.ptype(), |I| {
22 Ok(take(
23 array
24 .dtype()
25 .clone()
26 .union_nullability(indices.dtype().nullability()),
27 offsets.as_slice::<O>(),
28 data.as_slice(),
29 indices.as_slice::<I>(),
30 array.validity_mask()?,
31 indices.validity_mask()?,
32 )?
33 .into_array())
34 })
35 })
36 }
37}
38
39register_kernel!(TakeKernelAdapter(VarBinVTable).lift());
40
41fn take<I: NativePType, O: NativePType + PrimInt>(
42 dtype: DType,
43 offsets: &[O],
44 data: &[u8],
45 indices: &[I],
46 validity_mask: Mask,
47 indices_validity_mask: Mask,
48) -> VortexResult<VarBinArray> {
49 if !validity_mask.all_true() || !indices_validity_mask.all_true() {
50 return Ok(take_nullable(
51 dtype,
52 offsets,
53 data,
54 indices,
55 validity_mask,
56 indices_validity_mask,
57 ));
58 }
59
60 let mut builder = VarBinBuilder::<O>::with_capacity(indices.len());
61 for &idx in indices {
62 let idx = idx
63 .to_usize()
64 .ok_or_else(|| vortex_err!("Failed to convert index to usize: {}", idx))?;
65 let start = offsets[idx]
66 .to_usize()
67 .ok_or_else(|| vortex_err!("Failed to convert offset to usize: {}", offsets[idx]))?;
68 let stop = offsets[idx + 1].to_usize().ok_or_else(|| {
69 vortex_err!("Failed to convert offset to usize: {}", offsets[idx + 1])
70 })?;
71 builder.append_value(&data[start..stop]);
72 }
73 Ok(builder.finish(dtype))
74}
75
76fn take_nullable<I: NativePType, O: NativePType + PrimInt>(
77 dtype: DType,
78 offsets: &[O],
79 data: &[u8],
80 indices: &[I],
81 data_validity: Mask,
82 indices_validity: Mask,
83) -> VarBinArray {
84 let mut builder = VarBinBuilder::<O>::with_capacity(indices.len());
85 for (idx, data_idx) in indices.iter().enumerate() {
86 if !indices_validity.value(idx) {
87 builder.append_null();
88 continue;
89 }
90 let data_idx = data_idx
91 .to_usize()
92 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
93 if data_validity.value(data_idx) {
94 let start = offsets[data_idx].to_usize().unwrap_or_else(|| {
95 vortex_panic!("Failed to convert offset to usize: {}", offsets[data_idx])
96 });
97 let stop = offsets[data_idx + 1].to_usize().unwrap_or_else(|| {
98 vortex_panic!(
99 "Failed to convert offset to usize: {}",
100 offsets[data_idx + 1]
101 )
102 });
103 builder.append_value(&data[start..stop]);
104 } else {
105 builder.append_null();
106 }
107 }
108 builder.finish(dtype)
109}
110
111#[cfg(test)]
112mod tests {
113 use rstest::rstest;
114 use vortex_dtype::{DType, Nullability};
115
116 use crate::Array;
117 use crate::arrays::{PrimitiveArray, VarBinArray};
118 use crate::compute::conformance::take::test_take_conformance;
119 use crate::compute::take;
120
121 #[test]
122 fn test_null_take() {
123 let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
124
125 let idx1: PrimitiveArray = (0..1).collect();
126
127 assert_eq!(
128 take(arr.as_ref(), idx1.as_ref()).unwrap().dtype(),
129 &DType::Utf8(Nullability::NonNullable)
130 );
131
132 let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
133
134 assert_eq!(
135 take(arr.as_ref(), idx2.as_ref()).unwrap().dtype(),
136 &DType::Utf8(Nullability::Nullable)
137 );
138 }
139
140 #[rstest]
141 #[case(VarBinArray::from_iter(
142 ["hello", "world", "test", "data", "array"].map(Some),
143 DType::Utf8(Nullability::NonNullable),
144 ))]
145 #[case(VarBinArray::from_iter(
146 [Some("hello"), None, Some("test"), Some("data"), None],
147 DType::Utf8(Nullability::Nullable),
148 ))]
149 #[case(VarBinArray::from_iter(
150 [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
151 DType::Binary(Nullability::NonNullable),
152 ))]
153 #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
154 fn test_take_varbin_conformance(#[case] array: VarBinArray) {
155 test_take_conformance(array.as_ref());
156 }
157}