Skip to main content

vortex_array/scalar_fn/fns/zip/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use vortex_error::VortexExpect as _;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_mask::Mask;
13use vortex_mask::MaskValues;
14use vortex_session::VortexSession;
15use vortex_session::registry::CachedId;
16
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::arrays::BoolArray;
21use crate::arrays::bool::BoolArrayExt;
22use crate::builders::ArrayBuilder;
23use crate::builders::builder_with_capacity;
24use crate::builtins::ArrayBuiltins;
25use crate::dtype::DType;
26use crate::expr::Expression;
27use crate::scalar_fn::Arity;
28use crate::scalar_fn::ChildName;
29use crate::scalar_fn::EmptyOptions;
30use crate::scalar_fn::ExecutionArgs;
31use crate::scalar_fn::ScalarFnId;
32use crate::scalar_fn::ScalarFnVTable;
33use crate::scalar_fn::SimplifyCtx;
34use crate::scalar_fn::fns::literal::Literal;
35use crate::validity::Validity;
36
37/// An expression that conditionally selects between two arrays based on a boolean mask.
38///
39/// For each position `i`, `result[i] = if mask[i] then if_true[i] else if_false[i]`.
40///
41/// Null values in the mask are treated as false (selecting `if_false`). This follows
42/// SQL semantics (DuckDB, Trino) where a null condition falls through to the ELSE branch,
43/// rather than Arrow's `if_else` which propagates null conditions to the output.
44#[derive(Clone)]
45pub struct Zip;
46
47impl ScalarFnVTable for Zip {
48    type Options = EmptyOptions;
49
50    fn id(&self) -> ScalarFnId {
51        static ID: CachedId = CachedId::new("vortex.zip");
52        *ID
53    }
54
55    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
56        Ok(Some(vec![]))
57    }
58
59    fn deserialize(
60        &self,
61        _metadata: &[u8],
62        _session: &VortexSession,
63    ) -> VortexResult<Self::Options> {
64        Ok(EmptyOptions)
65    }
66
67    fn arity(&self, _options: &Self::Options) -> Arity {
68        Arity::Exact(3)
69    }
70
71    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
72        match child_idx {
73            0 => ChildName::from("if_true"),
74            1 => ChildName::from("if_false"),
75            2 => ChildName::from("mask"),
76            _ => unreachable!("Invalid child index {} for Zip expression", child_idx),
77        }
78    }
79
80    fn fmt_sql(
81        &self,
82        _options: &Self::Options,
83        expr: &Expression,
84        f: &mut Formatter<'_>,
85    ) -> std::fmt::Result {
86        write!(f, "zip(")?;
87        expr.child(0).fmt_sql(f)?;
88        write!(f, ", ")?;
89        expr.child(1).fmt_sql(f)?;
90        write!(f, ", ")?;
91        expr.child(2).fmt_sql(f)?;
92        write!(f, ")")
93    }
94
95    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
96        vortex_ensure!(
97            matches!(arg_dtypes[2], DType::Bool(_)),
98            "zip requires mask to be a boolean type, got {}",
99            arg_dtypes[2]
100        );
101        zip_return_dtype(&arg_dtypes[0], &arg_dtypes[1])
102    }
103
104    fn execute(
105        &self,
106        _options: &Self::Options,
107        args: &dyn ExecutionArgs,
108        ctx: &mut ExecutionCtx,
109    ) -> VortexResult<ArrayRef> {
110        let if_true = args.get(0)?;
111        let if_false = args.get(1)?;
112        let mask_array = args.get(2)?;
113
114        let mask = mask_array
115            .execute::<BoolArray>(ctx)?
116            .to_mask_fill_null_false(ctx);
117
118        let return_dtype = zip_return_dtype(if_true.dtype(), if_false.dtype())?;
119
120        if mask.all_true() {
121            return if_true.cast(return_dtype)?.execute(ctx);
122        }
123
124        if mask.all_false() {
125            return if_false.cast(return_dtype)?.execute(ctx);
126        }
127
128        if !if_true.is_canonical() || !if_false.is_canonical() {
129            let if_true = if_true.execute::<ArrayRef>(ctx)?;
130            let if_false = if_false.execute::<ArrayRef>(ctx)?;
131            return mask.into_array().zip(if_true, if_false);
132        }
133
134        zip_impl(&if_true, &if_false, &mask, ctx)
135    }
136
137    fn simplify(
138        &self,
139        _options: &Self::Options,
140        expr: &Expression,
141        _ctx: &dyn SimplifyCtx,
142    ) -> VortexResult<Option<Expression>> {
143        let Some(mask_lit) = expr.child(2).as_opt::<Literal>() else {
144            return Ok(None);
145        };
146
147        if let Some(mask_val) = mask_lit.as_bool().value() {
148            if mask_val {
149                return Ok(Some(expr.child(0).clone()));
150            } else {
151                return Ok(Some(expr.child(1).clone()));
152            }
153        }
154
155        Ok(None)
156    }
157
158    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
159        true
160    }
161
162    fn is_fallible(&self, _options: &Self::Options) -> bool {
163        false
164    }
165}
166
167pub(crate) fn zip_impl(
168    if_true: &ArrayRef,
169    if_false: &ArrayRef,
170    mask: &Mask,
171    ctx: &mut ExecutionCtx,
172) -> VortexResult<ArrayRef> {
173    assert_eq!(
174        if_true.len(),
175        if_false.len(),
176        "zip requires arrays to have the same size"
177    );
178
179    let return_type = zip_return_dtype(if_true.dtype(), if_false.dtype())?;
180
181    if mask.all_true() {
182        return if_true.cast(return_type);
183    }
184    if mask.all_false() {
185        return if_false.cast(return_type);
186    }
187
188    // `append_to_builder` requires exact dtype equality, so normalize branch
189    // nullability to the output dtype before appending slices into the builder.
190    let if_true = if_true.cast(return_type.clone())?;
191    let if_false = if_false.cast(return_type.clone())?;
192
193    zip_impl_with_builder(
194        &if_true,
195        &if_false,
196        mask.values()
197            .vortex_expect("zip_impl_with_builder: mask is not all-true or all-false"),
198        builder_with_capacity(&return_type, if_true.len()),
199        ctx,
200    )
201}
202
203fn zip_return_dtype(if_true: &DType, if_false: &DType) -> VortexResult<DType> {
204    vortex_ensure!(
205        if_true.eq_ignore_nullability(if_false),
206        "zip requires if_true and if_false to have the same base type, got {} and {}",
207        if_true,
208        if_false
209    );
210    Ok(if_true
211        .least_supertype(if_false)
212        .vortex_expect("zip inputs with the same base type must have a common dtype"))
213}
214
215fn zip_impl_with_builder(
216    if_true: &ArrayRef,
217    if_false: &ArrayRef,
218    mask: &MaskValues,
219    mut builder: Box<dyn ArrayBuilder>,
220    ctx: &mut ExecutionCtx,
221) -> VortexResult<ArrayRef> {
222    for (start, end) in mask.slices() {
223        if builder.len() < *start {
224            if_false
225                .slice(builder.len()..*start)?
226                .append_to_builder(builder.as_mut(), ctx)?;
227        }
228        if_true
229            .slice(*start..*end)?
230            .append_to_builder(builder.as_mut(), ctx)?;
231    }
232    if builder.len() < if_false.len() {
233        if_false
234            .slice(builder.len()..if_false.len())?
235            .append_to_builder(builder.as_mut(), ctx)?;
236    }
237    Ok(builder.finish())
238}
239
240/// Combine two validities for a row-wise zip: take `if_true`'s validity where `mask` is set and
241/// `if_false`'s where it is not.
242///
243/// That selection is itself a zip over the two boolean validity bitmaps, so it is built as a (lazy)
244/// zip array — reusing the zip machinery rather than re-deriving the mask algebra. Trivial cases
245/// where both sides' validity already agrees skip the zip. `mask` must already be null-filled so the
246/// selection matches the accompanying value selection. Shared by the per-encoding zip kernels (e.g.
247/// `Bool`, `Primitive`) that build their result directly.
248pub(crate) fn zip_validity(
249    if_true: Validity,
250    if_false: Validity,
251    mask: &Mask,
252) -> VortexResult<Validity> {
253    match (&if_true, &if_false) {
254        (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
255        (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
256        (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
257        _ => {}
258    }
259
260    let len = mask.len();
261    let validity = mask
262        .clone()
263        .into_array()
264        .zip(if_true.to_array(len), if_false.to_array(len))?;
265    Ok(Validity::Array(validity))
266}
267
268#[cfg(test)]
269mod tests {
270    use arrow_array::cast::AsArray;
271    use arrow_select::zip::zip as arrow_zip;
272    use vortex_buffer::buffer;
273    use vortex_error::VortexResult;
274    use vortex_mask::Mask;
275
276    use super::zip_impl;
277    use crate::ArrayRef;
278    use crate::IntoArray;
279    use crate::VortexSessionExecute;
280    use crate::array_session;
281    use crate::arrays::ConstantArray;
282    use crate::arrays::PrimitiveArray;
283    use crate::arrays::Struct;
284    use crate::arrays::StructArray;
285    use crate::arrays::VarBinView;
286    use crate::arrow::ArrowSessionExt;
287    use crate::assert_arrays_eq;
288    use crate::builders::ArrayBuilder;
289    use crate::builders::BufferGrowthStrategy;
290    use crate::builders::VarBinViewBuilder;
291    use crate::builtins::ArrayBuiltins;
292    use crate::columnar::Columnar;
293    use crate::dtype::DType;
294    use crate::dtype::Nullability;
295    use crate::dtype::PType;
296    use crate::expr::lit;
297    use crate::expr::root;
298    use crate::expr::zip_expr;
299    use crate::scalar::Scalar;
300
301    #[test]
302    fn dtype() {
303        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
304        let expr = zip_expr(lit(true), root(), lit(0i32));
305        let result_dtype = expr.return_dtype(&dtype).unwrap();
306        assert_eq!(
307            result_dtype,
308            DType::Primitive(PType::I32, Nullability::NonNullable)
309        );
310    }
311
312    #[test]
313    fn test_display() {
314        let expr = zip_expr(lit(true), root(), lit(0i32));
315        assert_eq!(expr.to_string(), "zip($, 0i32, true)");
316    }
317
318    #[test]
319    fn test_zip_basic() {
320        let mut ctx = array_session().create_execution_ctx();
321        let mask = Mask::from_iter([true, false, false, true, false]);
322        let if_true = buffer![10, 20, 30, 40, 50].into_array();
323        let if_false = buffer![1, 2, 3, 4, 5].into_array();
324
325        let result = mask.into_array().zip(if_true, if_false).unwrap();
326        let expected = buffer![10, 2, 3, 40, 5].into_array();
327
328        assert_arrays_eq!(result, expected, &mut ctx);
329    }
330
331    #[test]
332    fn test_zip_all_true() {
333        let mut ctx = array_session().create_execution_ctx();
334        let mask = Mask::new_true(4);
335        let if_true = buffer![10, 20, 30, 40].into_array();
336        let if_false =
337            PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
338
339        let result = mask.into_array().zip(if_true, if_false.clone()).unwrap();
340        let expected =
341            PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), Some(40)]).into_array();
342
343        assert_arrays_eq!(result, expected, &mut ctx);
344        assert_eq!(result.dtype(), if_false.dtype())
345    }
346
347    #[test]
348    fn test_zip_all_false_widens_nullability() {
349        let mut ctx = array_session().create_execution_ctx();
350        let mask = Mask::new_false(4);
351        let if_true =
352            PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array();
353        let if_false = buffer![1i32, 2, 3, 4].into_array();
354
355        let result = mask.into_array().zip(if_true.clone(), if_false).unwrap();
356        let expected =
357            PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), Some(4)]).into_array();
358
359        assert_arrays_eq!(result, expected, &mut ctx);
360        assert_eq!(result.dtype(), if_true.dtype());
361    }
362
363    #[test]
364    fn test_zip_impl_all_true_widens_nullability() -> VortexResult<()> {
365        let mask = Mask::new_true(4);
366        let if_true = buffer![10i32, 20, 30, 40].into_array();
367        let if_false =
368            PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
369
370        let mut ctx = array_session().create_execution_ctx();
371        let result = zip_impl(&if_true, &if_false, &mask, &mut ctx)?;
372        assert_arrays_eq!(
373            result,
374            PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40)])
375                .into_array(),
376            &mut ctx
377        );
378        assert_eq!(result.dtype(), if_false.dtype());
379        Ok(())
380    }
381
382    #[test]
383    fn test_zip_impl_all_false_widens_nullability() -> VortexResult<()> {
384        let mask = Mask::new_false(4);
385        let if_true =
386            PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array();
387        let if_false = buffer![1i32, 2, 3, 4].into_array();
388
389        let mut ctx = array_session().create_execution_ctx();
390        let result = zip_impl(&if_true, &if_false, &mask, &mut ctx)?;
391        assert_arrays_eq!(
392            result,
393            PrimitiveArray::from_option_iter([Some(1i32), Some(2), Some(3), Some(4)]).into_array(),
394            &mut ctx
395        );
396        assert_eq!(result.dtype(), if_true.dtype());
397        Ok(())
398    }
399
400    #[test]
401    #[should_panic]
402    fn test_invalid_lengths() {
403        let mask = Mask::new_false(4);
404        let if_true = buffer![10, 20, 30].into_array();
405        let if_false = buffer![1, 2, 3, 4].into_array();
406
407        let _result = mask.into_array().zip(if_true, if_false).unwrap();
408    }
409
410    #[test]
411    fn test_fragmentation() -> VortexResult<()> {
412        let len = 100;
413
414        let const1 = ConstantArray::new(
415            Scalar::utf8("hello_this_is_a_longer_string", Nullability::Nullable),
416            len,
417        )
418        .into_array();
419
420        let const2 = ConstantArray::new(
421            Scalar::utf8("world_this_is_another_string", Nullability::Nullable),
422            len,
423        )
424        .into_array();
425
426        let indices: Vec<usize> = (0..len).step_by(2).collect();
427        let mask = Mask::from_indices(len, indices);
428        let mask_array = mask.into_array();
429
430        let mut ctx = array_session().create_execution_ctx();
431        let result = mask_array
432            .zip(const1.clone(), const2.clone())?
433            .execute::<Columnar>(&mut ctx)?
434            .into_array();
435
436        insta::assert_snapshot!(result.display_tree(), @r"
437        root: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid]
438          metadata: 
439          buffer: buffer_0 host 29 B (align=1) (1.75%)
440          buffer: buffer_1 host 28 B (align=1) (1.69%)
441          buffer: views host 1.60 kB (align=16) (96.56%)
442        ");
443
444        let wrapped1 = StructArray::try_from_iter([("nested", const1)])?.into_array();
445        let wrapped2 = StructArray::try_from_iter([("nested", const2)])?.into_array();
446
447        let wrapped_result = mask_array
448            .zip(wrapped1, wrapped2)?
449            .execute::<ArrayRef>(&mut ctx)?;
450        assert!(wrapped_result.is::<Struct>());
451
452        Ok(())
453    }
454
455    #[test]
456    fn test_varbinview_zip() {
457        let if_true = {
458            let mut builder = VarBinViewBuilder::new(
459                DType::Utf8(Nullability::NonNullable),
460                10,
461                Default::default(),
462                BufferGrowthStrategy::fixed(64 * 1024),
463                0.0,
464            );
465            for _ in 0..100 {
466                builder.append_value("Hello");
467                builder.append_value("Hello this is a long string that won't be inlined.");
468            }
469            builder.finish()
470        };
471
472        let if_false = {
473            let mut builder = VarBinViewBuilder::new(
474                DType::Utf8(Nullability::NonNullable),
475                10,
476                Default::default(),
477                BufferGrowthStrategy::fixed(64 * 1024),
478                0.0,
479            );
480            for _ in 0..100 {
481                builder.append_value("Hello2");
482                builder.append_value("Hello2 this is a long string that won't be inlined.");
483            }
484            builder.finish()
485        };
486
487        let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0));
488        let mask_array = mask.clone().into_array();
489
490        let mut ctx = array_session().create_execution_ctx();
491        let zipped = mask_array
492            .zip(if_true.clone(), if_false.clone())
493            .unwrap()
494            .execute::<ArrayRef>(&mut ctx)
495            .unwrap();
496        let zipped = zipped.as_opt::<VarBinView>().unwrap();
497        assert_eq!(zipped.data_buffers().len(), 2);
498
499        let mut arrow_ctx = array_session().create_execution_ctx();
500        let expected = arrow_zip(
501            array_session()
502                .arrow()
503                .execute_arrow(mask.into_array(), None, &mut arrow_ctx)
504                .unwrap()
505                .as_boolean(),
506            &array_session()
507                .arrow()
508                .execute_arrow(if_true, None, &mut arrow_ctx)
509                .unwrap(),
510            &array_session()
511                .arrow()
512                .execute_arrow(if_false, None, &mut arrow_ctx)
513                .unwrap(),
514        )
515        .unwrap();
516
517        let actual = array_session()
518            .arrow()
519            .execute_arrow(zipped.array().clone(), None, &mut arrow_ctx)
520            .unwrap();
521        assert_eq!(actual.as_ref(), expected.as_ref());
522    }
523}