vortex_array/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9use vortex_mask::{AllOr, Mask};
10
11use super::{ComputeFnVTable, InvocationArgs, Output, cast};
12use crate::builders::{ArrayBuilder, VarBinViewBuilder, builder_with_capacity};
13use crate::compute::{ComputeFn, Kernel};
14use crate::vtable::VTable;
15use crate::{Array, ArrayRef};
16
17/// Performs element-wise conditional selection between two arrays based on a mask.
18///
19/// Returns a new array where `result[i] = if_true[i]` when `mask[i]` is true,
20/// otherwise `result[i] = if_false[i]`.
21pub fn zip(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
22    ZIP_FN
23        .invoke(&InvocationArgs {
24            inputs: &[if_true.into(), if_false.into(), mask.into()],
25            options: &(),
26        })?
27        .unwrap_array()
28}
29
30pub static ZIP_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
31    let compute = ComputeFn::new("zip".into(), ArcRef::new_ref(&Zip));
32    for kernel in inventory::iter::<ZipKernelRef> {
33        compute.register_kernel(kernel.0.clone());
34    }
35    compute
36});
37
38pub(crate) fn warm_up_vtable() -> usize {
39    ZIP_FN.kernels().len()
40}
41
42struct Zip;
43
44impl ComputeFnVTable for Zip {
45    fn invoke(
46        &self,
47        args: &InvocationArgs,
48        kernels: &[ArcRef<dyn Kernel>],
49    ) -> VortexResult<Output> {
50        let ZipArgs {
51            if_true,
52            if_false,
53            mask,
54        } = ZipArgs::try_from(args)?;
55
56        if mask.all_true() {
57            return Ok(cast(if_true, &zip_return_dtype(if_true, if_false))?.into());
58        }
59
60        if mask.all_false() {
61            return Ok(cast(if_false, &zip_return_dtype(if_true, if_false))?.into());
62        }
63
64        // check if if_true supports zip directly
65        for kernel in kernels {
66            if let Some(output) = kernel.invoke(args)? {
67                return Ok(output);
68            }
69        }
70
71        if let Some(output) = if_true.invoke(&ZIP_FN, args)? {
72            return Ok(output);
73        }
74
75        // TODO(os): add invert_mask opt and check if if_false has a kernel like:
76        //           kernel.invoke(Args(if_false, if_true, mask, invert_mask = true))
77
78        if !if_true.is_canonical() || !if_false.is_canonical() {
79            return zip(
80                if_true.to_canonical().as_ref(),
81                if_false.to_canonical().as_ref(),
82                mask,
83            )
84            .map(Into::into);
85        }
86
87        Ok(zip_impl(
88            if_true.to_canonical().as_ref(),
89            if_false.to_canonical().as_ref(),
90            mask,
91        )?
92        .into())
93    }
94
95    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
96        let ZipArgs {
97            if_true, if_false, ..
98        } = ZipArgs::try_from(args)?;
99
100        if !if_true.dtype().eq_ignore_nullability(if_false.dtype()) {
101            vortex_bail!("input arrays to zip must have the same dtype");
102        }
103        Ok(zip_return_dtype(if_true, if_false))
104    }
105
106    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
107        let ZipArgs { if_true, mask, .. } = ZipArgs::try_from(args)?;
108        // ComputeFn::invoke asserts if_true.len() == if_false.len(), because zip is elementwise
109        if if_true.len() != mask.len() {
110            vortex_bail!("input arrays must have the same length as the mask");
111        }
112        Ok(if_true.len())
113    }
114
115    fn is_elementwise(&self) -> bool {
116        true
117    }
118}
119
120struct ZipArgs<'a> {
121    if_true: &'a dyn Array,
122    if_false: &'a dyn Array,
123    mask: &'a Mask,
124}
125
126impl<'a> TryFrom<&InvocationArgs<'a>> for ZipArgs<'a> {
127    type Error = VortexError;
128
129    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
130        if value.inputs.len() != 3 {
131            vortex_bail!("Expected 3 inputs for zip, found {}", value.inputs.len());
132        }
133        let if_true = value.inputs[0]
134            .array()
135            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
136
137        let if_false = value.inputs[1]
138            .array()
139            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
140
141        let mask = value.inputs[2]
142            .mask()
143            .ok_or_else(|| vortex_err!("Expected input 2 to be a mask"))?;
144
145        Ok(Self {
146            if_true,
147            if_false,
148            mask,
149        })
150    }
151}
152
153pub trait ZipKernel: VTable {
154    fn zip(
155        &self,
156        if_true: &Self::Array,
157        if_false: &dyn Array,
158        mask: &Mask,
159    ) -> VortexResult<Option<ArrayRef>>;
160}
161
162pub struct ZipKernelRef(pub ArcRef<dyn Kernel>);
163inventory::collect!(ZipKernelRef);
164
165#[derive(Debug)]
166pub struct ZipKernelAdapter<V: VTable>(pub V);
167
168impl<V: VTable + ZipKernel> ZipKernelAdapter<V> {
169    pub const fn lift(&'static self) -> ZipKernelRef {
170        ZipKernelRef(ArcRef::new_ref(self))
171    }
172}
173
174impl<V: VTable + ZipKernel> Kernel for ZipKernelAdapter<V> {
175    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
176        let ZipArgs {
177            if_true,
178            if_false,
179            mask,
180        } = ZipArgs::try_from(args)?;
181        let Some(if_true) = if_true.as_opt::<V>() else {
182            return Ok(None);
183        };
184        Ok(V::zip(&self.0, if_true, if_false, mask)?.map(Into::into))
185    }
186}
187
188pub(crate) fn zip_return_dtype(if_true: &dyn Array, if_false: &dyn Array) -> DType {
189    if_true
190        .dtype()
191        .union_nullability(if_false.dtype().nullability())
192}
193
194fn zip_impl(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
195    assert_eq!(
196        if_true.len(),
197        if_false.len(),
198        "ComputeFn::invoke checks that arrays have the same size"
199    );
200
201    let return_type = zip_return_dtype(if_true, if_false);
202    let capacity = if_true.len();
203
204    let builder = match return_type {
205        // TODO(blaginin): once https://github.com/vortex-data/vortex/pull/4695 is merged, we can kill
206        //  these two special cases, but before that we need to manually use deduplicated buffers.
207        //  Otherwise, the same buffer will be appended multiple times causing fragmentation.
208        DType::Utf8(n) => Box::new(VarBinViewBuilder::with_buffer_deduplication(
209            DType::Utf8(n),
210            capacity,
211        )),
212        DType::Binary(n) => Box::new(VarBinViewBuilder::with_buffer_deduplication(
213            DType::Binary(n),
214            capacity,
215        )),
216        _ => builder_with_capacity(&return_type, if_true.len()),
217    };
218
219    zip_impl_with_builder(if_true, if_false, mask, builder)
220}
221
222pub(crate) fn zip_impl_with_builder(
223    if_true: &dyn Array,
224    if_false: &dyn Array,
225    mask: &Mask,
226    mut builder: Box<dyn ArrayBuilder>,
227) -> VortexResult<ArrayRef> {
228    match mask.slices() {
229        AllOr::All => Ok(if_true.to_array()),
230        AllOr::None => Ok(if_false.to_array()),
231        AllOr::Some(slices) => {
232            for (start, end) in slices {
233                builder.extend_from_array(&if_false.slice(builder.len()..*start));
234                builder.extend_from_array(&if_true.slice(*start..*end));
235            }
236            if builder.len() < if_false.len() {
237                builder.extend_from_array(&if_false.slice(builder.len()..if_false.len()));
238            }
239            Ok(builder.finish())
240        }
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use arrow_array::cast::AsArray;
247    use arrow_select::zip::zip as arrow_zip;
248    use vortex_buffer::buffer;
249    use vortex_dtype::{DType, Nullability};
250    use vortex_mask::Mask;
251    use vortex_scalar::Scalar;
252
253    use crate::arrays::{ConstantArray, PrimitiveArray, StructArray, VarBinViewVTable};
254    use crate::arrow::IntoArrowArray;
255    use crate::builders::{ArrayBuilder, BufferGrowthStrategy};
256    use crate::compute::zip;
257    use crate::compute::zip::VarBinViewBuilder;
258    use crate::{Array, IntoArray, ToCanonical};
259
260    #[test]
261    fn test_zip_basic() {
262        let mask = Mask::from_iter([true, false, false, true, false]);
263        let if_true = buffer![10, 20, 30, 40, 50].into_array();
264        let if_false = buffer![1, 2, 3, 4, 5].into_array();
265
266        let result = zip(&if_true, &if_false, &mask).unwrap();
267        let expected = buffer![10, 2, 3, 40, 5].into_array();
268
269        assert_eq!(
270            result.to_primitive().as_slice::<i32>(),
271            expected.to_primitive().as_slice::<i32>()
272        );
273    }
274
275    #[test]
276    fn test_zip_all_true() {
277        let mask = Mask::new_true(4);
278        let if_true = buffer![10, 20, 30, 40].into_array();
279        let if_false =
280            PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
281
282        let result = zip(&if_true, &if_false, &mask).unwrap();
283
284        assert_eq!(
285            result.to_primitive().as_slice::<i32>(),
286            if_true.to_primitive().as_slice::<i32>()
287        );
288
289        // result must be nullable even if_true was not
290        assert_eq!(result.dtype(), if_false.dtype())
291    }
292
293    #[test]
294    #[should_panic]
295    fn test_invalid_lengths() {
296        let mask = Mask::new_false(4);
297        let if_true = buffer![10, 20, 30].into_array();
298        let if_false = buffer![1, 2, 3, 4].into_array();
299
300        zip(&if_true, &if_false, &mask).unwrap();
301    }
302
303    #[test]
304    fn test_fragmentation() {
305        let len = 100;
306
307        let const1 = ConstantArray::new(
308            Scalar::utf8("hello_this_is_a_longer_string", Nullability::Nullable),
309            len,
310        )
311        .to_array();
312
313        let const2 = ConstantArray::new(
314            Scalar::utf8("world_this_is_another_string", Nullability::Nullable),
315            len,
316        )
317        .to_array();
318
319        // Create a mask that alternates frequently to cause fragmentation
320        // Pattern: take from const1 at even indices, const2 at odd indices
321        let indices: Vec<usize> = (0..len).step_by(2).collect();
322        let mask = Mask::from_indices(len, indices);
323
324        let result = zip(&const1, &const2, &mask).unwrap();
325
326        insta::assert_snapshot!(result.display_tree(), @r"
327        root: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%)
328          metadata: EmptyMetadata
329          buffer (align=1): 29 B (1.75%)
330          buffer (align=1): 28 B (1.69%)
331          buffer (align=16): 1.60 kB (96.56%)
332        ");
333
334        // test wrapped in a struct
335        let wrapped1 = StructArray::try_from_iter([("nested", const1)])
336            .unwrap()
337            .to_array();
338        let wrapped2 = StructArray::try_from_iter([("nested", const2)])
339            .unwrap()
340            .to_array();
341
342        let wrapped_result = zip(&wrapped1, &wrapped2, &mask).unwrap();
343        insta::assert_snapshot!(wrapped_result.display_tree(), @r"
344        root: vortex.struct({nested=utf8?}, len=100) nbytes=1.66 kB (100.00%)
345          metadata: EmptyMetadata
346          nested: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%)
347            metadata: EmptyMetadata
348            buffer (align=1): 29 B (1.75%)
349            buffer (align=1): 28 B (1.69%)
350            buffer (align=16): 1.60 kB (96.56%)
351        ");
352    }
353
354    #[test]
355    fn test_varbinview_zip() {
356        let if_true = {
357            let mut builder = VarBinViewBuilder::new(
358                DType::Utf8(Nullability::NonNullable),
359                10,
360                Default::default(),
361                BufferGrowthStrategy::fixed(64 * 1024),
362                0.0,
363            );
364            for _ in 0..100 {
365                builder.append_value("Hello");
366                builder.append_value("Hello this is a long string that won't be inlined.");
367            }
368            builder.finish()
369        };
370
371        let if_false = {
372            let mut builder = VarBinViewBuilder::new(
373                DType::Utf8(Nullability::NonNullable),
374                10,
375                Default::default(),
376                BufferGrowthStrategy::fixed(64 * 1024),
377                0.0,
378            );
379            for _ in 0..100 {
380                builder.append_value("Hello2");
381                builder.append_value("Hello2 this is a long string that won't be inlined.");
382            }
383            builder.finish()
384        };
385
386        // [1,2,4,5,7,8,..]
387        let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect());
388
389        let zipped = zip(&if_true, &if_false, &mask).unwrap();
390        let zipped = zipped.as_opt::<VarBinViewVTable>().unwrap();
391        assert_eq!(zipped.nbuffers(), 2);
392
393        // assert the result is the same as arrow
394        let expected = arrow_zip(
395            mask.into_array()
396                .into_arrow_preferred()
397                .unwrap()
398                .as_boolean(),
399            &if_true.into_arrow_preferred().unwrap(),
400            &if_false.into_arrow_preferred().unwrap(),
401        )
402        .unwrap();
403
404        let actual = zipped.clone().into_array().into_arrow_preferred().unwrap();
405        assert_eq!(actual.as_ref(), expected.as_ref());
406    }
407}