Skip to main content

vortex_array/scalar_fn/fns/zip/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_mask::AllOr;
12use vortex_mask::Mask;
13use vortex_session::VortexSession;
14
15use crate::Array;
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::arrays::BoolArray;
20use crate::builders::ArrayBuilder;
21use crate::builders::builder_with_capacity;
22use crate::builtins::ArrayBuiltins;
23use crate::dtype::DType;
24use crate::expr::Expression;
25use crate::scalar_fn::Arity;
26use crate::scalar_fn::ChildName;
27use crate::scalar_fn::EmptyOptions;
28use crate::scalar_fn::ExecutionArgs;
29use crate::scalar_fn::ScalarFnId;
30use crate::scalar_fn::ScalarFnVTable;
31use crate::scalar_fn::SimplifyCtx;
32use crate::scalar_fn::fns::literal::Literal;
33
34/// An expression that conditionally selects between two arrays based on a boolean mask.
35///
36/// For each position `i`, `result[i] = if mask[i] then if_true[i] else if_false[i]`.
37///
38/// Null values in the mask are treated as false (selecting `if_false`). This follows
39/// SQL semantics (DuckDB, Trino) where a null condition falls through to the ELSE branch,
40/// rather than Arrow's `if_else` which propagates null conditions to the output.
41#[derive(Clone)]
42pub struct Zip;
43
44impl ScalarFnVTable for Zip {
45    type Options = EmptyOptions;
46
47    fn id(&self) -> ScalarFnId {
48        ScalarFnId::from("vortex.zip")
49    }
50
51    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
52        Ok(Some(vec![]))
53    }
54
55    fn deserialize(
56        &self,
57        _metadata: &[u8],
58        _session: &VortexSession,
59    ) -> VortexResult<Self::Options> {
60        Ok(EmptyOptions)
61    }
62
63    fn arity(&self, _options: &Self::Options) -> Arity {
64        Arity::Exact(3)
65    }
66
67    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
68        match child_idx {
69            0 => ChildName::from("if_true"),
70            1 => ChildName::from("if_false"),
71            2 => ChildName::from("mask"),
72            _ => unreachable!("Invalid child index {} for Zip expression", child_idx),
73        }
74    }
75
76    fn fmt_sql(
77        &self,
78        _options: &Self::Options,
79        expr: &Expression,
80        f: &mut Formatter<'_>,
81    ) -> std::fmt::Result {
82        write!(f, "zip(")?;
83        expr.child(0).fmt_sql(f)?;
84        write!(f, ", ")?;
85        expr.child(1).fmt_sql(f)?;
86        write!(f, ", ")?;
87        expr.child(2).fmt_sql(f)?;
88        write!(f, ")")
89    }
90
91    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
92        vortex_ensure!(
93            arg_dtypes[0].eq_ignore_nullability(&arg_dtypes[1]),
94            "zip requires if_true and if_false to have the same base type, got {} and {}",
95            arg_dtypes[0],
96            arg_dtypes[1]
97        );
98        vortex_ensure!(
99            matches!(arg_dtypes[2], DType::Bool(_)),
100            "zip requires mask to be a boolean type, got {}",
101            arg_dtypes[2]
102        );
103        Ok(arg_dtypes[0]
104            .clone()
105            .union_nullability(arg_dtypes[1].nullability()))
106    }
107
108    fn execute(
109        &self,
110        _options: &Self::Options,
111        args: &dyn ExecutionArgs,
112        ctx: &mut ExecutionCtx,
113    ) -> VortexResult<ArrayRef> {
114        let if_true = args.get(0)?;
115        let if_false = args.get(1)?;
116        let mask_array = args.get(2)?;
117
118        let mask = mask_array.execute::<BoolArray>(ctx)?.to_mask();
119
120        let return_dtype = if_true
121            .dtype()
122            .clone()
123            .union_nullability(if_false.dtype().nullability());
124
125        if mask.all_true() {
126            return if_true.cast(return_dtype)?.execute(ctx);
127        }
128
129        let return_dtype = if_true
130            .dtype()
131            .clone()
132            .union_nullability(if_false.dtype().nullability());
133
134        if mask.all_false() {
135            return if_false.cast(return_dtype)?.execute(ctx);
136        }
137
138        if !if_true.is_canonical() || !if_false.is_canonical() {
139            let if_true = if_true.execute::<ArrayRef>(ctx)?;
140            let if_false = if_false.execute::<ArrayRef>(ctx)?;
141            return mask.into_array().zip(if_true, if_false);
142        }
143
144        zip_impl(&if_true, &if_false, &mask)
145    }
146
147    fn simplify(
148        &self,
149        _options: &Self::Options,
150        expr: &Expression,
151        _ctx: &dyn SimplifyCtx,
152    ) -> VortexResult<Option<Expression>> {
153        let Some(mask_lit) = expr.child(2).as_opt::<Literal>() else {
154            return Ok(None);
155        };
156
157        if let Some(mask_val) = mask_lit.as_bool().value() {
158            if mask_val {
159                return Ok(Some(expr.child(0).clone()));
160            } else {
161                return Ok(Some(expr.child(1).clone()));
162            }
163        }
164
165        Ok(None)
166    }
167
168    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
169        true
170    }
171
172    fn is_fallible(&self, _options: &Self::Options) -> bool {
173        false
174    }
175}
176
177pub(crate) fn zip_impl(
178    if_true: &ArrayRef,
179    if_false: &ArrayRef,
180    mask: &Mask,
181) -> VortexResult<ArrayRef> {
182    assert_eq!(
183        if_true.len(),
184        if_false.len(),
185        "zip requires arrays to have the same size"
186    );
187
188    let return_type = if_true
189        .dtype()
190        .clone()
191        .union_nullability(if_false.dtype().nullability());
192    zip_impl_with_builder(
193        if_true,
194        if_false,
195        mask,
196        builder_with_capacity(&return_type, if_true.len()),
197    )
198}
199
200fn zip_impl_with_builder(
201    if_true: &ArrayRef,
202    if_false: &ArrayRef,
203    mask: &Mask,
204    mut builder: Box<dyn ArrayBuilder>,
205) -> VortexResult<ArrayRef> {
206    match mask.slices() {
207        AllOr::All => Ok(if_true.to_array()),
208        AllOr::None => Ok(if_false.to_array()),
209        AllOr::Some(slices) => {
210            for (start, end) in slices {
211                builder.extend_from_array(&if_false.slice(builder.len()..*start)?);
212                builder.extend_from_array(&if_true.slice(*start..*end)?);
213            }
214            if builder.len() < if_false.len() {
215                builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?);
216            }
217            Ok(builder.finish())
218        }
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use arrow_array::cast::AsArray;
225    use arrow_select::zip::zip as arrow_zip;
226    use vortex_buffer::buffer;
227    use vortex_error::VortexResult;
228    use vortex_mask::Mask;
229
230    use crate::Array;
231    use crate::ArrayRef;
232    use crate::IntoArray;
233    use crate::LEGACY_SESSION;
234    use crate::VortexSessionExecute;
235    use crate::arrays::ConstantArray;
236    use crate::arrays::PrimitiveArray;
237    use crate::arrays::StructArray;
238    use crate::arrays::StructVTable;
239    use crate::arrays::VarBinViewArray;
240    use crate::arrow::IntoArrowArray;
241    use crate::assert_arrays_eq;
242    use crate::builders::ArrayBuilder;
243    use crate::builders::BufferGrowthStrategy;
244    use crate::builders::VarBinViewBuilder;
245    use crate::builtins::ArrayBuiltins;
246    use crate::dtype::DType;
247    use crate::dtype::Nullability;
248    use crate::dtype::PType;
249    use crate::expr::lit;
250    use crate::expr::root;
251    use crate::expr::zip_expr;
252    use crate::scalar::Scalar;
253
254    #[test]
255    fn dtype() {
256        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
257        let expr = zip_expr(lit(true), root(), lit(0i32));
258        let result_dtype = expr.return_dtype(&dtype).unwrap();
259        assert_eq!(
260            result_dtype,
261            DType::Primitive(PType::I32, Nullability::NonNullable)
262        );
263    }
264
265    #[test]
266    fn test_display() {
267        let expr = zip_expr(lit(true), root(), lit(0i32));
268        assert_eq!(expr.to_string(), "zip($, 0i32, true)");
269    }
270
271    #[test]
272    fn test_zip_basic() {
273        let mask = Mask::from_iter([true, false, false, true, false]);
274        let if_true = buffer![10, 20, 30, 40, 50].into_array();
275        let if_false = buffer![1, 2, 3, 4, 5].into_array();
276
277        let result = mask.into_array().zip(if_true, if_false).unwrap();
278        let expected = buffer![10, 2, 3, 40, 5].into_array();
279
280        assert_arrays_eq!(result, expected);
281    }
282
283    #[test]
284    fn test_zip_all_true() {
285        let mask = Mask::new_true(4);
286        let if_true = buffer![10, 20, 30, 40].into_array();
287        let if_false =
288            PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
289
290        let result = mask.into_array().zip(if_true, if_false.clone()).unwrap();
291        let expected =
292            PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), Some(40)]).into_array();
293
294        assert_arrays_eq!(result, expected);
295
296        // result must be nullable even if_true was not
297        assert_eq!(result.dtype(), if_false.dtype())
298    }
299
300    #[test]
301    #[should_panic]
302    fn test_invalid_lengths() {
303        let mask = Mask::new_false(4);
304        let if_true = buffer![10, 20, 30].into_array();
305        let if_false = buffer![1, 2, 3, 4].into_array();
306
307        let _result = mask.into_array().zip(if_true, if_false).unwrap();
308    }
309
310    #[test]
311    fn test_fragmentation() -> VortexResult<()> {
312        let len = 100;
313
314        let const1 = ConstantArray::new(
315            Scalar::utf8("hello_this_is_a_longer_string", Nullability::Nullable),
316            len,
317        )
318        .into_array();
319
320        let const2 = ConstantArray::new(
321            Scalar::utf8("world_this_is_another_string", Nullability::Nullable),
322            len,
323        )
324        .into_array();
325
326        let indices: Vec<usize> = (0..len).step_by(2).collect();
327        let mask = Mask::from_indices(len, indices);
328        let mask_array = mask.into_array();
329
330        let result = mask_array
331            .clone()
332            .zip(const1.clone(), const2.clone())?
333            .execute::<VarBinViewArray>(&mut LEGACY_SESSION.create_execution_ctx())?;
334
335        insta::assert_snapshot!(result.display_tree(), @r"
336        root: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid]
337          metadata: EmptyMetadata
338          buffer: buffer_0 host 29 B (align=1) (1.75%)
339          buffer: buffer_1 host 28 B (align=1) (1.69%)
340          buffer: views host 1.60 kB (align=16) (96.56%)
341        ");
342
343        // test wrapped in a struct
344        let wrapped1 = StructArray::try_from_iter([("nested", const1)])?.into_array();
345        let wrapped2 = StructArray::try_from_iter([("nested", const2)])?.into_array();
346
347        let wrapped_result = mask_array
348            .zip(wrapped1, wrapped2)?
349            .execute::<ArrayRef>(&mut LEGACY_SESSION.create_execution_ctx())?;
350        assert!(wrapped_result.is::<StructVTable>());
351
352        Ok(())
353    }
354
355    #[test]
356    fn test_varbinview_zip() {
357        let if_true = {
358            let mut builder = VarBinViewBuilder::new(
359                DType::Utf8(Nullability::NonNullable),
360                10,
361                Default::default(),
362                BufferGrowthStrategy::fixed(64 * 1024),
363                0.0,
364            );
365            for _ in 0..100 {
366                builder.append_value("Hello");
367                builder.append_value("Hello this is a long string that won't be inlined.");
368            }
369            builder.finish()
370        };
371
372        let if_false = {
373            let mut builder = VarBinViewBuilder::new(
374                DType::Utf8(Nullability::NonNullable),
375                10,
376                Default::default(),
377                BufferGrowthStrategy::fixed(64 * 1024),
378                0.0,
379            );
380            for _ in 0..100 {
381                builder.append_value("Hello2");
382                builder.append_value("Hello2 this is a long string that won't be inlined.");
383            }
384            builder.finish()
385        };
386
387        // [1,2,4,5,7,8,..]
388        let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect());
389        let mask_array = mask.clone().into_array();
390
391        let zipped = mask_array
392            .zip(if_true.clone(), if_false.clone())
393            .unwrap()
394            .execute::<VarBinViewArray>(&mut LEGACY_SESSION.create_execution_ctx())
395            .unwrap();
396        assert_eq!(zipped.nbuffers(), 2);
397
398        // assert the result is the same as arrow
399        let expected = arrow_zip(
400            mask.into_array()
401                .into_arrow_preferred()
402                .unwrap()
403                .as_boolean(),
404            &if_true.into_arrow_preferred().unwrap(),
405            &if_false.into_arrow_preferred().unwrap(),
406        )
407        .unwrap();
408
409        let actual = zipped.into_array().into_arrow_preferred().unwrap();
410        assert_eq!(actual.as_ref(), expected.as_ref());
411    }
412}