vortex_array/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9use vortex_mask::{AllOr, Mask};
10
11use super::{ComputeFnVTable, InvocationArgs, Output, cast};
12use crate::builders::{ArrayBuilder, builder_with_capacity};
13use crate::compute::{ComputeFn, Kernel};
14use crate::vtable::VTable;
15use crate::{Array, ArrayRef};
16
17/// Performs element-wise conditional selection between two arrays based on a mask.
18///
19/// Returns a new array where `result[i] = if_true[i]` when `mask[i]` is true,
20/// otherwise `result[i] = if_false[i]`.
21pub fn zip(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
22    ZIP_FN
23        .invoke(&InvocationArgs {
24            inputs: &[if_true.into(), if_false.into(), mask.into()],
25            options: &(),
26        })?
27        .unwrap_array()
28}
29
30pub static ZIP_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
31    let compute = ComputeFn::new("zip".into(), ArcRef::new_ref(&Zip));
32    for kernel in inventory::iter::<ZipKernelRef> {
33        compute.register_kernel(kernel.0.clone());
34    }
35    compute
36});
37
38pub(crate) fn warm_up_vtable() -> usize {
39    ZIP_FN.kernels().len()
40}
41
42struct Zip;
43
44impl ComputeFnVTable for Zip {
45    fn invoke(
46        &self,
47        args: &InvocationArgs,
48        kernels: &[ArcRef<dyn Kernel>],
49    ) -> VortexResult<Output> {
50        let ZipArgs {
51            if_true,
52            if_false,
53            mask,
54        } = ZipArgs::try_from(args)?;
55
56        if mask.all_true() {
57            return Ok(cast(if_true, &zip_return_dtype(if_true, if_false))?.into());
58        }
59
60        if mask.all_false() {
61            return Ok(cast(if_false, &zip_return_dtype(if_true, if_false))?.into());
62        }
63
64        // check if if_true supports zip directly
65        for kernel in kernels {
66            if let Some(output) = kernel.invoke(args)? {
67                return Ok(output);
68            }
69        }
70
71        if let Some(output) = if_true.invoke(&ZIP_FN, args)? {
72            return Ok(output);
73        }
74
75        // TODO(os): add invert_mask opt and check if if_false has a kernel like:
76        //           kernel.invoke(Args(if_false, if_true, mask, invert_mask = true))
77
78        Ok(zip_impl(
79            if_true.to_canonical().as_ref(),
80            if_false.to_canonical().as_ref(),
81            mask,
82        )?
83        .into())
84    }
85
86    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
87        let ZipArgs {
88            if_true, if_false, ..
89        } = ZipArgs::try_from(args)?;
90
91        if !if_true.dtype().eq_ignore_nullability(if_false.dtype()) {
92            vortex_bail!("input arrays to zip must have the same dtype");
93        }
94        Ok(zip_return_dtype(if_true, if_false))
95    }
96
97    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
98        let ZipArgs { if_true, mask, .. } = ZipArgs::try_from(args)?;
99        // ComputeFn::invoke asserts if_true.len() == if_false.len(), because zip is elementwise
100        if if_true.len() != mask.len() {
101            vortex_bail!("input arrays must have the same length as the mask");
102        }
103        Ok(if_true.len())
104    }
105
106    fn is_elementwise(&self) -> bool {
107        true
108    }
109}
110
111struct ZipArgs<'a> {
112    if_true: &'a dyn Array,
113    if_false: &'a dyn Array,
114    mask: &'a Mask,
115}
116
117impl<'a> TryFrom<&InvocationArgs<'a>> for ZipArgs<'a> {
118    type Error = VortexError;
119
120    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
121        if value.inputs.len() != 3 {
122            vortex_bail!("Expected 3 inputs for zip, found {}", value.inputs.len());
123        }
124        let if_true = value.inputs[0]
125            .array()
126            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
127
128        let if_false = value.inputs[1]
129            .array()
130            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
131
132        let mask = value.inputs[2]
133            .mask()
134            .ok_or_else(|| vortex_err!("Expected input 2 to be a mask"))?;
135
136        Ok(Self {
137            if_true,
138            if_false,
139            mask,
140        })
141    }
142}
143
144pub trait ZipKernel: VTable {
145    fn zip(
146        &self,
147        if_true: &Self::Array,
148        if_false: &dyn Array,
149        mask: &Mask,
150    ) -> VortexResult<Option<ArrayRef>>;
151}
152
153pub struct ZipKernelRef(pub ArcRef<dyn Kernel>);
154inventory::collect!(ZipKernelRef);
155
156#[derive(Debug)]
157pub struct ZipKernelAdapter<V: VTable>(pub V);
158
159impl<V: VTable + ZipKernel> ZipKernelAdapter<V> {
160    pub const fn lift(&'static self) -> ZipKernelRef {
161        ZipKernelRef(ArcRef::new_ref(self))
162    }
163}
164
165impl<V: VTable + ZipKernel> Kernel for ZipKernelAdapter<V> {
166    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
167        let ZipArgs {
168            if_true,
169            if_false,
170            mask,
171        } = ZipArgs::try_from(args)?;
172        let Some(if_true) = if_true.as_opt::<V>() else {
173            return Ok(None);
174        };
175        Ok(V::zip(&self.0, if_true, if_false, mask)?.map(Into::into))
176    }
177}
178
179pub(crate) fn zip_return_dtype(if_true: &dyn Array, if_false: &dyn Array) -> DType {
180    if_true
181        .dtype()
182        .union_nullability(if_false.dtype().nullability())
183}
184
185fn zip_impl(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
186    // if_true.len() == if_false.len() from ComputeFn::invoke
187    let builder = builder_with_capacity(&zip_return_dtype(if_true, if_false), if_true.len());
188    zip_impl_with_builder(if_true, if_false, mask, builder)
189}
190
191pub(crate) fn zip_impl_with_builder(
192    if_true: &dyn Array,
193    if_false: &dyn Array,
194    mask: &Mask,
195    mut builder: Box<dyn ArrayBuilder>,
196) -> VortexResult<ArrayRef> {
197    match mask.slices() {
198        AllOr::All => Ok(if_true.to_array()),
199        AllOr::None => Ok(if_false.to_array()),
200        AllOr::Some(slices) => {
201            for (start, end) in slices {
202                builder.extend_from_array(&if_false.slice(builder.len()..*start));
203                builder.extend_from_array(&if_true.slice(*start..*end));
204            }
205            if builder.len() < if_false.len() {
206                builder.extend_from_array(&if_false.slice(builder.len()..if_false.len()));
207            }
208            Ok(builder.finish())
209        }
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use vortex_array::arrays::PrimitiveArray;
216    use vortex_array::compute::zip;
217    use vortex_array::{IntoArray, ToCanonical};
218    use vortex_buffer::buffer;
219    use vortex_mask::Mask;
220
221    #[test]
222    fn test_zip_basic() {
223        let mask = Mask::from_iter([true, false, false, true, false]);
224        let if_true = buffer![10, 20, 30, 40, 50].into_array();
225        let if_false = buffer![1, 2, 3, 4, 5].into_array();
226
227        let result = zip(&if_true, &if_false, &mask).unwrap();
228        let expected = buffer![10, 2, 3, 40, 5].into_array();
229
230        assert_eq!(
231            result.to_primitive().as_slice::<i32>(),
232            expected.to_primitive().as_slice::<i32>()
233        );
234    }
235
236    #[test]
237    fn test_zip_all_true() {
238        let mask = Mask::new_true(4);
239        let if_true = buffer![10, 20, 30, 40].into_array();
240        let if_false =
241            PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
242
243        let result = zip(&if_true, &if_false, &mask).unwrap();
244
245        assert_eq!(
246            result.to_primitive().as_slice::<i32>(),
247            if_true.to_primitive().as_slice::<i32>()
248        );
249
250        // result must be nullable even if_true was not
251        assert_eq!(result.dtype(), if_false.dtype())
252    }
253
254    #[test]
255    #[should_panic]
256    fn test_invalid_lengths() {
257        let mask = Mask::new_false(4);
258        let if_true = buffer![10, 20, 30].into_array();
259        let if_false = buffer![1, 2, 3, 4].into_array();
260
261        zip(&if_true, &if_false, &mask).unwrap();
262    }
263}