1use std::any::Any;
2use std::fmt::{Debug, Display};
3use std::sync::Arc;
4
5use dyn_hash::DynHash;
6
7mod binary;
8
9mod between;
10pub mod datafusion;
11mod field;
12pub mod forms;
13mod get_item;
14mod identity;
15mod like;
16mod literal;
17mod merge;
18mod not;
19mod operators;
20mod pack;
21pub mod pruning;
22#[cfg(feature = "proto")]
23mod registry;
24mod select;
25pub mod transform;
26pub mod traversal;
27
28pub use between::*;
29pub use binary::*;
30pub use get_item::*;
31pub use identity::*;
32pub use like::*;
33pub use literal::*;
34pub use merge::*;
35pub use not::*;
36pub use operators::*;
37pub use pack::*;
38#[cfg(feature = "proto")]
39pub use registry::deserialize_expr;
40pub use select::*;
41use vortex_array::aliases::hash_set::HashSet;
42use vortex_array::{Array, ArrayRef};
43use vortex_dtype::{DType, FieldName};
44use vortex_error::{VortexResult, VortexUnwrap};
45#[cfg(feature = "proto")]
46use vortex_proto::expr;
47#[cfg(feature = "proto")]
48use vortex_proto::expr::{Expr, kind};
49
50use crate::traversal::{Node, ReferenceCollector};
51
52pub type ExprRef = Arc<dyn VortexExpr>;
53
54#[cfg(feature = "proto")]
55pub trait Id {
56 fn id(&self) -> &'static str;
57}
58
59#[cfg(feature = "proto")]
60pub trait ExprDeserialize: Id + Sync {
61 fn deserialize(&self, kind: &kind::Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
62}
63
64#[cfg(feature = "proto")]
65pub trait ExprSerializable {
66 fn id(&self) -> &'static str;
67
68 fn serialize_kind(&self) -> VortexResult<kind::Kind>;
69}
70
71#[cfg(not(feature = "proto"))]
72pub trait ExprSerializable {}
73#[cfg(not(feature = "proto"))]
74impl<T> ExprSerializable for T {}
75
76pub trait VortexExpr: Debug + Send + Sync + DynEq + DynHash + Display + ExprSerializable {
78 fn as_any(&self) -> &dyn Any;
80
81 fn evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
84 let result = self.unchecked_evaluate(batch)?;
85 assert_eq!(
86 result.dtype(),
87 &self.return_dtype(batch.dtype())?,
88 "Expression {} returned dtype {} but declared return_dtype of {}",
89 self,
90 result.dtype(),
91 self.return_dtype(batch.dtype())?,
92 );
93 Ok(result)
94 }
95
96 fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef>;
102
103 fn children(&self) -> Vec<&ExprRef>;
104
105 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;
106
107 fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType>;
109}
110
111pub trait VortexExprExt {
112 fn references(&self) -> HashSet<FieldName>;
114
115 #[cfg(feature = "proto")]
116 fn serialize(&self) -> VortexResult<Expr>;
117}
118
119impl VortexExprExt for ExprRef {
120 fn references(&self) -> HashSet<FieldName> {
121 let mut collector = ReferenceCollector::new();
122 self.accept(&mut collector).vortex_unwrap();
124 collector.into_fields()
125 }
126
127 #[cfg(feature = "proto")]
128 fn serialize(&self) -> VortexResult<Expr> {
129 let children = self
130 .children()
131 .iter()
132 .map(|e| e.serialize())
133 .collect::<VortexResult<_>>()?;
134
135 Ok(Expr {
136 id: self.id().to_string(),
137 children,
138 kind: Some(expr::Kind {
139 kind: Some(self.serialize_kind()?),
140 }),
141 })
142 }
143}
144
145pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
147 let mut conjunctions = vec![];
148 split_inner(expr, &mut conjunctions);
149 conjunctions
150}
151
152fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
153 match expr.as_any().downcast_ref::<BinaryExpr>() {
154 Some(bexp) if bexp.op() == Operator::And => {
155 split_inner(bexp.lhs(), exprs);
156 split_inner(bexp.rhs(), exprs);
157 }
158 Some(_) | None => {
159 exprs.push(expr.clone());
160 }
161 }
162}
163
164pub trait DynEq {
168 fn dyn_eq(&self, other: &dyn Any) -> bool;
169}
170
171impl<T: Eq + Any> DynEq for T {
172 fn dyn_eq(&self, other: &dyn Any) -> bool {
173 other.downcast_ref::<Self>() == Some(self)
174 }
175}
176
177impl PartialEq for dyn VortexExpr {
178 fn eq(&self, other: &Self) -> bool {
179 self.dyn_eq(other.as_any())
180 }
181}
182
183impl Eq for dyn VortexExpr {}
184
185dyn_hash::hash_trait_object!(VortexExpr);
186
187#[cfg(feature = "test-harness")]
188pub mod test_harness {
189 use std::sync::Arc;
190
191 use vortex_dtype::{DType, Nullability, PType, StructDType};
192
193 pub fn struct_dtype() -> DType {
194 DType::Struct(
195 Arc::new(StructDType::new(
196 [
197 "a".into(),
198 "col1".into(),
199 "col2".into(),
200 "bool1".into(),
201 "bool2".into(),
202 ]
203 .into(),
204 vec![
205 DType::Primitive(PType::I32, Nullability::NonNullable),
206 DType::Primitive(PType::U16, Nullability::Nullable),
207 DType::Primitive(PType::U16, Nullability::Nullable),
208 DType::Bool(Nullability::NonNullable),
209 DType::Bool(Nullability::NonNullable),
210 ],
211 )),
212 Nullability::NonNullable,
213 )
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use vortex_dtype::{DType, Nullability, PType, StructDType};
220 use vortex_scalar::Scalar;
221
222 use super::*;
223
224 #[test]
225 fn basic_expr_split_test() {
226 let lhs = get_item("col1", ident());
227 let rhs = lit(1);
228 let expr = eq(lhs, rhs);
229 let conjunction = split_conjunction(&expr);
230 assert_eq!(conjunction.len(), 1);
231 }
232
233 #[test]
234 fn basic_conjunction_split_test() {
235 let lhs = get_item("col1", ident());
236 let rhs = lit(1);
237 let expr = and(lhs, rhs);
238 let conjunction = split_conjunction(&expr);
239 assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
240 }
241
242 #[test]
243 fn expr_display() {
244 assert_eq!(col("a").to_string(), "$.a");
245 assert_eq!(Identity.to_string(), "$");
246
247 let col1: Arc<dyn VortexExpr> = col("col1");
248 let col2: Arc<dyn VortexExpr> = col("col2");
249 assert_eq!(
250 and(col1.clone(), col2.clone()).to_string(),
251 "($.col1 and $.col2)"
252 );
253 assert_eq!(
254 or(col1.clone(), col2.clone()).to_string(),
255 "($.col1 or $.col2)"
256 );
257 assert_eq!(
258 eq(col1.clone(), col2.clone()).to_string(),
259 "($.col1 = $.col2)"
260 );
261 assert_eq!(
262 not_eq(col1.clone(), col2.clone()).to_string(),
263 "($.col1 != $.col2)"
264 );
265 assert_eq!(
266 gt(col1.clone(), col2.clone()).to_string(),
267 "($.col1 > $.col2)"
268 );
269 assert_eq!(
270 gt_eq(col1.clone(), col2.clone()).to_string(),
271 "($.col1 >= $.col2)"
272 );
273 assert_eq!(
274 lt(col1.clone(), col2.clone()).to_string(),
275 "($.col1 < $.col2)"
276 );
277 assert_eq!(
278 lt_eq(col1.clone(), col2.clone()).to_string(),
279 "($.col1 <= $.col2)"
280 );
281
282 assert_eq!(
283 or(
284 lt(col1.clone(), col2.clone()),
285 not_eq(col1.clone(), col2.clone()),
286 )
287 .to_string(),
288 "(($.col1 < $.col2) or ($.col1 != $.col2))"
289 );
290
291 assert_eq!(not(col1.clone()).to_string(), "!$.col1");
292
293 assert_eq!(
294 select(vec![FieldName::from("col1")], ident()).to_string(),
295 "${col1}"
296 );
297 assert_eq!(
298 select(
299 vec![FieldName::from("col1"), FieldName::from("col2")],
300 ident()
301 )
302 .to_string(),
303 "${col1, col2}"
304 );
305 assert_eq!(
306 select_exclude(
307 vec![FieldName::from("col1"), FieldName::from("col2")],
308 ident()
309 )
310 .to_string(),
311 "$~{col1, col2}"
312 );
313
314 assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
315 assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
316 assert_eq!(
317 lit(Scalar::from(i64::MAX)).to_string(),
318 "9223372036854775807i64"
319 );
320 assert_eq!(lit(Scalar::from(true)).to_string(), "true");
321 assert_eq!(
322 lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
323 "null"
324 );
325
326 assert_eq!(
327 lit(Scalar::struct_(
328 DType::Struct(
329 Arc::new(StructDType::new(
330 Arc::from([Arc::from("dog"), Arc::from("cat")]),
331 vec![
332 DType::Primitive(PType::U32, Nullability::NonNullable),
333 DType::Utf8(Nullability::NonNullable)
334 ],
335 )),
336 Nullability::NonNullable
337 ),
338 vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
339 ))
340 .to_string(),
341 "{dog: 32u32, cat: \"rufus\"}"
342 );
343 }
344
345 #[cfg(feature = "proto")]
346 mod tests_proto {
347
348 use crate::{VortexExprExt, deserialize_expr, eq, ident, lit};
349
350 #[test]
351 fn round_trip_serde() {
352 let expr = eq(ident(), lit(1));
353 let res = expr.serialize().unwrap();
354 let final_ = deserialize_expr(&res).unwrap();
355
356 assert_eq!(&expr, &final_);
357 }
358 }
359}