1use std::any::Any;
2use std::fmt::{Debug, Display};
3use std::sync::Arc;
4
5use dyn_hash::DynHash;
6
7mod binary;
8
9mod between;
10pub mod datafusion;
11mod field;
12pub mod forms;
13mod get_item;
14mod identity;
15mod is_null;
16mod like;
17mod literal;
18mod merge;
19mod not;
20mod operators;
21mod pack;
22pub mod pruning;
23#[cfg(feature = "proto")]
24mod registry;
25mod select;
26pub mod transform;
27pub mod traversal;
28
29pub use between::*;
30pub use binary::*;
31pub use get_item::*;
32pub use identity::*;
33pub use is_null::*;
34pub use like::*;
35pub use literal::*;
36pub use merge::*;
37pub use not::*;
38pub use operators::*;
39pub use pack::*;
40#[cfg(feature = "proto")]
41pub use registry::deserialize_expr;
42pub use select::*;
43use vortex_array::aliases::hash_set::HashSet;
44use vortex_array::{Array, ArrayRef};
45use vortex_dtype::{DType, FieldName};
46use vortex_error::{VortexResult, VortexUnwrap};
47#[cfg(feature = "proto")]
48use vortex_proto::expr;
49#[cfg(feature = "proto")]
50use vortex_proto::expr::{Expr, kind};
51
52use crate::traversal::{Node, ReferenceCollector};
53
54pub type ExprRef = Arc<dyn VortexExpr>;
55
56#[cfg(feature = "proto")]
57pub trait Id {
58 fn id(&self) -> &'static str;
59}
60
61#[cfg(feature = "proto")]
62pub trait ExprDeserialize: Id + Sync {
63 fn deserialize(&self, kind: &kind::Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
64}
65
66#[cfg(feature = "proto")]
67pub trait ExprSerializable {
68 fn id(&self) -> &'static str;
69
70 fn serialize_kind(&self) -> VortexResult<kind::Kind>;
71}
72
73#[cfg(not(feature = "proto"))]
74pub trait ExprSerializable {}
75#[cfg(not(feature = "proto"))]
76impl<T> ExprSerializable for T {}
77
78pub trait VortexExpr: Debug + Send + Sync + DynEq + DynHash + Display + ExprSerializable {
80 fn as_any(&self) -> &dyn Any;
82
83 fn evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
86 let result = self.unchecked_evaluate(batch)?;
87 assert_eq!(
88 result.dtype(),
89 &self.return_dtype(batch.dtype())?,
90 "Expression {} returned dtype {} but declared return_dtype of {}",
91 self,
92 result.dtype(),
93 self.return_dtype(batch.dtype())?,
94 );
95 Ok(result)
96 }
97
98 fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef>;
104
105 fn children(&self) -> Vec<&ExprRef>;
106
107 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;
108
109 fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType>;
111}
112
113pub trait VortexExprExt {
114 fn references(&self) -> HashSet<FieldName>;
116
117 #[cfg(feature = "proto")]
118 fn serialize(&self) -> VortexResult<Expr>;
119}
120
121impl VortexExprExt for ExprRef {
122 fn references(&self) -> HashSet<FieldName> {
123 let mut collector = ReferenceCollector::new();
124 self.accept(&mut collector).vortex_unwrap();
126 collector.into_fields()
127 }
128
129 #[cfg(feature = "proto")]
130 fn serialize(&self) -> VortexResult<Expr> {
131 let children = self
132 .children()
133 .iter()
134 .map(|e| e.serialize())
135 .collect::<VortexResult<_>>()?;
136
137 Ok(Expr {
138 id: self.id().to_string(),
139 children,
140 kind: Some(expr::Kind {
141 kind: Some(self.serialize_kind()?),
142 }),
143 })
144 }
145}
146
147pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
149 let mut conjunctions = vec![];
150 split_inner(expr, &mut conjunctions);
151 conjunctions
152}
153
154fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
155 match expr.as_any().downcast_ref::<BinaryExpr>() {
156 Some(bexp) if bexp.op() == Operator::And => {
157 split_inner(bexp.lhs(), exprs);
158 split_inner(bexp.rhs(), exprs);
159 }
160 Some(_) | None => {
161 exprs.push(expr.clone());
162 }
163 }
164}
165
166pub trait DynEq {
170 fn dyn_eq(&self, other: &dyn Any) -> bool;
171}
172
173impl<T: Eq + Any> DynEq for T {
174 fn dyn_eq(&self, other: &dyn Any) -> bool {
175 other.downcast_ref::<Self>() == Some(self)
176 }
177}
178
179impl PartialEq for dyn VortexExpr {
180 fn eq(&self, other: &Self) -> bool {
181 self.dyn_eq(other.as_any())
182 }
183}
184
185impl Eq for dyn VortexExpr {}
186
187dyn_hash::hash_trait_object!(VortexExpr);
188
189#[cfg(feature = "test-harness")]
190pub mod test_harness {
191 use std::sync::Arc;
192
193 use vortex_dtype::{DType, Nullability, PType, StructDType};
194
195 pub fn struct_dtype() -> DType {
196 DType::Struct(
197 Arc::new(StructDType::new(
198 [
199 "a".into(),
200 "col1".into(),
201 "col2".into(),
202 "bool1".into(),
203 "bool2".into(),
204 ]
205 .into(),
206 vec![
207 DType::Primitive(PType::I32, Nullability::NonNullable),
208 DType::Primitive(PType::U16, Nullability::Nullable),
209 DType::Primitive(PType::U16, Nullability::Nullable),
210 DType::Bool(Nullability::NonNullable),
211 DType::Bool(Nullability::NonNullable),
212 ],
213 )),
214 Nullability::NonNullable,
215 )
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use vortex_dtype::{DType, Nullability, PType, StructDType};
222 use vortex_scalar::Scalar;
223
224 use super::*;
225
226 #[test]
227 fn basic_expr_split_test() {
228 let lhs = get_item("col1", ident());
229 let rhs = lit(1);
230 let expr = eq(lhs, rhs);
231 let conjunction = split_conjunction(&expr);
232 assert_eq!(conjunction.len(), 1);
233 }
234
235 #[test]
236 fn basic_conjunction_split_test() {
237 let lhs = get_item("col1", ident());
238 let rhs = lit(1);
239 let expr = and(lhs, rhs);
240 let conjunction = split_conjunction(&expr);
241 assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
242 }
243
244 #[test]
245 fn expr_display() {
246 assert_eq!(col("a").to_string(), "$.a");
247 assert_eq!(Identity.to_string(), "$");
248
249 let col1: Arc<dyn VortexExpr> = col("col1");
250 let col2: Arc<dyn VortexExpr> = col("col2");
251 assert_eq!(
252 and(col1.clone(), col2.clone()).to_string(),
253 "($.col1 and $.col2)"
254 );
255 assert_eq!(
256 or(col1.clone(), col2.clone()).to_string(),
257 "($.col1 or $.col2)"
258 );
259 assert_eq!(
260 eq(col1.clone(), col2.clone()).to_string(),
261 "($.col1 = $.col2)"
262 );
263 assert_eq!(
264 not_eq(col1.clone(), col2.clone()).to_string(),
265 "($.col1 != $.col2)"
266 );
267 assert_eq!(
268 gt(col1.clone(), col2.clone()).to_string(),
269 "($.col1 > $.col2)"
270 );
271 assert_eq!(
272 gt_eq(col1.clone(), col2.clone()).to_string(),
273 "($.col1 >= $.col2)"
274 );
275 assert_eq!(
276 lt(col1.clone(), col2.clone()).to_string(),
277 "($.col1 < $.col2)"
278 );
279 assert_eq!(
280 lt_eq(col1.clone(), col2.clone()).to_string(),
281 "($.col1 <= $.col2)"
282 );
283
284 assert_eq!(
285 or(
286 lt(col1.clone(), col2.clone()),
287 not_eq(col1.clone(), col2.clone()),
288 )
289 .to_string(),
290 "(($.col1 < $.col2) or ($.col1 != $.col2))"
291 );
292
293 assert_eq!(not(col1.clone()).to_string(), "!$.col1");
294
295 assert_eq!(
296 select(vec![FieldName::from("col1")], ident()).to_string(),
297 "${col1}"
298 );
299 assert_eq!(
300 select(
301 vec![FieldName::from("col1"), FieldName::from("col2")],
302 ident()
303 )
304 .to_string(),
305 "${col1, col2}"
306 );
307 assert_eq!(
308 select_exclude(
309 vec![FieldName::from("col1"), FieldName::from("col2")],
310 ident()
311 )
312 .to_string(),
313 "$~{col1, col2}"
314 );
315
316 assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
317 assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
318 assert_eq!(
319 lit(Scalar::from(i64::MAX)).to_string(),
320 "9223372036854775807i64"
321 );
322 assert_eq!(lit(Scalar::from(true)).to_string(), "true");
323 assert_eq!(
324 lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
325 "null"
326 );
327
328 assert_eq!(
329 lit(Scalar::struct_(
330 DType::Struct(
331 Arc::new(StructDType::new(
332 Arc::from([Arc::from("dog"), Arc::from("cat")]),
333 vec![
334 DType::Primitive(PType::U32, Nullability::NonNullable),
335 DType::Utf8(Nullability::NonNullable)
336 ],
337 )),
338 Nullability::NonNullable
339 ),
340 vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
341 ))
342 .to_string(),
343 "{dog: 32u32, cat: \"rufus\"}"
344 );
345 }
346
347 #[cfg(feature = "proto")]
348 mod tests_proto {
349
350 use crate::{VortexExprExt, deserialize_expr, eq, ident, lit};
351
352 #[test]
353 fn round_trip_serde() {
354 let expr = eq(ident(), lit(1));
355 let res = expr.serialize().unwrap();
356 let final_ = deserialize_expr(&res).unwrap();
357
358 assert_eq!(&expr, &final_);
359 }
360 }
361}