vortex_array/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9use vortex_mask::{AllOr, Mask};
10
11use super::{ComputeFnVTable, InvocationArgs, Output, cast};
12use crate::builders::{ArrayBuilder, builder_with_capacity};
13use crate::compute::{ComputeFn, Kernel};
14use crate::vtable::VTable;
15use crate::{Array, ArrayRef};
16
17/// Performs element-wise conditional selection between two arrays based on a mask.
18///
19/// Returns a new array where `result[i] = if_true[i]` when `mask[i]` is true,
20/// otherwise `result[i] = if_false[i]`.
21pub fn zip(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
22    ZIP_FN
23        .invoke(&InvocationArgs {
24            inputs: &[if_true.into(), if_false.into(), mask.into()],
25            options: &(),
26        })?
27        .unwrap_array()
28}
29
30pub static ZIP_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
31    let compute = ComputeFn::new("zip".into(), ArcRef::new_ref(&Zip));
32    for kernel in inventory::iter::<ZipKernelRef> {
33        compute.register_kernel(kernel.0.clone());
34    }
35    compute
36});
37
38struct Zip;
39
40impl ComputeFnVTable for Zip {
41    fn invoke(
42        &self,
43        args: &InvocationArgs,
44        kernels: &[ArcRef<dyn Kernel>],
45    ) -> VortexResult<Output> {
46        let ZipArgs {
47            if_true,
48            if_false,
49            mask,
50        } = ZipArgs::try_from(args)?;
51
52        if mask.all_true() {
53            return Ok(cast(if_true, &zip_return_dtype(if_true, if_false))?.into());
54        }
55
56        if mask.all_false() {
57            return Ok(cast(if_false, &zip_return_dtype(if_true, if_false))?.into());
58        }
59
60        // check if if_true supports zip directly
61        for kernel in kernels {
62            if let Some(output) = kernel.invoke(args)? {
63                return Ok(output);
64            }
65        }
66
67        if let Some(output) = if_true.invoke(&ZIP_FN, args)? {
68            return Ok(output);
69        }
70
71        // TODO(os): add invert_mask opt and check if if_false has a kernel like:
72        //           kernel.invoke(Args(if_false, if_true, mask, invert_mask = true))
73
74        Ok(zip_impl(
75            if_true.to_canonical()?.as_ref(),
76            if_false.to_canonical()?.as_ref(),
77            mask,
78        )?
79        .into())
80    }
81
82    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
83        let ZipArgs {
84            if_true, if_false, ..
85        } = ZipArgs::try_from(args)?;
86
87        if !if_true.dtype().eq_ignore_nullability(if_false.dtype()) {
88            vortex_bail!("input arrays to zip must have the same dtype");
89        }
90        Ok(zip_return_dtype(if_true, if_false))
91    }
92
93    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
94        let ZipArgs { if_true, mask, .. } = ZipArgs::try_from(args)?;
95        // ComputeFn::invoke asserts if_true.len() == if_false.len(), because zip is elementwise
96        if if_true.len() != mask.len() {
97            vortex_bail!("input arrays must have the same length as the mask");
98        }
99        Ok(if_true.len())
100    }
101
102    fn is_elementwise(&self) -> bool {
103        true
104    }
105}
106
107struct ZipArgs<'a> {
108    if_true: &'a dyn Array,
109    if_false: &'a dyn Array,
110    mask: &'a Mask,
111}
112
113impl<'a> TryFrom<&InvocationArgs<'a>> for ZipArgs<'a> {
114    type Error = VortexError;
115
116    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
117        if value.inputs.len() != 3 {
118            vortex_bail!("Expected 3 inputs for zip, found {}", value.inputs.len());
119        }
120        let if_true = value.inputs[0]
121            .array()
122            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
123
124        let if_false = value.inputs[1]
125            .array()
126            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
127
128        let mask = value.inputs[2]
129            .mask()
130            .ok_or_else(|| vortex_err!("Expected input 2 to be a mask"))?;
131
132        Ok(Self {
133            if_true,
134            if_false,
135            mask,
136        })
137    }
138}
139
140pub trait ZipKernel: VTable {
141    fn zip(
142        &self,
143        if_true: &Self::Array,
144        if_false: &dyn Array,
145        mask: &Mask,
146    ) -> VortexResult<Option<ArrayRef>>;
147}
148
149pub struct ZipKernelRef(pub ArcRef<dyn Kernel>);
150inventory::collect!(ZipKernelRef);
151
152#[derive(Debug)]
153pub struct ZipKernelAdapter<V: VTable>(pub V);
154
155impl<V: VTable + ZipKernel> ZipKernelAdapter<V> {
156    pub const fn lift(&'static self) -> ZipKernelRef {
157        ZipKernelRef(ArcRef::new_ref(self))
158    }
159}
160
161impl<V: VTable + ZipKernel> Kernel for ZipKernelAdapter<V> {
162    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
163        let ZipArgs {
164            if_true,
165            if_false,
166            mask,
167        } = ZipArgs::try_from(args)?;
168        let Some(if_true) = if_true.as_opt::<V>() else {
169            return Ok(None);
170        };
171        Ok(V::zip(&self.0, if_true, if_false, mask)?.map(Into::into))
172    }
173}
174
175pub(crate) fn zip_return_dtype(if_true: &dyn Array, if_false: &dyn Array) -> DType {
176    if_true
177        .dtype()
178        .union_nullability(if_false.dtype().nullability())
179}
180
181fn zip_impl(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
182    // if_true.len() == if_false.len() from ComputeFn::invoke
183    let builder = builder_with_capacity(&zip_return_dtype(if_true, if_false), if_true.len());
184    zip_impl_with_builder(if_true, if_false, mask, builder)
185}
186
187pub(crate) fn zip_impl_with_builder(
188    if_true: &dyn Array,
189    if_false: &dyn Array,
190    mask: &Mask,
191    mut builder: Box<dyn ArrayBuilder>,
192) -> VortexResult<ArrayRef> {
193    match mask.slices() {
194        AllOr::All => Ok(if_true.to_array()),
195        AllOr::None => Ok(if_false.to_array()),
196        AllOr::Some(slices) => {
197            for (start, end) in slices {
198                builder.extend_from_array(&if_false.slice(builder.len(), *start))?;
199                builder.extend_from_array(&if_true.slice(*start, *end))?;
200            }
201            if builder.len() < if_false.len() {
202                builder.extend_from_array(&if_false.slice(builder.len(), if_false.len()))?;
203            }
204            Ok(builder.finish())
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use vortex_array::arrays::{BoolArray, PrimitiveArray};
212    use vortex_array::compute::zip;
213    use vortex_array::{IntoArray, ToCanonical};
214    use vortex_mask::Mask;
215
216    #[test]
217    fn test_zip_basic() {
218        let mask =
219            Mask::try_from(&BoolArray::from_iter([true, false, false, true, false])).unwrap();
220        let if_true = PrimitiveArray::from_iter([10, 20, 30, 40, 50]).into_array();
221        let if_false = PrimitiveArray::from_iter([1, 2, 3, 4, 5]).into_array();
222
223        let result = zip(&if_true, &if_false, &mask).unwrap();
224        let expected = PrimitiveArray::from_iter([10, 2, 3, 40, 5]);
225
226        assert_eq!(
227            result.to_primitive().unwrap().as_slice::<i32>(),
228            expected.as_slice::<i32>()
229        );
230    }
231
232    #[test]
233    fn test_zip_all_true() {
234        let mask = Mask::new_true(4);
235        let if_true = PrimitiveArray::from_iter([10, 20, 30, 40]).into_array();
236        let if_false =
237            PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
238
239        let result = zip(&if_true, &if_false, &mask).unwrap();
240
241        assert_eq!(
242            result.to_primitive().unwrap().as_slice::<i32>(),
243            if_true.to_primitive().unwrap().as_slice::<i32>()
244        );
245
246        // result must be nullable even if_true was not
247        assert_eq!(result.dtype(), if_false.dtype())
248    }
249
250    #[test]
251    #[should_panic]
252    fn test_invalid_lengths() {
253        let mask = Mask::new_false(4);
254        let if_true = PrimitiveArray::from_iter([10, 20, 30]).into_array();
255        let if_false = PrimitiveArray::from_iter([1, 2, 3, 4]).into_array();
256
257        zip(&if_true, &if_false, &mask).unwrap();
258    }
259}