Skip to main content

vortex_array/scalar_fn/fns/zip/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_err;
12use vortex_mask::AllOr;
13use vortex_mask::Mask;
14use vortex_session::VortexSession;
15
16use crate::Array;
17use crate::ArrayRef;
18use crate::IntoArray;
19use crate::builders::ArrayBuilder;
20use crate::builders::builder_with_capacity;
21use crate::builtins::ArrayBuiltins;
22use crate::dtype::DType;
23use crate::expr::Expression;
24use crate::scalar_fn::Arity;
25use crate::scalar_fn::ChildName;
26use crate::scalar_fn::EmptyOptions;
27use crate::scalar_fn::ExecutionArgs;
28use crate::scalar_fn::ScalarFnId;
29use crate::scalar_fn::ScalarFnVTable;
30use crate::scalar_fn::SimplifyCtx;
31use crate::scalar_fn::fns::literal::Literal;
32
33/// An expression that conditionally selects between two arrays based on a boolean mask.
34///
35/// For each position `i`, `result[i] = if mask[i] then if_true[i] else if_false[i]`.
36///
37/// Null values in the mask are treated as false (selecting `if_false`). This follows
38/// SQL semantics (DuckDB, Trino) where a null condition falls through to the ELSE branch,
39/// rather than Arrow's `if_else` which propagates null conditions to the output.
40#[derive(Clone)]
41pub struct Zip;
42
43impl ScalarFnVTable for Zip {
44    type Options = EmptyOptions;
45
46    fn id(&self) -> ScalarFnId {
47        ScalarFnId::from("vortex.zip")
48    }
49
50    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
51        Ok(Some(vec![]))
52    }
53
54    fn deserialize(
55        &self,
56        _metadata: &[u8],
57        _session: &VortexSession,
58    ) -> VortexResult<Self::Options> {
59        Ok(EmptyOptions)
60    }
61
62    fn arity(&self, _options: &Self::Options) -> Arity {
63        Arity::Exact(3)
64    }
65
66    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
67        match child_idx {
68            0 => ChildName::from("if_true"),
69            1 => ChildName::from("if_false"),
70            2 => ChildName::from("mask"),
71            _ => unreachable!("Invalid child index {} for Zip expression", child_idx),
72        }
73    }
74
75    fn fmt_sql(
76        &self,
77        _options: &Self::Options,
78        expr: &Expression,
79        f: &mut Formatter<'_>,
80    ) -> std::fmt::Result {
81        write!(f, "zip(")?;
82        expr.child(0).fmt_sql(f)?;
83        write!(f, ", ")?;
84        expr.child(1).fmt_sql(f)?;
85        write!(f, ", ")?;
86        expr.child(2).fmt_sql(f)?;
87        write!(f, ")")
88    }
89
90    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
91        vortex_ensure!(
92            arg_dtypes[0].eq_ignore_nullability(&arg_dtypes[1]),
93            "zip requires if_true and if_false to have the same base type, got {} and {}",
94            arg_dtypes[0],
95            arg_dtypes[1]
96        );
97        vortex_ensure!(
98            matches!(arg_dtypes[2], DType::Bool(_)),
99            "zip requires mask to be a boolean type, got {}",
100            arg_dtypes[2]
101        );
102        Ok(arg_dtypes[0]
103            .clone()
104            .union_nullability(arg_dtypes[1].nullability()))
105    }
106
107    fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
108        let [if_true, if_false, mask_array]: [ArrayRef; _] = args
109            .inputs
110            .try_into()
111            .map_err(|_| vortex_err!("Wrong arg count"))?;
112
113        let mask = mask_array.try_to_mask_fill_null_false()?;
114
115        let return_dtype = if_true
116            .dtype()
117            .clone()
118            .union_nullability(if_false.dtype().nullability());
119
120        if mask.all_true() {
121            return if_true.cast(return_dtype)?.execute(args.ctx);
122        }
123
124        let return_dtype = if_true
125            .dtype()
126            .clone()
127            .union_nullability(if_false.dtype().nullability());
128
129        if mask.all_false() {
130            return if_false.cast(return_dtype)?.execute(args.ctx);
131        }
132
133        if !if_true.is_canonical() || !if_false.is_canonical() {
134            let if_true = if_true.execute::<ArrayRef>(args.ctx)?;
135            let if_false = if_false.execute::<ArrayRef>(args.ctx)?;
136            return if_true.zip(if_false, mask.into_array());
137        }
138
139        zip_impl(&if_true, &if_false, &mask)
140    }
141
142    fn simplify(
143        &self,
144        _options: &Self::Options,
145        expr: &Expression,
146        _ctx: &dyn SimplifyCtx,
147    ) -> VortexResult<Option<Expression>> {
148        let Some(mask_lit) = expr.child(2).as_opt::<Literal>() else {
149            return Ok(None);
150        };
151
152        if let Some(mask_val) = mask_lit.as_bool().value() {
153            if mask_val {
154                return Ok(Some(expr.child(0).clone()));
155            } else {
156                return Ok(Some(expr.child(1).clone()));
157            }
158        }
159
160        Ok(None)
161    }
162
163    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
164        true
165    }
166
167    fn is_fallible(&self, _options: &Self::Options) -> bool {
168        false
169    }
170}
171
172pub(crate) fn zip_impl(
173    if_true: &dyn Array,
174    if_false: &dyn Array,
175    mask: &Mask,
176) -> VortexResult<ArrayRef> {
177    assert_eq!(
178        if_true.len(),
179        if_false.len(),
180        "zip requires arrays to have the same size"
181    );
182
183    let return_type = if_true
184        .dtype()
185        .clone()
186        .union_nullability(if_false.dtype().nullability());
187    zip_impl_with_builder(
188        if_true,
189        if_false,
190        mask,
191        builder_with_capacity(&return_type, if_true.len()),
192    )
193}
194
195fn zip_impl_with_builder(
196    if_true: &dyn Array,
197    if_false: &dyn Array,
198    mask: &Mask,
199    mut builder: Box<dyn ArrayBuilder>,
200) -> VortexResult<ArrayRef> {
201    match mask.slices() {
202        AllOr::All => Ok(if_true.to_array()),
203        AllOr::None => Ok(if_false.to_array()),
204        AllOr::Some(slices) => {
205            for (start, end) in slices {
206                builder.extend_from_array(&if_false.slice(builder.len()..*start)?);
207                builder.extend_from_array(&if_true.slice(*start..*end)?);
208            }
209            if builder.len() < if_false.len() {
210                builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?);
211            }
212            Ok(builder.finish())
213        }
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use arrow_array::cast::AsArray;
220    use arrow_select::zip::zip as arrow_zip;
221    use vortex_buffer::buffer;
222    use vortex_mask::Mask;
223
224    use crate::Array;
225    use crate::IntoArray;
226    use crate::arrays::ConstantArray;
227    use crate::arrays::PrimitiveArray;
228    use crate::arrays::StructArray;
229    use crate::arrays::VarBinViewVTable;
230    use crate::arrow::IntoArrowArray;
231    use crate::assert_arrays_eq;
232    use crate::builders::ArrayBuilder;
233    use crate::builders::BufferGrowthStrategy;
234    use crate::builders::VarBinViewBuilder;
235    use crate::builtins::ArrayBuiltins;
236    use crate::dtype::DType;
237    use crate::dtype::Nullability;
238    use crate::dtype::PType;
239    use crate::expr::lit;
240    use crate::expr::root;
241    use crate::expr::zip_expr;
242    use crate::scalar::Scalar;
243
244    #[test]
245    fn dtype() {
246        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
247        let expr = zip_expr(root(), lit(0i32), lit(true));
248        let result_dtype = expr.return_dtype(&dtype).unwrap();
249        assert_eq!(
250            result_dtype,
251            DType::Primitive(PType::I32, Nullability::NonNullable)
252        );
253    }
254
255    #[test]
256    fn test_display() {
257        let expr = zip_expr(root(), lit(0i32), lit(true));
258        assert_eq!(expr.to_string(), "zip($, 0i32, true)");
259    }
260
261    #[test]
262    fn test_zip_basic() {
263        let mask = Mask::from_iter([true, false, false, true, false]);
264        let if_true = buffer![10, 20, 30, 40, 50].into_array();
265        let if_false = buffer![1, 2, 3, 4, 5].into_array();
266
267        let result = if_true.zip(if_false, mask.into_array()).unwrap();
268        let expected = buffer![10, 2, 3, 40, 5].into_array();
269
270        assert_arrays_eq!(result, expected);
271    }
272
273    #[test]
274    fn test_zip_all_true() {
275        let mask = Mask::new_true(4);
276        let if_true = buffer![10, 20, 30, 40].into_array();
277        let if_false =
278            PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
279
280        let result = if_true.zip(if_false.clone(), mask.into_array()).unwrap();
281        let expected =
282            PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), Some(40)]).into_array();
283
284        assert_arrays_eq!(result, expected);
285
286        // result must be nullable even if_true was not
287        assert_eq!(result.dtype(), if_false.dtype())
288    }
289
290    #[test]
291    #[should_panic]
292    fn test_invalid_lengths() {
293        let mask = Mask::new_false(4);
294        let if_true = buffer![10, 20, 30].into_array();
295        let if_false = buffer![1, 2, 3, 4].into_array();
296
297        let _result = if_true.zip(if_false, mask.into_array()).unwrap();
298    }
299
300    #[test]
301    fn test_fragmentation() {
302        let len = 100;
303
304        let const1 = ConstantArray::new(
305            Scalar::utf8("hello_this_is_a_longer_string", Nullability::Nullable),
306            len,
307        )
308        .to_array();
309
310        let const2 = ConstantArray::new(
311            Scalar::utf8("world_this_is_another_string", Nullability::Nullable),
312            len,
313        )
314        .to_array();
315
316        let indices: Vec<usize> = (0..len).step_by(2).collect();
317        let mask = Mask::from_indices(len, indices);
318        let mask_array = mask.into_array();
319
320        let result = const1.zip(const2.clone(), mask_array.clone()).unwrap();
321
322        insta::assert_snapshot!(result.display_tree(), @r"
323        root: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid]
324          metadata: EmptyMetadata
325          buffer: buffer_0 host 29 B (align=1) (1.75%)
326          buffer: buffer_1 host 28 B (align=1) (1.69%)
327          buffer: views host 1.60 kB (align=16) (96.56%)
328        ");
329
330        // test wrapped in a struct
331        let wrapped1 = StructArray::try_from_iter([("nested", const1)])
332            .unwrap()
333            .to_array();
334        let wrapped2 = StructArray::try_from_iter([("nested", const2)])
335            .unwrap()
336            .to_array();
337
338        let wrapped_result = wrapped1.zip(wrapped2, mask_array).unwrap();
339        insta::assert_snapshot!(wrapped_result.display_tree(), @r"
340        root: vortex.struct({nested=utf8?}, len=100) nbytes=1.66 kB (100.00%)
341          metadata: EmptyMetadata
342          nested: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid]
343            metadata: EmptyMetadata
344            buffer: buffer_0 host 29 B (align=1) (1.75%)
345            buffer: buffer_1 host 28 B (align=1) (1.69%)
346            buffer: views host 1.60 kB (align=16) (96.56%)
347        ");
348    }
349
350    #[test]
351    fn test_varbinview_zip() {
352        let if_true = {
353            let mut builder = VarBinViewBuilder::new(
354                DType::Utf8(Nullability::NonNullable),
355                10,
356                Default::default(),
357                BufferGrowthStrategy::fixed(64 * 1024),
358                0.0,
359            );
360            for _ in 0..100 {
361                builder.append_value("Hello");
362                builder.append_value("Hello this is a long string that won't be inlined.");
363            }
364            builder.finish()
365        };
366
367        let if_false = {
368            let mut builder = VarBinViewBuilder::new(
369                DType::Utf8(Nullability::NonNullable),
370                10,
371                Default::default(),
372                BufferGrowthStrategy::fixed(64 * 1024),
373                0.0,
374            );
375            for _ in 0..100 {
376                builder.append_value("Hello2");
377                builder.append_value("Hello2 this is a long string that won't be inlined.");
378            }
379            builder.finish()
380        };
381
382        // [1,2,4,5,7,8,..]
383        let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect());
384        let mask_array = mask.clone().into_array();
385
386        let zipped = if_true.zip(if_false.clone(), mask_array).unwrap();
387        let zipped = zipped.as_opt::<VarBinViewVTable>().unwrap();
388        assert_eq!(zipped.nbuffers(), 2);
389
390        // assert the result is the same as arrow
391        let expected = arrow_zip(
392            mask.into_array()
393                .into_arrow_preferred()
394                .unwrap()
395                .as_boolean(),
396            &if_true.into_arrow_preferred().unwrap(),
397            &if_false.into_arrow_preferred().unwrap(),
398        )
399        .unwrap();
400
401        let actual = zipped.clone().into_array().into_arrow_preferred().unwrap();
402        assert_eq!(actual.as_ref(), expected.as_ref());
403    }
404}