Skip to main content

vortex_array/scalar_fn/fns/zip/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_mask::AllOr;
12use vortex_mask::Mask;
13use vortex_session::VortexSession;
14
15use crate::ArrayRef;
16use crate::DynArray;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::arrays::BoolArray;
20use crate::builders::ArrayBuilder;
21use crate::builders::builder_with_capacity;
22use crate::builtins::ArrayBuiltins;
23use crate::dtype::DType;
24use crate::expr::Expression;
25use crate::scalar_fn::Arity;
26use crate::scalar_fn::ChildName;
27use crate::scalar_fn::EmptyOptions;
28use crate::scalar_fn::ExecutionArgs;
29use crate::scalar_fn::ScalarFnId;
30use crate::scalar_fn::ScalarFnVTable;
31use crate::scalar_fn::SimplifyCtx;
32use crate::scalar_fn::fns::literal::Literal;
33
34/// An expression that conditionally selects between two arrays based on a boolean mask.
35///
36/// For each position `i`, `result[i] = if mask[i] then if_true[i] else if_false[i]`.
37///
38/// Null values in the mask are treated as false (selecting `if_false`). This follows
39/// SQL semantics (DuckDB, Trino) where a null condition falls through to the ELSE branch,
40/// rather than Arrow's `if_else` which propagates null conditions to the output.
41#[derive(Clone)]
42pub struct Zip;
43
44impl ScalarFnVTable for Zip {
45    type Options = EmptyOptions;
46
47    fn id(&self) -> ScalarFnId {
48        ScalarFnId::from("vortex.zip")
49    }
50
51    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
52        Ok(Some(vec![]))
53    }
54
55    fn deserialize(
56        &self,
57        _metadata: &[u8],
58        _session: &VortexSession,
59    ) -> VortexResult<Self::Options> {
60        Ok(EmptyOptions)
61    }
62
63    fn arity(&self, _options: &Self::Options) -> Arity {
64        Arity::Exact(3)
65    }
66
67    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
68        match child_idx {
69            0 => ChildName::from("if_true"),
70            1 => ChildName::from("if_false"),
71            2 => ChildName::from("mask"),
72            _ => unreachable!("Invalid child index {} for Zip expression", child_idx),
73        }
74    }
75
76    fn fmt_sql(
77        &self,
78        _options: &Self::Options,
79        expr: &Expression,
80        f: &mut Formatter<'_>,
81    ) -> std::fmt::Result {
82        write!(f, "zip(")?;
83        expr.child(0).fmt_sql(f)?;
84        write!(f, ", ")?;
85        expr.child(1).fmt_sql(f)?;
86        write!(f, ", ")?;
87        expr.child(2).fmt_sql(f)?;
88        write!(f, ")")
89    }
90
91    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
92        vortex_ensure!(
93            arg_dtypes[0].eq_ignore_nullability(&arg_dtypes[1]),
94            "zip requires if_true and if_false to have the same base type, got {} and {}",
95            arg_dtypes[0],
96            arg_dtypes[1]
97        );
98        vortex_ensure!(
99            matches!(arg_dtypes[2], DType::Bool(_)),
100            "zip requires mask to be a boolean type, got {}",
101            arg_dtypes[2]
102        );
103        Ok(arg_dtypes[0]
104            .clone()
105            .union_nullability(arg_dtypes[1].nullability()))
106    }
107
108    fn execute(
109        &self,
110        _options: &Self::Options,
111        args: &dyn ExecutionArgs,
112        ctx: &mut ExecutionCtx,
113    ) -> VortexResult<ArrayRef> {
114        let if_true = args.get(0)?;
115        let if_false = args.get(1)?;
116        let mask_array = args.get(2)?;
117
118        let mask = mask_array.execute::<BoolArray>(ctx)?.to_mask();
119
120        let return_dtype = if_true
121            .dtype()
122            .clone()
123            .union_nullability(if_false.dtype().nullability());
124
125        if mask.all_true() {
126            return if_true.cast(return_dtype)?.execute(ctx);
127        }
128
129        let return_dtype = if_true
130            .dtype()
131            .clone()
132            .union_nullability(if_false.dtype().nullability());
133
134        if mask.all_false() {
135            return if_false.cast(return_dtype)?.execute(ctx);
136        }
137
138        if !if_true.is_canonical() || !if_false.is_canonical() {
139            let if_true = if_true.execute::<ArrayRef>(ctx)?;
140            let if_false = if_false.execute::<ArrayRef>(ctx)?;
141            return mask.into_array().zip(if_true, if_false);
142        }
143
144        zip_impl(&if_true, &if_false, &mask)
145    }
146
147    fn simplify(
148        &self,
149        _options: &Self::Options,
150        expr: &Expression,
151        _ctx: &dyn SimplifyCtx,
152    ) -> VortexResult<Option<Expression>> {
153        let Some(mask_lit) = expr.child(2).as_opt::<Literal>() else {
154            return Ok(None);
155        };
156
157        if let Some(mask_val) = mask_lit.as_bool().value() {
158            if mask_val {
159                return Ok(Some(expr.child(0).clone()));
160            } else {
161                return Ok(Some(expr.child(1).clone()));
162            }
163        }
164
165        Ok(None)
166    }
167
168    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
169        true
170    }
171
172    fn is_fallible(&self, _options: &Self::Options) -> bool {
173        false
174    }
175}
176
177pub(crate) fn zip_impl(
178    if_true: &ArrayRef,
179    if_false: &ArrayRef,
180    mask: &Mask,
181) -> VortexResult<ArrayRef> {
182    assert_eq!(
183        if_true.len(),
184        if_false.len(),
185        "zip requires arrays to have the same size"
186    );
187
188    let return_type = if_true
189        .dtype()
190        .clone()
191        .union_nullability(if_false.dtype().nullability());
192    zip_impl_with_builder(
193        if_true,
194        if_false,
195        mask,
196        builder_with_capacity(&return_type, if_true.len()),
197    )
198}
199
200fn zip_impl_with_builder(
201    if_true: &ArrayRef,
202    if_false: &ArrayRef,
203    mask: &Mask,
204    mut builder: Box<dyn ArrayBuilder>,
205) -> VortexResult<ArrayRef> {
206    match mask.slices() {
207        AllOr::All => Ok(if_true.to_array()),
208        AllOr::None => Ok(if_false.to_array()),
209        AllOr::Some(slices) => {
210            for (start, end) in slices {
211                builder.extend_from_array(&if_false.slice(builder.len()..*start)?);
212                builder.extend_from_array(&if_true.slice(*start..*end)?);
213            }
214            if builder.len() < if_false.len() {
215                builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?);
216            }
217            Ok(builder.finish())
218        }
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use arrow_array::cast::AsArray;
225    use arrow_select::zip::zip as arrow_zip;
226    use vortex_buffer::buffer;
227    use vortex_error::VortexResult;
228    use vortex_mask::Mask;
229
230    use crate::ArrayRef;
231    use crate::IntoArray;
232    use crate::LEGACY_SESSION;
233    use crate::VortexSessionExecute;
234    use crate::arrays::ConstantArray;
235    use crate::arrays::PrimitiveArray;
236    use crate::arrays::StructArray;
237    use crate::arrays::StructVTable;
238    use crate::arrays::VarBinViewArray;
239    use crate::arrow::IntoArrowArray;
240    use crate::assert_arrays_eq;
241    use crate::builders::ArrayBuilder;
242    use crate::builders::BufferGrowthStrategy;
243    use crate::builders::VarBinViewBuilder;
244    use crate::builtins::ArrayBuiltins;
245    use crate::dtype::DType;
246    use crate::dtype::Nullability;
247    use crate::dtype::PType;
248    use crate::expr::lit;
249    use crate::expr::root;
250    use crate::expr::zip_expr;
251    use crate::scalar::Scalar;
252
253    #[test]
254    fn dtype() {
255        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
256        let expr = zip_expr(lit(true), root(), lit(0i32));
257        let result_dtype = expr.return_dtype(&dtype).unwrap();
258        assert_eq!(
259            result_dtype,
260            DType::Primitive(PType::I32, Nullability::NonNullable)
261        );
262    }
263
264    #[test]
265    fn test_display() {
266        let expr = zip_expr(lit(true), root(), lit(0i32));
267        assert_eq!(expr.to_string(), "zip($, 0i32, true)");
268    }
269
270    #[test]
271    fn test_zip_basic() {
272        let mask = Mask::from_iter([true, false, false, true, false]);
273        let if_true = buffer![10, 20, 30, 40, 50].into_array();
274        let if_false = buffer![1, 2, 3, 4, 5].into_array();
275
276        let result = mask.into_array().zip(if_true, if_false).unwrap();
277        let expected = buffer![10, 2, 3, 40, 5].into_array();
278
279        assert_arrays_eq!(result, expected);
280    }
281
282    #[test]
283    fn test_zip_all_true() {
284        let mask = Mask::new_true(4);
285        let if_true = buffer![10, 20, 30, 40].into_array();
286        let if_false =
287            PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
288
289        let result = mask.into_array().zip(if_true, if_false.clone()).unwrap();
290        let expected =
291            PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), Some(40)]).into_array();
292
293        assert_arrays_eq!(result, expected);
294
295        // result must be nullable even if_true was not
296        assert_eq!(result.dtype(), if_false.dtype())
297    }
298
299    #[test]
300    #[should_panic]
301    fn test_invalid_lengths() {
302        let mask = Mask::new_false(4);
303        let if_true = buffer![10, 20, 30].into_array();
304        let if_false = buffer![1, 2, 3, 4].into_array();
305
306        let _result = mask.into_array().zip(if_true, if_false).unwrap();
307    }
308
309    #[test]
310    fn test_fragmentation() -> VortexResult<()> {
311        let len = 100;
312
313        let const1 = ConstantArray::new(
314            Scalar::utf8("hello_this_is_a_longer_string", Nullability::Nullable),
315            len,
316        )
317        .into_array();
318
319        let const2 = ConstantArray::new(
320            Scalar::utf8("world_this_is_another_string", Nullability::Nullable),
321            len,
322        )
323        .into_array();
324
325        let indices: Vec<usize> = (0..len).step_by(2).collect();
326        let mask = Mask::from_indices(len, indices);
327        let mask_array = mask.into_array();
328
329        let result = mask_array
330            .clone()
331            .zip(const1.clone(), const2.clone())?
332            .execute::<VarBinViewArray>(&mut LEGACY_SESSION.create_execution_ctx())?;
333
334        insta::assert_snapshot!(result.display_tree(), @r"
335        root: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid]
336          metadata: EmptyMetadata
337          buffer: buffer_0 host 29 B (align=1) (1.75%)
338          buffer: buffer_1 host 28 B (align=1) (1.69%)
339          buffer: views host 1.60 kB (align=16) (96.56%)
340        ");
341
342        // test wrapped in a struct
343        let wrapped1 = StructArray::try_from_iter([("nested", const1)])?.into_array();
344        let wrapped2 = StructArray::try_from_iter([("nested", const2)])?.into_array();
345
346        let wrapped_result = mask_array
347            .zip(wrapped1, wrapped2)?
348            .execute::<ArrayRef>(&mut LEGACY_SESSION.create_execution_ctx())?;
349        assert!(wrapped_result.is::<StructVTable>());
350
351        Ok(())
352    }
353
354    #[test]
355    fn test_varbinview_zip() {
356        let if_true = {
357            let mut builder = VarBinViewBuilder::new(
358                DType::Utf8(Nullability::NonNullable),
359                10,
360                Default::default(),
361                BufferGrowthStrategy::fixed(64 * 1024),
362                0.0,
363            );
364            for _ in 0..100 {
365                builder.append_value("Hello");
366                builder.append_value("Hello this is a long string that won't be inlined.");
367            }
368            builder.finish()
369        };
370
371        let if_false = {
372            let mut builder = VarBinViewBuilder::new(
373                DType::Utf8(Nullability::NonNullable),
374                10,
375                Default::default(),
376                BufferGrowthStrategy::fixed(64 * 1024),
377                0.0,
378            );
379            for _ in 0..100 {
380                builder.append_value("Hello2");
381                builder.append_value("Hello2 this is a long string that won't be inlined.");
382            }
383            builder.finish()
384        };
385
386        // [1,2,4,5,7,8,..]
387        let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect());
388        let mask_array = mask.clone().into_array();
389
390        let zipped = mask_array
391            .zip(if_true.clone(), if_false.clone())
392            .unwrap()
393            .execute::<VarBinViewArray>(&mut LEGACY_SESSION.create_execution_ctx())
394            .unwrap();
395        assert_eq!(zipped.nbuffers(), 2);
396
397        // assert the result is the same as arrow
398        let expected = arrow_zip(
399            mask.into_array()
400                .into_arrow_preferred()
401                .unwrap()
402                .as_boolean(),
403            &if_true.into_arrow_preferred().unwrap(),
404            &if_false.into_arrow_preferred().unwrap(),
405        )
406        .unwrap();
407
408        let actual = zipped.into_array().into_arrow_preferred().unwrap();
409        assert_eq!(actual.as_ref(), expected.as_ref());
410    }
411}