vortex_array/expr/exprs/zip/
mod.rs1mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use vortex_dtype::DType;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_err;
13use vortex_session::VortexSession;
14
15use crate::ArrayRef;
16use crate::builtins::ArrayBuiltins;
17use crate::compute::zip_impl;
18use crate::compute::zip_return_dtype;
19use crate::expr::Arity;
20use crate::expr::ChildName;
21use crate::expr::EmptyOptions;
22use crate::expr::ExecutionArgs;
23use crate::expr::ExprId;
24use crate::expr::Expression;
25use crate::expr::Literal;
26use crate::expr::SimplifyCtx;
27use crate::expr::VTable;
28use crate::expr::VTableExt;
29
30pub struct Zip;
38
39impl VTable for Zip {
40 type Options = EmptyOptions;
41
42 fn id(&self) -> ExprId {
43 ExprId::from("vortex.zip")
44 }
45
46 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
47 Ok(Some(vec![]))
48 }
49
50 fn deserialize(
51 &self,
52 _metadata: &[u8],
53 _session: &VortexSession,
54 ) -> VortexResult<Self::Options> {
55 Ok(EmptyOptions)
56 }
57
58 fn arity(&self, _options: &Self::Options) -> Arity {
59 Arity::Exact(3)
60 }
61
62 fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
63 match child_idx {
64 0 => ChildName::from("if_true"),
65 1 => ChildName::from("if_false"),
66 2 => ChildName::from("mask"),
67 _ => unreachable!("Invalid child index {} for Zip expression", child_idx),
68 }
69 }
70
71 fn fmt_sql(
72 &self,
73 _options: &Self::Options,
74 expr: &Expression,
75 f: &mut Formatter<'_>,
76 ) -> std::fmt::Result {
77 write!(f, "zip(")?;
78 expr.child(0).fmt_sql(f)?;
79 write!(f, ", ")?;
80 expr.child(1).fmt_sql(f)?;
81 write!(f, ", ")?;
82 expr.child(2).fmt_sql(f)?;
83 write!(f, ")")
84 }
85
86 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
87 vortex_ensure!(
88 arg_dtypes[0].eq_ignore_nullability(&arg_dtypes[1]),
89 "zip requires if_true and if_false to have the same base type, got {} and {}",
90 arg_dtypes[0],
91 arg_dtypes[1]
92 );
93 vortex_ensure!(
94 matches!(arg_dtypes[2], DType::Bool(_)),
95 "zip requires mask to be a boolean type, got {}",
96 arg_dtypes[2]
97 );
98 Ok(arg_dtypes[0]
99 .clone()
100 .union_nullability(arg_dtypes[1].nullability()))
101 }
102
103 fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
104 let [if_true, if_false, mask_array]: [ArrayRef; _] = args
105 .inputs
106 .try_into()
107 .map_err(|_| vortex_err!("Wrong arg count"))?;
108
109 let mask = mask_array.try_to_mask_fill_null_false()?;
110
111 if mask.all_true() {
112 return if_true
113 .cast(zip_return_dtype(&if_true, &if_false))?
114 .execute(args.ctx);
115 }
116
117 if mask.all_false() {
118 return if_false
119 .cast(zip_return_dtype(&if_true, &if_false))?
120 .execute(args.ctx);
121 }
122
123 if !if_true.is_canonical() || !if_false.is_canonical() {
124 let if_true = if_true.execute::<ArrayRef>(args.ctx)?;
125 let if_false = if_false.execute::<ArrayRef>(args.ctx)?;
126 return crate::compute::zip(&if_true, &if_false, &mask);
127 }
128
129 zip_impl(&if_true, &if_false, &mask)
130 }
131
132 fn simplify(
133 &self,
134 _options: &Self::Options,
135 expr: &Expression,
136 _ctx: &dyn SimplifyCtx,
137 ) -> VortexResult<Option<Expression>> {
138 let Some(mask_lit) = expr.child(2).as_opt::<Literal>() else {
139 return Ok(None);
140 };
141
142 if let Some(mask_val) = mask_lit.as_bool().value() {
143 if mask_val {
144 return Ok(Some(expr.child(0).clone()));
145 } else {
146 return Ok(Some(expr.child(1).clone()));
147 }
148 }
149
150 Ok(None)
151 }
152
153 fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
154 true
155 }
156
157 fn is_fallible(&self, _options: &Self::Options) -> bool {
158 false
159 }
160}
161
162pub fn zip_expr(if_true: Expression, if_false: Expression, mask: Expression) -> Expression {
169 Zip.new_expr(EmptyOptions, [if_true, if_false, mask])
170}
171
172#[cfg(test)]
173mod tests {
174 use vortex_dtype::DType;
175 use vortex_dtype::Nullability;
176 use vortex_dtype::PType;
177
178 use super::zip_expr;
179 use crate::expr::exprs::literal::lit;
180 use crate::expr::exprs::root::root;
181
182 #[test]
183 fn dtype() {
184 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
185 let expr = zip_expr(root(), lit(0i32), lit(true));
186 let result_dtype = expr.return_dtype(&dtype).unwrap();
187 assert_eq!(
188 result_dtype,
189 DType::Primitive(PType::I32, Nullability::NonNullable)
190 );
191 }
192
193 #[test]
194 fn test_display() {
195 let expr = zip_expr(root(), lit(0i32), lit(true));
196 assert_eq!(expr.to_string(), "zip($, 0i32, true)");
197 }
198}