Skip to main content

vortex_array/expr/exprs/zip/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use vortex_dtype::DType;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_err;
13use vortex_session::VortexSession;
14
15use crate::ArrayRef;
16use crate::builtins::ArrayBuiltins;
17use crate::compute::zip_impl;
18use crate::compute::zip_return_dtype;
19use crate::expr::Arity;
20use crate::expr::ChildName;
21use crate::expr::EmptyOptions;
22use crate::expr::ExecutionArgs;
23use crate::expr::ExprId;
24use crate::expr::Expression;
25use crate::expr::Literal;
26use crate::expr::SimplifyCtx;
27use crate::expr::VTable;
28use crate::expr::VTableExt;
29
30/// An expression that conditionally selects between two arrays based on a boolean mask.
31///
32/// For each position `i`, `result[i] = if mask[i] then if_true[i] else if_false[i]`.
33///
34/// Null values in the mask are treated as false (selecting `if_false`). This follows
35/// SQL semantics (DuckDB, Trino) where a null condition falls through to the ELSE branch,
36/// rather than Arrow's `if_else` which propagates null conditions to the output.
37pub struct Zip;
38
39impl VTable for Zip {
40    type Options = EmptyOptions;
41
42    fn id(&self) -> ExprId {
43        ExprId::from("vortex.zip")
44    }
45
46    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
47        Ok(Some(vec![]))
48    }
49
50    fn deserialize(
51        &self,
52        _metadata: &[u8],
53        _session: &VortexSession,
54    ) -> VortexResult<Self::Options> {
55        Ok(EmptyOptions)
56    }
57
58    fn arity(&self, _options: &Self::Options) -> Arity {
59        Arity::Exact(3)
60    }
61
62    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
63        match child_idx {
64            0 => ChildName::from("if_true"),
65            1 => ChildName::from("if_false"),
66            2 => ChildName::from("mask"),
67            _ => unreachable!("Invalid child index {} for Zip expression", child_idx),
68        }
69    }
70
71    fn fmt_sql(
72        &self,
73        _options: &Self::Options,
74        expr: &Expression,
75        f: &mut Formatter<'_>,
76    ) -> std::fmt::Result {
77        write!(f, "zip(")?;
78        expr.child(0).fmt_sql(f)?;
79        write!(f, ", ")?;
80        expr.child(1).fmt_sql(f)?;
81        write!(f, ", ")?;
82        expr.child(2).fmt_sql(f)?;
83        write!(f, ")")
84    }
85
86    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
87        vortex_ensure!(
88            arg_dtypes[0].eq_ignore_nullability(&arg_dtypes[1]),
89            "zip requires if_true and if_false to have the same base type, got {} and {}",
90            arg_dtypes[0],
91            arg_dtypes[1]
92        );
93        vortex_ensure!(
94            matches!(arg_dtypes[2], DType::Bool(_)),
95            "zip requires mask to be a boolean type, got {}",
96            arg_dtypes[2]
97        );
98        Ok(arg_dtypes[0]
99            .clone()
100            .union_nullability(arg_dtypes[1].nullability()))
101    }
102
103    fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
104        let [if_true, if_false, mask_array]: [ArrayRef; _] = args
105            .inputs
106            .try_into()
107            .map_err(|_| vortex_err!("Wrong arg count"))?;
108
109        let mask = mask_array.try_to_mask_fill_null_false()?;
110
111        if mask.all_true() {
112            return if_true
113                .cast(zip_return_dtype(&if_true, &if_false))?
114                .execute(args.ctx);
115        }
116
117        if mask.all_false() {
118            return if_false
119                .cast(zip_return_dtype(&if_true, &if_false))?
120                .execute(args.ctx);
121        }
122
123        if !if_true.is_canonical() || !if_false.is_canonical() {
124            let if_true = if_true.execute::<ArrayRef>(args.ctx)?;
125            let if_false = if_false.execute::<ArrayRef>(args.ctx)?;
126            return crate::compute::zip(&if_true, &if_false, &mask);
127        }
128
129        zip_impl(&if_true, &if_false, &mask)
130    }
131
132    fn simplify(
133        &self,
134        _options: &Self::Options,
135        expr: &Expression,
136        _ctx: &dyn SimplifyCtx,
137    ) -> VortexResult<Option<Expression>> {
138        let Some(mask_lit) = expr.child(2).as_opt::<Literal>() else {
139            return Ok(None);
140        };
141
142        if let Some(mask_val) = mask_lit.as_bool().value() {
143            if mask_val {
144                return Ok(Some(expr.child(0).clone()));
145            } else {
146                return Ok(Some(expr.child(1).clone()));
147            }
148        }
149
150        Ok(None)
151    }
152
153    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
154        true
155    }
156
157    fn is_fallible(&self, _options: &Self::Options) -> bool {
158        false
159    }
160}
161
162/// Creates a zip expression that conditionally selects between two arrays.
163///
164/// ```rust
165/// # use vortex_array::expr::{zip_expr, root, lit};
166/// let expr = zip_expr(root(), lit(0i32), lit(true));
167/// ```
168pub fn zip_expr(if_true: Expression, if_false: Expression, mask: Expression) -> Expression {
169    Zip.new_expr(EmptyOptions, [if_true, if_false, mask])
170}
171
172#[cfg(test)]
173mod tests {
174    use vortex_dtype::DType;
175    use vortex_dtype::Nullability;
176    use vortex_dtype::PType;
177
178    use super::zip_expr;
179    use crate::expr::exprs::literal::lit;
180    use crate::expr::exprs::root::root;
181
182    #[test]
183    fn dtype() {
184        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
185        let expr = zip_expr(root(), lit(0i32), lit(true));
186        let result_dtype = expr.return_dtype(&dtype).unwrap();
187        assert_eq!(
188            result_dtype,
189            DType::Primitive(PType::I32, Nullability::NonNullable)
190        );
191    }
192
193    #[test]
194    fn test_display() {
195        let expr = zip_expr(root(), lit(0i32), lit(true));
196        assert_eq!(expr.to_string(), "zip($, 0i32, true)");
197    }
198}