vortex_array/arrays/varbin/compute/
take.rs1use std::iter::Sum;
5
6use num_traits::PrimInt;
7use vortex_dtype::{DType, NativePType, match_each_integer_ptype};
8use vortex_error::{VortexResult, vortex_err, vortex_panic};
9use vortex_mask::Mask;
10
11use crate::arrays::VarBinVTable;
12use crate::arrays::varbin::VarBinArray;
13use crate::arrays::varbin::builder::VarBinBuilder;
14use crate::compute::{TakeKernel, TakeKernelAdapter};
15use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
16
17impl TakeKernel for VarBinVTable {
18 fn take(&self, array: &VarBinArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
19 let offsets = array.offsets().to_primitive()?;
20 let data = array.bytes();
21 let indices = indices.to_primitive()?;
22 match_each_integer_ptype!(offsets.ptype(), |O| {
23 match_each_integer_ptype!(indices.ptype(), |I| {
24 Ok(take(
25 array
26 .dtype()
27 .clone()
28 .union_nullability(indices.dtype().nullability()),
29 offsets.as_slice::<O>(),
30 data.as_slice(),
31 indices.as_slice::<I>(),
32 array.validity_mask()?,
33 indices.validity_mask()?,
34 )?
35 .into_array())
36 })
37 })
38 }
39}
40
41register_kernel!(TakeKernelAdapter(VarBinVTable).lift());
42
43fn take<I: NativePType, O: NativePType + PrimInt + Sum>(
44 dtype: DType,
45 offsets: &[O],
46 data: &[u8],
47 indices: &[I],
48 validity_mask: Mask,
49 indices_validity_mask: Mask,
50) -> VortexResult<VarBinArray> {
51 if !validity_mask.all_true() || !indices_validity_mask.all_true() {
52 return Ok(take_nullable(
53 dtype,
54 offsets,
55 data,
56 indices,
57 validity_mask,
58 indices_validity_mask,
59 ));
60 }
61
62 let mut builder = VarBinBuilder::<u32>::with_capacity(indices.len());
63 for &idx in indices {
64 let idx = idx
65 .to_usize()
66 .ok_or_else(|| vortex_err!("Failed to convert index to usize: {}", idx))?;
67 let start = offsets[idx]
68 .to_usize()
69 .ok_or_else(|| vortex_err!("Failed to convert offset to usize: {}", offsets[idx]))?;
70 let stop = offsets[idx + 1].to_usize().ok_or_else(|| {
71 vortex_err!("Failed to convert offset to usize: {}", offsets[idx + 1])
72 })?;
73 builder.append_value(&data[start..stop]);
74 }
75 Ok(builder.finish(dtype))
76}
77
78fn take_nullable<I: NativePType, O: NativePType + PrimInt>(
79 dtype: DType,
80 offsets: &[O],
81 data: &[u8],
82 indices: &[I],
83 data_validity: Mask,
84 indices_validity: Mask,
85) -> VarBinArray {
86 let mut builder = VarBinBuilder::<u32>::with_capacity(indices.len());
87 for (idx, data_idx) in indices.iter().enumerate() {
88 if !indices_validity.value(idx) {
89 builder.append_null();
90 continue;
91 }
92 let data_idx = data_idx
93 .to_usize()
94 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
95 if data_validity.value(data_idx) {
96 let start = offsets[data_idx].to_usize().unwrap_or_else(|| {
97 vortex_panic!("Failed to convert offset to usize: {}", offsets[data_idx])
98 });
99 let stop = offsets[data_idx + 1].to_usize().unwrap_or_else(|| {
100 vortex_panic!(
101 "Failed to convert offset to usize: {}",
102 offsets[data_idx + 1]
103 )
104 });
105 builder.append_value(&data[start..stop]);
106 } else {
107 builder.append_null();
108 }
109 }
110 builder.finish(dtype)
111}
112
113#[cfg(test)]
114mod tests {
115 use rstest::rstest;
116 use vortex_dtype::{DType, Nullability};
117
118 use crate::Array;
119 use crate::arrays::{PrimitiveArray, VarBinArray};
120 use crate::compute::conformance::take::test_take_conformance;
121 use crate::compute::take;
122
123 #[test]
124 fn test_null_take() {
125 let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
126
127 let idx1: PrimitiveArray = (0..1).collect();
128
129 assert_eq!(
130 take(arr.as_ref(), idx1.as_ref()).unwrap().dtype(),
131 &DType::Utf8(Nullability::NonNullable)
132 );
133
134 let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
135
136 assert_eq!(
137 take(arr.as_ref(), idx2.as_ref()).unwrap().dtype(),
138 &DType::Utf8(Nullability::Nullable)
139 );
140 }
141
142 #[rstest]
143 #[case(VarBinArray::from_iter(
144 ["hello", "world", "test", "data", "array"].map(Some),
145 DType::Utf8(Nullability::NonNullable),
146 ))]
147 #[case(VarBinArray::from_iter(
148 [Some("hello"), None, Some("test"), Some("data"), None],
149 DType::Utf8(Nullability::Nullable),
150 ))]
151 #[case(VarBinArray::from_iter(
152 [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
153 DType::Binary(Nullability::NonNullable),
154 ))]
155 #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
156 fn test_take_varbin_conformance(#[case] array: VarBinArray) {
157 test_take_conformance(array.as_ref());
158 }
159}