vortex_array/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::VortexError;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12use vortex_mask::AllOr;
13use vortex_mask::Mask;
14
15use super::ComputeFnVTable;
16use super::InvocationArgs;
17use super::Output;
18use super::cast;
19use crate::Array;
20use crate::ArrayRef;
21use crate::builders::ArrayBuilder;
22use crate::builders::builder_with_capacity;
23use crate::compute::ComputeFn;
24use crate::compute::Kernel;
25use crate::vtable::VTable;
26
27/// Performs element-wise conditional selection between two arrays based on a mask.
28///
29/// Returns a new array where `result[i] = if_true[i]` when `mask[i]` is true,
30/// otherwise `result[i] = if_false[i]`.
31pub fn zip(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
32    ZIP_FN
33        .invoke(&InvocationArgs {
34            inputs: &[if_true.into(), if_false.into(), mask.into()],
35            options: &(),
36        })?
37        .unwrap_array()
38}
39
40pub static ZIP_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
41    let compute = ComputeFn::new("zip".into(), ArcRef::new_ref(&Zip));
42    for kernel in inventory::iter::<ZipKernelRef> {
43        compute.register_kernel(kernel.0.clone());
44    }
45    compute
46});
47
48pub(crate) fn warm_up_vtable() -> usize {
49    ZIP_FN.kernels().len()
50}
51
52struct Zip;
53
54impl ComputeFnVTable for Zip {
55    fn invoke(
56        &self,
57        args: &InvocationArgs,
58        kernels: &[ArcRef<dyn Kernel>],
59    ) -> VortexResult<Output> {
60        let ZipArgs {
61            if_true,
62            if_false,
63            mask,
64        } = ZipArgs::try_from(args)?;
65
66        if mask.all_true() {
67            return Ok(cast(if_true, &zip_return_dtype(if_true, if_false))?.into());
68        }
69
70        if mask.all_false() {
71            return Ok(cast(if_false, &zip_return_dtype(if_true, if_false))?.into());
72        }
73
74        // check if if_true supports zip directly
75        for kernel in kernels {
76            if let Some(output) = kernel.invoke(args)? {
77                return Ok(output);
78            }
79        }
80
81        if let Some(output) = if_true.invoke(&ZIP_FN, args)? {
82            return Ok(output);
83        }
84
85        // TODO(os): add invert_mask opt and check if if_false has a kernel like:
86        //           kernel.invoke(Args(if_false, if_true, mask, invert_mask = true))
87
88        if !if_true.is_canonical() || !if_false.is_canonical() {
89            return zip(
90                if_true.to_canonical().as_ref(),
91                if_false.to_canonical().as_ref(),
92                mask,
93            )
94            .map(Into::into);
95        }
96
97        Ok(zip_impl(
98            if_true.to_canonical().as_ref(),
99            if_false.to_canonical().as_ref(),
100            mask,
101        )?
102        .into())
103    }
104
105    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
106        let ZipArgs {
107            if_true, if_false, ..
108        } = ZipArgs::try_from(args)?;
109
110        if !if_true.dtype().eq_ignore_nullability(if_false.dtype()) {
111            vortex_bail!("input arrays to zip must have the same dtype");
112        }
113        Ok(zip_return_dtype(if_true, if_false))
114    }
115
116    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
117        let ZipArgs { if_true, mask, .. } = ZipArgs::try_from(args)?;
118        // ComputeFn::invoke asserts if_true.len() == if_false.len(), because zip is elementwise
119        if if_true.len() != mask.len() {
120            vortex_bail!("input arrays must have the same length as the mask");
121        }
122        Ok(if_true.len())
123    }
124
125    fn is_elementwise(&self) -> bool {
126        true
127    }
128}
129
130struct ZipArgs<'a> {
131    if_true: &'a dyn Array,
132    if_false: &'a dyn Array,
133    mask: &'a Mask,
134}
135
136impl<'a> TryFrom<&InvocationArgs<'a>> for ZipArgs<'a> {
137    type Error = VortexError;
138
139    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
140        if value.inputs.len() != 3 {
141            vortex_bail!("Expected 3 inputs for zip, found {}", value.inputs.len());
142        }
143        let if_true = value.inputs[0]
144            .array()
145            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
146
147        let if_false = value.inputs[1]
148            .array()
149            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
150
151        let mask = value.inputs[2]
152            .mask()
153            .ok_or_else(|| vortex_err!("Expected input 2 to be a mask"))?;
154
155        Ok(Self {
156            if_true,
157            if_false,
158            mask,
159        })
160    }
161}
162
163pub trait ZipKernel: VTable {
164    fn zip(
165        &self,
166        if_true: &Self::Array,
167        if_false: &dyn Array,
168        mask: &Mask,
169    ) -> VortexResult<Option<ArrayRef>>;
170}
171
172pub struct ZipKernelRef(pub ArcRef<dyn Kernel>);
173inventory::collect!(ZipKernelRef);
174
175#[derive(Debug)]
176pub struct ZipKernelAdapter<V: VTable>(pub V);
177
178impl<V: VTable + ZipKernel> ZipKernelAdapter<V> {
179    pub const fn lift(&'static self) -> ZipKernelRef {
180        ZipKernelRef(ArcRef::new_ref(self))
181    }
182}
183
184impl<V: VTable + ZipKernel> Kernel for ZipKernelAdapter<V> {
185    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
186        let ZipArgs {
187            if_true,
188            if_false,
189            mask,
190        } = ZipArgs::try_from(args)?;
191        let Some(if_true) = if_true.as_opt::<V>() else {
192            return Ok(None);
193        };
194        Ok(V::zip(&self.0, if_true, if_false, mask)?.map(Into::into))
195    }
196}
197
198pub(crate) fn zip_return_dtype(if_true: &dyn Array, if_false: &dyn Array) -> DType {
199    if_true
200        .dtype()
201        .union_nullability(if_false.dtype().nullability())
202}
203
204fn zip_impl(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
205    assert_eq!(
206        if_true.len(),
207        if_false.len(),
208        "ComputeFn::invoke checks that arrays have the same size"
209    );
210
211    let return_type = zip_return_dtype(if_true, if_false);
212    zip_impl_with_builder(
213        if_true,
214        if_false,
215        mask,
216        builder_with_capacity(&return_type, if_true.len()),
217    )
218}
219
220fn zip_impl_with_builder(
221    if_true: &dyn Array,
222    if_false: &dyn Array,
223    mask: &Mask,
224    mut builder: Box<dyn ArrayBuilder>,
225) -> VortexResult<ArrayRef> {
226    match mask.slices() {
227        AllOr::All => Ok(if_true.to_array()),
228        AllOr::None => Ok(if_false.to_array()),
229        AllOr::Some(slices) => {
230            for (start, end) in slices {
231                builder.extend_from_array(&if_false.slice(builder.len()..*start));
232                builder.extend_from_array(&if_true.slice(*start..*end));
233            }
234            if builder.len() < if_false.len() {
235                builder.extend_from_array(&if_false.slice(builder.len()..if_false.len()));
236            }
237            Ok(builder.finish())
238        }
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use arrow_array::cast::AsArray;
245    use arrow_select::zip::zip as arrow_zip;
246    use vortex_buffer::buffer;
247    use vortex_dtype::DType;
248    use vortex_dtype::Nullability;
249    use vortex_mask::Mask;
250    use vortex_scalar::Scalar;
251
252    use crate::Array;
253    use crate::IntoArray;
254    use crate::ToCanonical;
255    use crate::arrays::ConstantArray;
256    use crate::arrays::PrimitiveArray;
257    use crate::arrays::StructArray;
258    use crate::arrays::VarBinViewVTable;
259    use crate::arrow::IntoArrowArray;
260    use crate::builders::ArrayBuilder;
261    use crate::builders::BufferGrowthStrategy;
262    use crate::builders::VarBinViewBuilder;
263    use crate::compute::zip;
264
265    #[test]
266    fn test_zip_basic() {
267        let mask = Mask::from_iter([true, false, false, true, false]);
268        let if_true = buffer![10, 20, 30, 40, 50].into_array();
269        let if_false = buffer![1, 2, 3, 4, 5].into_array();
270
271        let result = zip(&if_true, &if_false, &mask).unwrap();
272        let expected = buffer![10, 2, 3, 40, 5].into_array();
273
274        assert_eq!(
275            result.to_primitive().as_slice::<i32>(),
276            expected.to_primitive().as_slice::<i32>()
277        );
278    }
279
280    #[test]
281    fn test_zip_all_true() {
282        let mask = Mask::new_true(4);
283        let if_true = buffer![10, 20, 30, 40].into_array();
284        let if_false =
285            PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
286
287        let result = zip(&if_true, &if_false, &mask).unwrap();
288
289        assert_eq!(
290            result.to_primitive().as_slice::<i32>(),
291            if_true.to_primitive().as_slice::<i32>()
292        );
293
294        // result must be nullable even if_true was not
295        assert_eq!(result.dtype(), if_false.dtype())
296    }
297
298    #[test]
299    #[should_panic]
300    fn test_invalid_lengths() {
301        let mask = Mask::new_false(4);
302        let if_true = buffer![10, 20, 30].into_array();
303        let if_false = buffer![1, 2, 3, 4].into_array();
304
305        zip(&if_true, &if_false, &mask).unwrap();
306    }
307
308    #[test]
309    fn test_fragmentation() {
310        let len = 100;
311
312        let const1 = ConstantArray::new(
313            Scalar::utf8("hello_this_is_a_longer_string", Nullability::Nullable),
314            len,
315        )
316        .to_array();
317
318        let const2 = ConstantArray::new(
319            Scalar::utf8("world_this_is_another_string", Nullability::Nullable),
320            len,
321        )
322        .to_array();
323
324        // Create a mask that alternates frequently to cause fragmentation
325        // Pattern: take from const1 at even indices, const2 at odd indices
326        let indices: Vec<usize> = (0..len).step_by(2).collect();
327        let mask = Mask::from_indices(len, indices);
328
329        let result = zip(&const1, &const2, &mask).unwrap();
330
331        insta::assert_snapshot!(result.display_tree(), @r"
332        root: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%)
333          metadata: EmptyMetadata
334          buffer (align=1): 29 B (1.75%)
335          buffer (align=1): 28 B (1.69%)
336          buffer (align=16): 1.60 kB (96.56%)
337        ");
338
339        // test wrapped in a struct
340        let wrapped1 = StructArray::try_from_iter([("nested", const1)])
341            .unwrap()
342            .to_array();
343        let wrapped2 = StructArray::try_from_iter([("nested", const2)])
344            .unwrap()
345            .to_array();
346
347        let wrapped_result = zip(&wrapped1, &wrapped2, &mask).unwrap();
348        insta::assert_snapshot!(wrapped_result.display_tree(), @r"
349        root: vortex.struct({nested=utf8?}, len=100) nbytes=1.66 kB (100.00%)
350          metadata: EmptyMetadata
351          nested: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%)
352            metadata: EmptyMetadata
353            buffer (align=1): 29 B (1.75%)
354            buffer (align=1): 28 B (1.69%)
355            buffer (align=16): 1.60 kB (96.56%)
356        ");
357    }
358
359    #[test]
360    fn test_varbinview_zip() {
361        let if_true = {
362            let mut builder = VarBinViewBuilder::new(
363                DType::Utf8(Nullability::NonNullable),
364                10,
365                Default::default(),
366                BufferGrowthStrategy::fixed(64 * 1024),
367                0.0,
368            );
369            for _ in 0..100 {
370                builder.append_value("Hello");
371                builder.append_value("Hello this is a long string that won't be inlined.");
372            }
373            builder.finish()
374        };
375
376        let if_false = {
377            let mut builder = VarBinViewBuilder::new(
378                DType::Utf8(Nullability::NonNullable),
379                10,
380                Default::default(),
381                BufferGrowthStrategy::fixed(64 * 1024),
382                0.0,
383            );
384            for _ in 0..100 {
385                builder.append_value("Hello2");
386                builder.append_value("Hello2 this is a long string that won't be inlined.");
387            }
388            builder.finish()
389        };
390
391        // [1,2,4,5,7,8,..]
392        let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect());
393
394        let zipped = zip(&if_true, &if_false, &mask).unwrap();
395        let zipped = zipped.as_opt::<VarBinViewVTable>().unwrap();
396        assert_eq!(zipped.nbuffers(), 2);
397
398        // assert the result is the same as arrow
399        let expected = arrow_zip(
400            mask.into_array()
401                .into_arrow_preferred()
402                .unwrap()
403                .as_boolean(),
404            &if_true.into_arrow_preferred().unwrap(),
405            &if_false.into_arrow_preferred().unwrap(),
406        )
407        .unwrap();
408
409        let actual = zipped.clone().into_array().into_arrow_preferred().unwrap();
410        assert_eq!(actual.as_ref(), expected.as_ref());
411    }
412}