Skip to main content

vortex_array/scalar_fn/fns/zip/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use vortex_error::VortexExpect as _;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_mask::Mask;
13use vortex_mask::MaskValues;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::arrays::BoolArray;
20use crate::arrays::bool::BoolArrayExt;
21use crate::builders::ArrayBuilder;
22use crate::builders::builder_with_capacity;
23use crate::builtins::ArrayBuiltins;
24use crate::dtype::DType;
25use crate::expr::Expression;
26use crate::scalar_fn::Arity;
27use crate::scalar_fn::ChildName;
28use crate::scalar_fn::EmptyOptions;
29use crate::scalar_fn::ExecutionArgs;
30use crate::scalar_fn::ScalarFnId;
31use crate::scalar_fn::ScalarFnVTable;
32use crate::scalar_fn::SimplifyCtx;
33use crate::scalar_fn::fns::literal::Literal;
34
35/// An expression that conditionally selects between two arrays based on a boolean mask.
36///
37/// For each position `i`, `result[i] = if mask[i] then if_true[i] else if_false[i]`.
38///
39/// Null values in the mask are treated as false (selecting `if_false`). This follows
40/// SQL semantics (DuckDB, Trino) where a null condition falls through to the ELSE branch,
41/// rather than Arrow's `if_else` which propagates null conditions to the output.
42#[derive(Clone)]
43pub struct Zip;
44
45impl ScalarFnVTable for Zip {
46    type Options = EmptyOptions;
47
48    fn id(&self) -> ScalarFnId {
49        ScalarFnId::new("vortex.zip")
50    }
51
52    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
53        Ok(Some(vec![]))
54    }
55
56    fn deserialize(
57        &self,
58        _metadata: &[u8],
59        _session: &VortexSession,
60    ) -> VortexResult<Self::Options> {
61        Ok(EmptyOptions)
62    }
63
64    fn arity(&self, _options: &Self::Options) -> Arity {
65        Arity::Exact(3)
66    }
67
68    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
69        match child_idx {
70            0 => ChildName::from("if_true"),
71            1 => ChildName::from("if_false"),
72            2 => ChildName::from("mask"),
73            _ => unreachable!("Invalid child index {} for Zip expression", child_idx),
74        }
75    }
76
77    fn fmt_sql(
78        &self,
79        _options: &Self::Options,
80        expr: &Expression,
81        f: &mut Formatter<'_>,
82    ) -> std::fmt::Result {
83        write!(f, "zip(")?;
84        expr.child(0).fmt_sql(f)?;
85        write!(f, ", ")?;
86        expr.child(1).fmt_sql(f)?;
87        write!(f, ", ")?;
88        expr.child(2).fmt_sql(f)?;
89        write!(f, ")")
90    }
91
92    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
93        vortex_ensure!(
94            arg_dtypes[0].eq_ignore_nullability(&arg_dtypes[1]),
95            "zip requires if_true and if_false to have the same base type, got {} and {}",
96            arg_dtypes[0],
97            arg_dtypes[1]
98        );
99        vortex_ensure!(
100            matches!(arg_dtypes[2], DType::Bool(_)),
101            "zip requires mask to be a boolean type, got {}",
102            arg_dtypes[2]
103        );
104        Ok(arg_dtypes[0]
105            .clone()
106            .union_nullability(arg_dtypes[1].nullability()))
107    }
108
109    fn execute(
110        &self,
111        _options: &Self::Options,
112        args: &dyn ExecutionArgs,
113        ctx: &mut ExecutionCtx,
114    ) -> VortexResult<ArrayRef> {
115        let if_true = args.get(0)?;
116        let if_false = args.get(1)?;
117        let mask_array = args.get(2)?;
118
119        let mask = mask_array
120            .execute::<BoolArray>(ctx)?
121            .to_mask_fill_null_false(ctx);
122
123        let return_dtype = if_true
124            .dtype()
125            .clone()
126            .union_nullability(if_false.dtype().nullability());
127
128        if mask.all_true() {
129            return if_true.cast(return_dtype)?.execute(ctx);
130        }
131
132        if mask.all_false() {
133            return if_false.cast(return_dtype)?.execute(ctx);
134        }
135
136        if !if_true.is_canonical() || !if_false.is_canonical() {
137            let if_true = if_true.execute::<ArrayRef>(ctx)?;
138            let if_false = if_false.execute::<ArrayRef>(ctx)?;
139            return mask.into_array().zip(if_true, if_false);
140        }
141
142        zip_impl(&if_true, &if_false, &mask, ctx)
143    }
144
145    fn simplify(
146        &self,
147        _options: &Self::Options,
148        expr: &Expression,
149        _ctx: &dyn SimplifyCtx,
150    ) -> VortexResult<Option<Expression>> {
151        let Some(mask_lit) = expr.child(2).as_opt::<Literal>() else {
152            return Ok(None);
153        };
154
155        if let Some(mask_val) = mask_lit.as_bool().value() {
156            if mask_val {
157                return Ok(Some(expr.child(0).clone()));
158            } else {
159                return Ok(Some(expr.child(1).clone()));
160            }
161        }
162
163        Ok(None)
164    }
165
166    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
167        true
168    }
169
170    fn is_fallible(&self, _options: &Self::Options) -> bool {
171        false
172    }
173}
174
175pub(crate) fn zip_impl(
176    if_true: &ArrayRef,
177    if_false: &ArrayRef,
178    mask: &Mask,
179    ctx: &mut ExecutionCtx,
180) -> VortexResult<ArrayRef> {
181    assert_eq!(
182        if_true.len(),
183        if_false.len(),
184        "zip requires arrays to have the same size"
185    );
186
187    let return_type = if_true
188        .dtype()
189        .clone()
190        .union_nullability(if_false.dtype().nullability());
191
192    if mask.all_true() {
193        return if_true.cast(return_type);
194    }
195    if mask.all_false() {
196        return if_false.cast(return_type);
197    }
198
199    // `append_to_builder` requires exact dtype equality, so normalize branch
200    // nullability to the output dtype before appending slices into the builder.
201    let if_true = if_true.cast(return_type.clone())?;
202    let if_false = if_false.cast(return_type.clone())?;
203
204    zip_impl_with_builder(
205        &if_true,
206        &if_false,
207        mask.values()
208            .vortex_expect("zip_impl_with_builder: mask is not all-true or all-false"),
209        builder_with_capacity(&return_type, if_true.len()),
210        ctx,
211    )
212}
213
214fn zip_impl_with_builder(
215    if_true: &ArrayRef,
216    if_false: &ArrayRef,
217    mask: &MaskValues,
218    mut builder: Box<dyn ArrayBuilder>,
219    ctx: &mut ExecutionCtx,
220) -> VortexResult<ArrayRef> {
221    for (start, end) in mask.slices() {
222        if builder.len() < *start {
223            if_false
224                .slice(builder.len()..*start)?
225                .append_to_builder(builder.as_mut(), ctx)?;
226        }
227        if_true
228            .slice(*start..*end)?
229            .append_to_builder(builder.as_mut(), ctx)?;
230    }
231    if builder.len() < if_false.len() {
232        if_false
233            .slice(builder.len()..if_false.len())?
234            .append_to_builder(builder.as_mut(), ctx)?;
235    }
236    Ok(builder.finish())
237}
238
239#[cfg(test)]
240mod tests {
241    use arrow_array::cast::AsArray;
242    use arrow_select::zip::zip as arrow_zip;
243    use vortex_buffer::buffer;
244    use vortex_error::VortexResult;
245    use vortex_mask::Mask;
246
247    use super::zip_impl;
248    use crate::ArrayRef;
249    use crate::IntoArray;
250    use crate::LEGACY_SESSION;
251    use crate::VortexSessionExecute;
252    use crate::arrays::ConstantArray;
253    use crate::arrays::PrimitiveArray;
254    use crate::arrays::Struct;
255    use crate::arrays::StructArray;
256    use crate::arrays::VarBinView;
257    use crate::arrow::ArrowSessionExt;
258    use crate::assert_arrays_eq;
259    use crate::builders::ArrayBuilder;
260    use crate::builders::BufferGrowthStrategy;
261    use crate::builders::VarBinViewBuilder;
262    use crate::builtins::ArrayBuiltins;
263    use crate::columnar::Columnar;
264    use crate::dtype::DType;
265    use crate::dtype::Nullability;
266    use crate::dtype::PType;
267    use crate::expr::lit;
268    use crate::expr::root;
269    use crate::expr::zip_expr;
270    use crate::scalar::Scalar;
271
272    #[test]
273    fn dtype() {
274        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
275        let expr = zip_expr(lit(true), root(), lit(0i32));
276        let result_dtype = expr.return_dtype(&dtype).unwrap();
277        assert_eq!(
278            result_dtype,
279            DType::Primitive(PType::I32, Nullability::NonNullable)
280        );
281    }
282
283    #[test]
284    fn test_display() {
285        let expr = zip_expr(lit(true), root(), lit(0i32));
286        assert_eq!(expr.to_string(), "zip($, 0i32, true)");
287    }
288
289    #[test]
290    fn test_zip_basic() {
291        let mask = Mask::from_iter([true, false, false, true, false]);
292        let if_true = buffer![10, 20, 30, 40, 50].into_array();
293        let if_false = buffer![1, 2, 3, 4, 5].into_array();
294
295        let result = mask.into_array().zip(if_true, if_false).unwrap();
296        let expected = buffer![10, 2, 3, 40, 5].into_array();
297
298        assert_arrays_eq!(result, expected);
299    }
300
301    #[test]
302    fn test_zip_all_true() {
303        let mask = Mask::new_true(4);
304        let if_true = buffer![10, 20, 30, 40].into_array();
305        let if_false =
306            PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
307
308        let result = mask.into_array().zip(if_true, if_false.clone()).unwrap();
309        let expected =
310            PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), Some(40)]).into_array();
311
312        assert_arrays_eq!(result, expected);
313        assert_eq!(result.dtype(), if_false.dtype())
314    }
315
316    #[test]
317    fn test_zip_all_false_widens_nullability() {
318        let mask = Mask::new_false(4);
319        let if_true =
320            PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array();
321        let if_false = buffer![1i32, 2, 3, 4].into_array();
322
323        let result = mask.into_array().zip(if_true.clone(), if_false).unwrap();
324        let expected =
325            PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), Some(4)]).into_array();
326
327        assert_arrays_eq!(result, expected);
328        assert_eq!(result.dtype(), if_true.dtype());
329    }
330
331    #[test]
332    fn test_zip_impl_all_true_widens_nullability() -> VortexResult<()> {
333        let mask = Mask::new_true(4);
334        let if_true = buffer![10i32, 20, 30, 40].into_array();
335        let if_false =
336            PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
337
338        let mut ctx = LEGACY_SESSION.create_execution_ctx();
339        let result = zip_impl(&if_true, &if_false, &mask, &mut ctx)?;
340        assert_arrays_eq!(
341            result,
342            PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40)])
343                .into_array()
344        );
345        assert_eq!(result.dtype(), if_false.dtype());
346        Ok(())
347    }
348
349    #[test]
350    fn test_zip_impl_all_false_widens_nullability() -> VortexResult<()> {
351        let mask = Mask::new_false(4);
352        let if_true =
353            PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array();
354        let if_false = buffer![1i32, 2, 3, 4].into_array();
355
356        let mut ctx = LEGACY_SESSION.create_execution_ctx();
357        let result = zip_impl(&if_true, &if_false, &mask, &mut ctx)?;
358        assert_arrays_eq!(
359            result,
360            PrimitiveArray::from_option_iter([Some(1i32), Some(2), Some(3), Some(4)]).into_array()
361        );
362        assert_eq!(result.dtype(), if_true.dtype());
363        Ok(())
364    }
365
366    #[test]
367    #[should_panic]
368    fn test_invalid_lengths() {
369        let mask = Mask::new_false(4);
370        let if_true = buffer![10, 20, 30].into_array();
371        let if_false = buffer![1, 2, 3, 4].into_array();
372
373        let _result = mask.into_array().zip(if_true, if_false).unwrap();
374    }
375
376    #[test]
377    fn test_fragmentation() -> VortexResult<()> {
378        let len = 100;
379
380        let const1 = ConstantArray::new(
381            Scalar::utf8("hello_this_is_a_longer_string", Nullability::Nullable),
382            len,
383        )
384        .into_array();
385
386        let const2 = ConstantArray::new(
387            Scalar::utf8("world_this_is_another_string", Nullability::Nullable),
388            len,
389        )
390        .into_array();
391
392        let indices: Vec<usize> = (0..len).step_by(2).collect();
393        let mask = Mask::from_indices(len, indices);
394        let mask_array = mask.into_array();
395
396        let mut ctx = LEGACY_SESSION.create_execution_ctx();
397        let result = mask_array
398            .zip(const1.clone(), const2.clone())?
399            .execute::<Columnar>(&mut ctx)?
400            .into_array();
401
402        insta::assert_snapshot!(result.display_tree(), @r"
403        root: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid]
404          metadata: 
405          buffer: buffer_0 host 29 B (align=1) (1.75%)
406          buffer: buffer_1 host 28 B (align=1) (1.69%)
407          buffer: views host 1.60 kB (align=16) (96.56%)
408        ");
409
410        let wrapped1 = StructArray::try_from_iter([("nested", const1)])?.into_array();
411        let wrapped2 = StructArray::try_from_iter([("nested", const2)])?.into_array();
412
413        let wrapped_result = mask_array
414            .zip(wrapped1, wrapped2)?
415            .execute::<ArrayRef>(&mut ctx)?;
416        assert!(wrapped_result.is::<Struct>());
417
418        Ok(())
419    }
420
421    #[test]
422    fn test_varbinview_zip() {
423        let if_true = {
424            let mut builder = VarBinViewBuilder::new(
425                DType::Utf8(Nullability::NonNullable),
426                10,
427                Default::default(),
428                BufferGrowthStrategy::fixed(64 * 1024),
429                0.0,
430            );
431            for _ in 0..100 {
432                builder.append_value("Hello");
433                builder.append_value("Hello this is a long string that won't be inlined.");
434            }
435            builder.finish()
436        };
437
438        let if_false = {
439            let mut builder = VarBinViewBuilder::new(
440                DType::Utf8(Nullability::NonNullable),
441                10,
442                Default::default(),
443                BufferGrowthStrategy::fixed(64 * 1024),
444                0.0,
445            );
446            for _ in 0..100 {
447                builder.append_value("Hello2");
448                builder.append_value("Hello2 this is a long string that won't be inlined.");
449            }
450            builder.finish()
451        };
452
453        let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0));
454        let mask_array = mask.clone().into_array();
455
456        let mut ctx = LEGACY_SESSION.create_execution_ctx();
457        let zipped = mask_array
458            .zip(if_true.clone(), if_false.clone())
459            .unwrap()
460            .execute::<ArrayRef>(&mut ctx)
461            .unwrap();
462        let zipped = zipped.as_opt::<VarBinView>().unwrap();
463        assert_eq!(zipped.data_buffers().len(), 2);
464
465        let mut arrow_ctx = LEGACY_SESSION.create_execution_ctx();
466        let expected = arrow_zip(
467            LEGACY_SESSION
468                .arrow()
469                .execute_arrow(mask.into_array(), None, &mut arrow_ctx)
470                .unwrap()
471                .as_boolean(),
472            &LEGACY_SESSION
473                .arrow()
474                .execute_arrow(if_true, None, &mut arrow_ctx)
475                .unwrap(),
476            &LEGACY_SESSION
477                .arrow()
478                .execute_arrow(if_false, None, &mut arrow_ctx)
479                .unwrap(),
480        )
481        .unwrap();
482
483        let actual = LEGACY_SESSION
484            .arrow()
485            .execute_arrow(zipped.array().clone(), None, &mut arrow_ctx)
486            .unwrap();
487        assert_eq!(actual.as_ref(), expected.as_ref());
488    }
489}