1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9use vortex_mask::{AllOr, Mask};
10
11use super::{ComputeFnVTable, InvocationArgs, Output, cast};
12use crate::builders::{ArrayBuilder, VarBinViewBuilder, builder_with_capacity};
13use crate::compute::{ComputeFn, Kernel};
14use crate::vtable::VTable;
15use crate::{Array, ArrayRef};
16
17pub fn zip(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
22 ZIP_FN
23 .invoke(&InvocationArgs {
24 inputs: &[if_true.into(), if_false.into(), mask.into()],
25 options: &(),
26 })?
27 .unwrap_array()
28}
29
30pub static ZIP_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
31 let compute = ComputeFn::new("zip".into(), ArcRef::new_ref(&Zip));
32 for kernel in inventory::iter::<ZipKernelRef> {
33 compute.register_kernel(kernel.0.clone());
34 }
35 compute
36});
37
38pub(crate) fn warm_up_vtable() -> usize {
39 ZIP_FN.kernels().len()
40}
41
42struct Zip;
43
44impl ComputeFnVTable for Zip {
45 fn invoke(
46 &self,
47 args: &InvocationArgs,
48 kernels: &[ArcRef<dyn Kernel>],
49 ) -> VortexResult<Output> {
50 let ZipArgs {
51 if_true,
52 if_false,
53 mask,
54 } = ZipArgs::try_from(args)?;
55
56 if mask.all_true() {
57 return Ok(cast(if_true, &zip_return_dtype(if_true, if_false))?.into());
58 }
59
60 if mask.all_false() {
61 return Ok(cast(if_false, &zip_return_dtype(if_true, if_false))?.into());
62 }
63
64 for kernel in kernels {
66 if let Some(output) = kernel.invoke(args)? {
67 return Ok(output);
68 }
69 }
70
71 if let Some(output) = if_true.invoke(&ZIP_FN, args)? {
72 return Ok(output);
73 }
74
75 if !if_true.is_canonical() || !if_false.is_canonical() {
79 return zip(
80 if_true.to_canonical().as_ref(),
81 if_false.to_canonical().as_ref(),
82 mask,
83 )
84 .map(Into::into);
85 }
86
87 Ok(zip_impl(
88 if_true.to_canonical().as_ref(),
89 if_false.to_canonical().as_ref(),
90 mask,
91 )?
92 .into())
93 }
94
95 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
96 let ZipArgs {
97 if_true, if_false, ..
98 } = ZipArgs::try_from(args)?;
99
100 if !if_true.dtype().eq_ignore_nullability(if_false.dtype()) {
101 vortex_bail!("input arrays to zip must have the same dtype");
102 }
103 Ok(zip_return_dtype(if_true, if_false))
104 }
105
106 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
107 let ZipArgs { if_true, mask, .. } = ZipArgs::try_from(args)?;
108 if if_true.len() != mask.len() {
110 vortex_bail!("input arrays must have the same length as the mask");
111 }
112 Ok(if_true.len())
113 }
114
115 fn is_elementwise(&self) -> bool {
116 true
117 }
118}
119
120struct ZipArgs<'a> {
121 if_true: &'a dyn Array,
122 if_false: &'a dyn Array,
123 mask: &'a Mask,
124}
125
126impl<'a> TryFrom<&InvocationArgs<'a>> for ZipArgs<'a> {
127 type Error = VortexError;
128
129 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
130 if value.inputs.len() != 3 {
131 vortex_bail!("Expected 3 inputs for zip, found {}", value.inputs.len());
132 }
133 let if_true = value.inputs[0]
134 .array()
135 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
136
137 let if_false = value.inputs[1]
138 .array()
139 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
140
141 let mask = value.inputs[2]
142 .mask()
143 .ok_or_else(|| vortex_err!("Expected input 2 to be a mask"))?;
144
145 Ok(Self {
146 if_true,
147 if_false,
148 mask,
149 })
150 }
151}
152
153pub trait ZipKernel: VTable {
154 fn zip(
155 &self,
156 if_true: &Self::Array,
157 if_false: &dyn Array,
158 mask: &Mask,
159 ) -> VortexResult<Option<ArrayRef>>;
160}
161
162pub struct ZipKernelRef(pub ArcRef<dyn Kernel>);
163inventory::collect!(ZipKernelRef);
164
165#[derive(Debug)]
166pub struct ZipKernelAdapter<V: VTable>(pub V);
167
168impl<V: VTable + ZipKernel> ZipKernelAdapter<V> {
169 pub const fn lift(&'static self) -> ZipKernelRef {
170 ZipKernelRef(ArcRef::new_ref(self))
171 }
172}
173
174impl<V: VTable + ZipKernel> Kernel for ZipKernelAdapter<V> {
175 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
176 let ZipArgs {
177 if_true,
178 if_false,
179 mask,
180 } = ZipArgs::try_from(args)?;
181 let Some(if_true) = if_true.as_opt::<V>() else {
182 return Ok(None);
183 };
184 Ok(V::zip(&self.0, if_true, if_false, mask)?.map(Into::into))
185 }
186}
187
188pub(crate) fn zip_return_dtype(if_true: &dyn Array, if_false: &dyn Array) -> DType {
189 if_true
190 .dtype()
191 .union_nullability(if_false.dtype().nullability())
192}
193
194fn zip_impl(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
195 assert_eq!(
196 if_true.len(),
197 if_false.len(),
198 "ComputeFn::invoke checks that arrays have the same size"
199 );
200
201 let return_type = zip_return_dtype(if_true, if_false);
202 let capacity = if_true.len();
203
204 let builder = match return_type {
205 DType::Utf8(n) => Box::new(VarBinViewBuilder::with_buffer_deduplication(
209 DType::Utf8(n),
210 capacity,
211 )),
212 DType::Binary(n) => Box::new(VarBinViewBuilder::with_buffer_deduplication(
213 DType::Binary(n),
214 capacity,
215 )),
216 _ => builder_with_capacity(&return_type, if_true.len()),
217 };
218
219 zip_impl_with_builder(if_true, if_false, mask, builder)
220}
221
222pub(crate) fn zip_impl_with_builder(
223 if_true: &dyn Array,
224 if_false: &dyn Array,
225 mask: &Mask,
226 mut builder: Box<dyn ArrayBuilder>,
227) -> VortexResult<ArrayRef> {
228 match mask.slices() {
229 AllOr::All => Ok(if_true.to_array()),
230 AllOr::None => Ok(if_false.to_array()),
231 AllOr::Some(slices) => {
232 for (start, end) in slices {
233 builder.extend_from_array(&if_false.slice(builder.len()..*start));
234 builder.extend_from_array(&if_true.slice(*start..*end));
235 }
236 if builder.len() < if_false.len() {
237 builder.extend_from_array(&if_false.slice(builder.len()..if_false.len()));
238 }
239 Ok(builder.finish())
240 }
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use arrow_array::cast::AsArray;
247 use arrow_select::zip::zip as arrow_zip;
248 use vortex_buffer::buffer;
249 use vortex_dtype::{DType, Nullability};
250 use vortex_mask::Mask;
251 use vortex_scalar::Scalar;
252
253 use crate::arrays::{ConstantArray, PrimitiveArray, StructArray, VarBinViewVTable};
254 use crate::arrow::IntoArrowArray;
255 use crate::builders::{ArrayBuilder, BufferGrowthStrategy};
256 use crate::compute::zip;
257 use crate::compute::zip::VarBinViewBuilder;
258 use crate::{Array, IntoArray, ToCanonical};
259
260 #[test]
261 fn test_zip_basic() {
262 let mask = Mask::from_iter([true, false, false, true, false]);
263 let if_true = buffer![10, 20, 30, 40, 50].into_array();
264 let if_false = buffer![1, 2, 3, 4, 5].into_array();
265
266 let result = zip(&if_true, &if_false, &mask).unwrap();
267 let expected = buffer![10, 2, 3, 40, 5].into_array();
268
269 assert_eq!(
270 result.to_primitive().as_slice::<i32>(),
271 expected.to_primitive().as_slice::<i32>()
272 );
273 }
274
275 #[test]
276 fn test_zip_all_true() {
277 let mask = Mask::new_true(4);
278 let if_true = buffer![10, 20, 30, 40].into_array();
279 let if_false =
280 PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
281
282 let result = zip(&if_true, &if_false, &mask).unwrap();
283
284 assert_eq!(
285 result.to_primitive().as_slice::<i32>(),
286 if_true.to_primitive().as_slice::<i32>()
287 );
288
289 assert_eq!(result.dtype(), if_false.dtype())
291 }
292
293 #[test]
294 #[should_panic]
295 fn test_invalid_lengths() {
296 let mask = Mask::new_false(4);
297 let if_true = buffer![10, 20, 30].into_array();
298 let if_false = buffer![1, 2, 3, 4].into_array();
299
300 zip(&if_true, &if_false, &mask).unwrap();
301 }
302
303 #[test]
304 fn test_fragmentation() {
305 let len = 100;
306
307 let const1 = ConstantArray::new(
308 Scalar::utf8("hello_this_is_a_longer_string", Nullability::Nullable),
309 len,
310 )
311 .to_array();
312
313 let const2 = ConstantArray::new(
314 Scalar::utf8("world_this_is_another_string", Nullability::Nullable),
315 len,
316 )
317 .to_array();
318
319 let indices: Vec<usize> = (0..len).step_by(2).collect();
322 let mask = Mask::from_indices(len, indices);
323
324 let result = zip(&const1, &const2, &mask).unwrap();
325
326 insta::assert_snapshot!(result.display_tree(), @r"
327 root: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%)
328 metadata: EmptyMetadata
329 buffer (align=1): 29 B (1.75%)
330 buffer (align=1): 28 B (1.69%)
331 buffer (align=16): 1.60 kB (96.56%)
332 ");
333
334 let wrapped1 = StructArray::try_from_iter([("nested", const1)])
336 .unwrap()
337 .to_array();
338 let wrapped2 = StructArray::try_from_iter([("nested", const2)])
339 .unwrap()
340 .to_array();
341
342 let wrapped_result = zip(&wrapped1, &wrapped2, &mask).unwrap();
343 insta::assert_snapshot!(wrapped_result.display_tree(), @r"
344 root: vortex.struct({nested=utf8?}, len=100) nbytes=1.66 kB (100.00%)
345 metadata: EmptyMetadata
346 nested: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%)
347 metadata: EmptyMetadata
348 buffer (align=1): 29 B (1.75%)
349 buffer (align=1): 28 B (1.69%)
350 buffer (align=16): 1.60 kB (96.56%)
351 ");
352 }
353
354 #[test]
355 fn test_varbinview_zip() {
356 let if_true = {
357 let mut builder = VarBinViewBuilder::new(
358 DType::Utf8(Nullability::NonNullable),
359 10,
360 Default::default(),
361 BufferGrowthStrategy::fixed(64 * 1024),
362 0.0,
363 );
364 for _ in 0..100 {
365 builder.append_value("Hello");
366 builder.append_value("Hello this is a long string that won't be inlined.");
367 }
368 builder.finish()
369 };
370
371 let if_false = {
372 let mut builder = VarBinViewBuilder::new(
373 DType::Utf8(Nullability::NonNullable),
374 10,
375 Default::default(),
376 BufferGrowthStrategy::fixed(64 * 1024),
377 0.0,
378 );
379 for _ in 0..100 {
380 builder.append_value("Hello2");
381 builder.append_value("Hello2 this is a long string that won't be inlined.");
382 }
383 builder.finish()
384 };
385
386 let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect());
388
389 let zipped = zip(&if_true, &if_false, &mask).unwrap();
390 let zipped = zipped.as_opt::<VarBinViewVTable>().unwrap();
391 assert_eq!(zipped.nbuffers(), 2);
392
393 let expected = arrow_zip(
395 mask.into_array()
396 .into_arrow_preferred()
397 .unwrap()
398 .as_boolean(),
399 &if_true.into_arrow_preferred().unwrap(),
400 &if_false.into_arrow_preferred().unwrap(),
401 )
402 .unwrap();
403
404 let actual = zipped.clone().into_array().into_arrow_preferred().unwrap();
405 assert_eq!(actual.as_ref(), expected.as_ref());
406 }
407}