vortex_array/arrays/varbin/compute/
take.rs1use vortex_buffer::{BitBufferMut, BufferMut, ByteBufferMut};
5use vortex_dtype::{DType, IntegerPType, match_each_integer_ptype};
6use vortex_error::{VortexExpect, VortexResult, vortex_panic};
7use vortex_mask::Mask;
8
9use crate::arrays::varbin::VarBinArray;
10use crate::arrays::{PrimitiveArray, VarBinVTable};
11use crate::compute::{TakeKernel, TakeKernelAdapter};
12use crate::validity::Validity;
13use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
14
15impl TakeKernel for VarBinVTable {
16 fn take(&self, array: &VarBinArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
17 let offsets = array.offsets().to_primitive();
18 let data = array.bytes();
19 let indices = indices.to_primitive();
20 match_each_integer_ptype!(offsets.ptype(), |O| {
21 match_each_integer_ptype!(indices.ptype(), |I| {
22 Ok(take(
23 array
24 .dtype()
25 .clone()
26 .union_nullability(indices.dtype().nullability()),
27 offsets.as_slice::<O>(),
28 data.as_slice(),
29 indices.as_slice::<I>(),
30 array.validity_mask(),
31 indices.validity_mask(),
32 )?
33 .into_array())
34 })
35 })
36 }
37}
38
39register_kernel!(TakeKernelAdapter(VarBinVTable).lift());
40
41fn take<I: IntegerPType, O: IntegerPType>(
42 dtype: DType,
43 offsets: &[O],
44 data: &[u8],
45 indices: &[I],
46 validity_mask: Mask,
47 indices_validity_mask: Mask,
48) -> VortexResult<VarBinArray> {
49 if !validity_mask.all_true() || !indices_validity_mask.all_true() {
50 return Ok(take_nullable(
51 dtype,
52 offsets,
53 data,
54 indices,
55 validity_mask,
56 indices_validity_mask,
57 ));
58 }
59
60 let mut new_offsets = BufferMut::with_capacity(indices.len() + 1);
61 new_offsets.push(O::zero());
62 let mut current_offset = O::zero();
63
64 for &idx in indices {
65 let idx = idx
66 .to_usize()
67 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
68 let start = offsets[idx];
69 let stop = offsets[idx + 1];
70 current_offset += stop - start;
71 new_offsets.push(current_offset);
72 }
73
74 let mut new_data = ByteBufferMut::with_capacity(
75 current_offset
76 .to_usize()
77 .vortex_expect("Failed to cast max offset to usize"),
78 );
79
80 for idx in indices {
81 let idx = idx
82 .to_usize()
83 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
84 let start = offsets[idx]
85 .to_usize()
86 .vortex_expect("Failed to cast max offset to usize");
87 let stop = offsets[idx + 1]
88 .to_usize()
89 .vortex_expect("Failed to cast max offset to usize");
90 new_data.extend_from_slice(&data[start..stop]);
91 }
92
93 let array_validity = Validity::from(dtype.nullability());
94
95 unsafe {
98 Ok(VarBinArray::new_unchecked(
99 PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
100 new_data.freeze(),
101 dtype,
102 array_validity,
103 ))
104 }
105}
106
107fn take_nullable<I: IntegerPType, O: IntegerPType>(
108 dtype: DType,
109 offsets: &[O],
110 data: &[u8],
111 indices: &[I],
112 data_validity: Mask,
113 indices_validity: Mask,
114) -> VarBinArray {
115 let mut new_offsets = BufferMut::with_capacity(indices.len() + 1);
116 new_offsets.push(O::zero());
117 let mut current_offset = O::zero();
118
119 let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
120
121 let mut valid_indices = Vec::with_capacity(indices.len());
123
124 for (idx, data_idx) in indices.iter().enumerate() {
126 if !indices_validity.value(idx) {
127 validity_buffer.append(false);
128 new_offsets.push(current_offset);
129 continue;
130 }
131 let data_idx_usize = data_idx
132 .to_usize()
133 .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
134 if data_validity.value(data_idx_usize) {
135 validity_buffer.append(true);
136 let start = offsets[data_idx_usize];
137 let stop = offsets[data_idx_usize + 1];
138 current_offset += stop - start;
139 new_offsets.push(current_offset);
140 valid_indices.push(data_idx_usize);
141 } else {
142 validity_buffer.append(false);
143 new_offsets.push(current_offset);
144 }
145 }
146
147 let mut new_data = ByteBufferMut::with_capacity(
148 current_offset
149 .to_usize()
150 .vortex_expect("Failed to cast max offset to usize"),
151 );
152
153 for data_idx in valid_indices {
155 let start = offsets[data_idx]
156 .to_usize()
157 .vortex_expect("Failed to cast max offset to usize");
158 let stop = offsets[data_idx + 1]
159 .to_usize()
160 .vortex_expect("Failed to cast max offset to usize");
161 new_data.extend_from_slice(&data[start..stop]);
162 }
163
164 let array_validity = Validity::from(validity_buffer.freeze());
165
166 unsafe {
169 VarBinArray::new_unchecked(
170 PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
171 new_data.freeze(),
172 dtype,
173 array_validity,
174 )
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use rstest::rstest;
181 use vortex_dtype::{DType, Nullability};
182
183 use crate::Array;
184 use crate::arrays::{PrimitiveArray, VarBinArray};
185 use crate::compute::conformance::take::test_take_conformance;
186 use crate::compute::take;
187
188 #[test]
189 fn test_null_take() {
190 let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
191
192 let idx1: PrimitiveArray = (0..1).collect();
193
194 assert_eq!(
195 take(arr.as_ref(), idx1.as_ref()).unwrap().dtype(),
196 &DType::Utf8(Nullability::NonNullable)
197 );
198
199 let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
200
201 assert_eq!(
202 take(arr.as_ref(), idx2.as_ref()).unwrap().dtype(),
203 &DType::Utf8(Nullability::Nullable)
204 );
205 }
206
207 #[rstest]
208 #[case(VarBinArray::from_iter(
209 ["hello", "world", "test", "data", "array"].map(Some),
210 DType::Utf8(Nullability::NonNullable),
211 ))]
212 #[case(VarBinArray::from_iter(
213 [Some("hello"), None, Some("test"), Some("data"), None],
214 DType::Utf8(Nullability::Nullable),
215 ))]
216 #[case(VarBinArray::from_iter(
217 [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
218 DType::Binary(Nullability::NonNullable),
219 ))]
220 #[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
221 fn test_take_varbin_conformance(#[case] array: VarBinArray) {
222 test_take_conformance(array.as_ref());
223 }
224}