1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::VortexError;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12use vortex_mask::AllOr;
13use vortex_mask::Mask;
14
15use super::ComputeFnVTable;
16use super::InvocationArgs;
17use super::Output;
18use super::cast;
19use crate::Array;
20use crate::ArrayRef;
21use crate::builders::ArrayBuilder;
22use crate::builders::builder_with_capacity;
23use crate::compute::ComputeFn;
24use crate::compute::Kernel;
25use crate::vtable::VTable;
26
27pub fn zip(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
32 ZIP_FN
33 .invoke(&InvocationArgs {
34 inputs: &[if_true.into(), if_false.into(), mask.into()],
35 options: &(),
36 })?
37 .unwrap_array()
38}
39
40pub static ZIP_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
41 let compute = ComputeFn::new("zip".into(), ArcRef::new_ref(&Zip));
42 for kernel in inventory::iter::<ZipKernelRef> {
43 compute.register_kernel(kernel.0.clone());
44 }
45 compute
46});
47
48pub(crate) fn warm_up_vtable() -> usize {
49 ZIP_FN.kernels().len()
50}
51
52struct Zip;
53
54impl ComputeFnVTable for Zip {
55 fn invoke(
56 &self,
57 args: &InvocationArgs,
58 kernels: &[ArcRef<dyn Kernel>],
59 ) -> VortexResult<Output> {
60 let ZipArgs {
61 if_true,
62 if_false,
63 mask,
64 } = ZipArgs::try_from(args)?;
65
66 if mask.all_true() {
67 return Ok(cast(if_true, &zip_return_dtype(if_true, if_false))?.into());
68 }
69
70 if mask.all_false() {
71 return Ok(cast(if_false, &zip_return_dtype(if_true, if_false))?.into());
72 }
73
74 for kernel in kernels {
76 if let Some(output) = kernel.invoke(args)? {
77 return Ok(output);
78 }
79 }
80
81 if let Some(output) = if_true.invoke(&ZIP_FN, args)? {
82 return Ok(output);
83 }
84
85 if !if_true.is_canonical() || !if_false.is_canonical() {
89 return zip(
90 if_true.to_canonical().as_ref(),
91 if_false.to_canonical().as_ref(),
92 mask,
93 )
94 .map(Into::into);
95 }
96
97 Ok(zip_impl(
98 if_true.to_canonical().as_ref(),
99 if_false.to_canonical().as_ref(),
100 mask,
101 )?
102 .into())
103 }
104
105 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
106 let ZipArgs {
107 if_true, if_false, ..
108 } = ZipArgs::try_from(args)?;
109
110 if !if_true.dtype().eq_ignore_nullability(if_false.dtype()) {
111 vortex_bail!("input arrays to zip must have the same dtype");
112 }
113 Ok(zip_return_dtype(if_true, if_false))
114 }
115
116 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
117 let ZipArgs { if_true, mask, .. } = ZipArgs::try_from(args)?;
118 if if_true.len() != mask.len() {
120 vortex_bail!("input arrays must have the same length as the mask");
121 }
122 Ok(if_true.len())
123 }
124
125 fn is_elementwise(&self) -> bool {
126 true
127 }
128}
129
130struct ZipArgs<'a> {
131 if_true: &'a dyn Array,
132 if_false: &'a dyn Array,
133 mask: &'a Mask,
134}
135
136impl<'a> TryFrom<&InvocationArgs<'a>> for ZipArgs<'a> {
137 type Error = VortexError;
138
139 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
140 if value.inputs.len() != 3 {
141 vortex_bail!("Expected 3 inputs for zip, found {}", value.inputs.len());
142 }
143 let if_true = value.inputs[0]
144 .array()
145 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
146
147 let if_false = value.inputs[1]
148 .array()
149 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
150
151 let mask = value.inputs[2]
152 .mask()
153 .ok_or_else(|| vortex_err!("Expected input 2 to be a mask"))?;
154
155 Ok(Self {
156 if_true,
157 if_false,
158 mask,
159 })
160 }
161}
162
163pub trait ZipKernel: VTable {
164 fn zip(
165 &self,
166 if_true: &Self::Array,
167 if_false: &dyn Array,
168 mask: &Mask,
169 ) -> VortexResult<Option<ArrayRef>>;
170}
171
172pub struct ZipKernelRef(pub ArcRef<dyn Kernel>);
173inventory::collect!(ZipKernelRef);
174
175#[derive(Debug)]
176pub struct ZipKernelAdapter<V: VTable>(pub V);
177
178impl<V: VTable + ZipKernel> ZipKernelAdapter<V> {
179 pub const fn lift(&'static self) -> ZipKernelRef {
180 ZipKernelRef(ArcRef::new_ref(self))
181 }
182}
183
184impl<V: VTable + ZipKernel> Kernel for ZipKernelAdapter<V> {
185 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
186 let ZipArgs {
187 if_true,
188 if_false,
189 mask,
190 } = ZipArgs::try_from(args)?;
191 let Some(if_true) = if_true.as_opt::<V>() else {
192 return Ok(None);
193 };
194 Ok(V::zip(&self.0, if_true, if_false, mask)?.map(Into::into))
195 }
196}
197
198pub(crate) fn zip_return_dtype(if_true: &dyn Array, if_false: &dyn Array) -> DType {
199 if_true
200 .dtype()
201 .union_nullability(if_false.dtype().nullability())
202}
203
204fn zip_impl(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
205 assert_eq!(
206 if_true.len(),
207 if_false.len(),
208 "ComputeFn::invoke checks that arrays have the same size"
209 );
210
211 let return_type = zip_return_dtype(if_true, if_false);
212 zip_impl_with_builder(
213 if_true,
214 if_false,
215 mask,
216 builder_with_capacity(&return_type, if_true.len()),
217 )
218}
219
220fn zip_impl_with_builder(
221 if_true: &dyn Array,
222 if_false: &dyn Array,
223 mask: &Mask,
224 mut builder: Box<dyn ArrayBuilder>,
225) -> VortexResult<ArrayRef> {
226 match mask.slices() {
227 AllOr::All => Ok(if_true.to_array()),
228 AllOr::None => Ok(if_false.to_array()),
229 AllOr::Some(slices) => {
230 for (start, end) in slices {
231 builder.extend_from_array(&if_false.slice(builder.len()..*start));
232 builder.extend_from_array(&if_true.slice(*start..*end));
233 }
234 if builder.len() < if_false.len() {
235 builder.extend_from_array(&if_false.slice(builder.len()..if_false.len()));
236 }
237 Ok(builder.finish())
238 }
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use arrow_array::cast::AsArray;
245 use arrow_select::zip::zip as arrow_zip;
246 use vortex_buffer::buffer;
247 use vortex_dtype::DType;
248 use vortex_dtype::Nullability;
249 use vortex_mask::Mask;
250 use vortex_scalar::Scalar;
251
252 use crate::Array;
253 use crate::IntoArray;
254 use crate::ToCanonical;
255 use crate::arrays::ConstantArray;
256 use crate::arrays::PrimitiveArray;
257 use crate::arrays::StructArray;
258 use crate::arrays::VarBinViewVTable;
259 use crate::arrow::IntoArrowArray;
260 use crate::builders::ArrayBuilder;
261 use crate::builders::BufferGrowthStrategy;
262 use crate::builders::VarBinViewBuilder;
263 use crate::compute::zip;
264
265 #[test]
266 fn test_zip_basic() {
267 let mask = Mask::from_iter([true, false, false, true, false]);
268 let if_true = buffer![10, 20, 30, 40, 50].into_array();
269 let if_false = buffer![1, 2, 3, 4, 5].into_array();
270
271 let result = zip(&if_true, &if_false, &mask).unwrap();
272 let expected = buffer![10, 2, 3, 40, 5].into_array();
273
274 assert_eq!(
275 result.to_primitive().as_slice::<i32>(),
276 expected.to_primitive().as_slice::<i32>()
277 );
278 }
279
280 #[test]
281 fn test_zip_all_true() {
282 let mask = Mask::new_true(4);
283 let if_true = buffer![10, 20, 30, 40].into_array();
284 let if_false =
285 PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
286
287 let result = zip(&if_true, &if_false, &mask).unwrap();
288
289 assert_eq!(
290 result.to_primitive().as_slice::<i32>(),
291 if_true.to_primitive().as_slice::<i32>()
292 );
293
294 assert_eq!(result.dtype(), if_false.dtype())
296 }
297
298 #[test]
299 #[should_panic]
300 fn test_invalid_lengths() {
301 let mask = Mask::new_false(4);
302 let if_true = buffer![10, 20, 30].into_array();
303 let if_false = buffer![1, 2, 3, 4].into_array();
304
305 zip(&if_true, &if_false, &mask).unwrap();
306 }
307
308 #[test]
309 fn test_fragmentation() {
310 let len = 100;
311
312 let const1 = ConstantArray::new(
313 Scalar::utf8("hello_this_is_a_longer_string", Nullability::Nullable),
314 len,
315 )
316 .to_array();
317
318 let const2 = ConstantArray::new(
319 Scalar::utf8("world_this_is_another_string", Nullability::Nullable),
320 len,
321 )
322 .to_array();
323
324 let indices: Vec<usize> = (0..len).step_by(2).collect();
327 let mask = Mask::from_indices(len, indices);
328
329 let result = zip(&const1, &const2, &mask).unwrap();
330
331 insta::assert_snapshot!(result.display_tree(), @r"
332 root: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%)
333 metadata: EmptyMetadata
334 buffer (align=1): 29 B (1.75%)
335 buffer (align=1): 28 B (1.69%)
336 buffer (align=16): 1.60 kB (96.56%)
337 ");
338
339 let wrapped1 = StructArray::try_from_iter([("nested", const1)])
341 .unwrap()
342 .to_array();
343 let wrapped2 = StructArray::try_from_iter([("nested", const2)])
344 .unwrap()
345 .to_array();
346
347 let wrapped_result = zip(&wrapped1, &wrapped2, &mask).unwrap();
348 insta::assert_snapshot!(wrapped_result.display_tree(), @r"
349 root: vortex.struct({nested=utf8?}, len=100) nbytes=1.66 kB (100.00%)
350 metadata: EmptyMetadata
351 nested: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%)
352 metadata: EmptyMetadata
353 buffer (align=1): 29 B (1.75%)
354 buffer (align=1): 28 B (1.69%)
355 buffer (align=16): 1.60 kB (96.56%)
356 ");
357 }
358
359 #[test]
360 fn test_varbinview_zip() {
361 let if_true = {
362 let mut builder = VarBinViewBuilder::new(
363 DType::Utf8(Nullability::NonNullable),
364 10,
365 Default::default(),
366 BufferGrowthStrategy::fixed(64 * 1024),
367 0.0,
368 );
369 for _ in 0..100 {
370 builder.append_value("Hello");
371 builder.append_value("Hello this is a long string that won't be inlined.");
372 }
373 builder.finish()
374 };
375
376 let if_false = {
377 let mut builder = VarBinViewBuilder::new(
378 DType::Utf8(Nullability::NonNullable),
379 10,
380 Default::default(),
381 BufferGrowthStrategy::fixed(64 * 1024),
382 0.0,
383 );
384 for _ in 0..100 {
385 builder.append_value("Hello2");
386 builder.append_value("Hello2 this is a long string that won't be inlined.");
387 }
388 builder.finish()
389 };
390
391 let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect());
393
394 let zipped = zip(&if_true, &if_false, &mask).unwrap();
395 let zipped = zipped.as_opt::<VarBinViewVTable>().unwrap();
396 assert_eq!(zipped.nbuffers(), 2);
397
398 let expected = arrow_zip(
400 mask.into_array()
401 .into_arrow_preferred()
402 .unwrap()
403 .as_boolean(),
404 &if_true.into_arrow_preferred().unwrap(),
405 &if_false.into_arrow_preferred().unwrap(),
406 )
407 .unwrap();
408
409 let actual = zipped.clone().into_array().into_arrow_preferred().unwrap();
410 assert_eq!(actual.as_ref(), expected.as_ref());
411 }
412}