vortex_array/arrays/varbin/compute/
take.rs1use arrow_buffer::BooleanBufferBuilder;
5use vortex_buffer::{BufferMut, ByteBufferMut};
6use vortex_dtype::{DType, IntegerPType, match_each_integer_ptype};
7use vortex_error::{VortexExpect, VortexResult, vortex_panic};
8use vortex_mask::Mask;
9
10use crate::arrays::varbin::VarBinArray;
11use crate::arrays::{PrimitiveArray, VarBinVTable};
12use crate::compute::{TakeKernel, TakeKernelAdapter};
13use crate::validity::Validity;
14use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
15
16impl TakeKernel for VarBinVTable {
17 fn take(&self, array: &VarBinArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
18 let offsets = array.offsets().to_primitive();
19 let data = array.bytes();
20 let indices = indices.to_primitive();
21 match_each_integer_ptype!(offsets.ptype(), |O| {
22 match_each_integer_ptype!(indices.ptype(), |I| {
23 Ok(take(
24 array
25 .dtype()
26 .clone()
27 .union_nullability(indices.dtype().nullability()),
28 offsets.as_slice::<O>(),
29 data.as_slice(),
30 indices.as_slice::<I>(),
31 array.validity_mask(),
32 indices.validity_mask(),
33 )?
34 .into_array())
35 })
36 })
37 }
38}
39
40register_kernel!(TakeKernelAdapter(VarBinVTable).lift());
41
42fn take<I: IntegerPType, O: IntegerPType>(
43 dtype: DType,
44 offsets: &[O],
45 data: &[u8],
46 indices: &[I],
47 validity_mask: Mask,
48 indices_validity_mask: Mask,
49) -> VortexResult<VarBinArray> {
50 if !validity_mask.all_true() || !indices_validity_mask.all_true() {
51 return Ok(take_nullable(
52 dtype,
53 offsets,
54 data,
55 indices,
56 validity_mask,
57 indices_validity_mask,
58 ));
59 }
60
61 let mut new_offsets = BufferMut::with_capacity(indices.len() + 1);
62 new_offsets.push(O::zero());
63 let mut current_offset = O::zero();
64
65 for &idx in indices {
66 let idx = idx
67 .to_usize()
68 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
69 let start = offsets[idx];
70 let stop = offsets[idx + 1];
71 current_offset += stop - start;
72 new_offsets.push(current_offset);
73 }
74
75 let mut new_data = ByteBufferMut::with_capacity(
76 current_offset
77 .to_usize()
78 .vortex_expect("Failed to cast max offset to usize"),
79 );
80
81 for idx in indices {
82 let idx = idx
83 .to_usize()
84 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
85 let start = offsets[idx]
86 .to_usize()
87 .vortex_expect("Failed to cast max offset to usize");
88 let stop = offsets[idx + 1]
89 .to_usize()
90 .vortex_expect("Failed to cast max offset to usize");
91 new_data.extend_from_slice(&data[start..stop]);
92 }
93
94 let array_validity = Validity::from(dtype.nullability());
95
96 unsafe {
99 Ok(VarBinArray::new_unchecked(
100 PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
101 new_data.freeze(),
102 dtype,
103 array_validity,
104 ))
105 }
106}
107
108fn take_nullable<I: IntegerPType, O: IntegerPType>(
109 dtype: DType,
110 offsets: &[O],
111 data: &[u8],
112 indices: &[I],
113 data_validity: Mask,
114 indices_validity: Mask,
115) -> VarBinArray {
116 let mut new_offsets = BufferMut::with_capacity(indices.len() + 1);
117 new_offsets.push(O::zero());
118 let mut current_offset = O::zero();
119
120 let mut validity_buffer = BooleanBufferBuilder::new(indices.len());
121
122 let mut valid_indices = Vec::with_capacity(indices.len());
124
125 for (idx, data_idx) in indices.iter().enumerate() {
127 if !indices_validity.value(idx) {
128 validity_buffer.append(false);
129 new_offsets.push(current_offset);
130 continue;
131 }
132 let data_idx_usize = data_idx
133 .to_usize()
134 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
135 if data_validity.value(data_idx_usize) {
136 validity_buffer.append(true);
137 let start = offsets[data_idx_usize];
138 let stop = offsets[data_idx_usize + 1];
139 current_offset += stop - start;
140 new_offsets.push(current_offset);
141 valid_indices.push(data_idx_usize);
142 } else {
143 validity_buffer.append(false);
144 new_offsets.push(current_offset);
145 }
146 }
147
148 let mut new_data = ByteBufferMut::with_capacity(
149 current_offset
150 .to_usize()
151 .vortex_expect("Failed to cast max offset to usize"),
152 );
153
154 for data_idx in valid_indices {
156 let start = offsets[data_idx]
157 .to_usize()
158 .vortex_expect("Failed to cast max offset to usize");
159 let stop = offsets[data_idx + 1]
160 .to_usize()
161 .vortex_expect("Failed to cast max offset to usize");
162 new_data.extend_from_slice(&data[start..stop]);
163 }
164
165 let array_validity = Validity::from(validity_buffer.finish());
166
167 unsafe {
170 VarBinArray::new_unchecked(
171 PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
172 new_data.freeze(),
173 dtype,
174 array_validity,
175 )
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use rstest::rstest;
182 use vortex_dtype::{DType, Nullability};
183
184 use crate::Array;
185 use crate::arrays::{PrimitiveArray, VarBinArray};
186 use crate::compute::conformance::take::test_take_conformance;
187 use crate::compute::take;
188
189 #[test]
190 fn test_null_take() {
191 let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
192
193 let idx1: PrimitiveArray = (0..1).collect();
194
195 assert_eq!(
196 take(arr.as_ref(), idx1.as_ref()).unwrap().dtype(),
197 &DType::Utf8(Nullability::NonNullable)
198 );
199
200 let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
201
202 assert_eq!(
203 take(arr.as_ref(), idx2.as_ref()).unwrap().dtype(),
204 &DType::Utf8(Nullability::Nullable)
205 );
206 }
207
208 #[rstest]
209 #[case(VarBinArray::from_iter(
210 ["hello", "world", "test", "data", "array"].map(Some),
211 DType::Utf8(Nullability::NonNullable),
212 ))]
213 #[case(VarBinArray::from_iter(
214 [Some("hello"), None, Some("test"), Some("data"), None],
215 DType::Utf8(Nullability::Nullable),
216 ))]
217 #[case(VarBinArray::from_iter(
218 [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
219 DType::Binary(Nullability::NonNullable),
220 ))]
221 #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
222 fn test_take_varbin_conformance(#[case] array: VarBinArray) {
223 test_take_conformance(array.as_ref());
224 }
225}