1extern crate core;
2
3use std::any::Any;
4use std::fmt::{Debug, Display};
5use std::sync::Arc;
6
7use dyn_hash::DynHash;
8
9mod binary;
10
11mod between;
12pub mod datafusion;
13mod field;
14pub mod forms;
15mod get_item;
16mod identity;
17mod like;
18mod literal;
19mod merge;
20mod not;
21mod operators;
22mod pack;
23pub mod pruning;
24#[cfg(feature = "proto")]
25mod registry;
26mod select;
27pub mod transform;
28pub mod traversal;
29
30pub use binary::*;
31pub use get_item::*;
32pub use identity::*;
33pub use like::*;
34pub use literal::*;
35pub use merge::*;
36pub use not::*;
37pub use operators::*;
38pub use pack::*;
39#[cfg(feature = "proto")]
40pub use registry::deserialize_expr;
41pub use select::*;
42use vortex_array::aliases::hash_set::HashSet;
43use vortex_array::{Array, ArrayRef};
44use vortex_dtype::{DType, FieldName};
45use vortex_error::{VortexResult, VortexUnwrap};
46#[cfg(feature = "proto")]
47use vortex_proto::expr;
48#[cfg(feature = "proto")]
49use vortex_proto::expr::{Expr, kind};
50
51use crate::traversal::{Node, ReferenceCollector};
52
53pub type ExprRef = Arc<dyn VortexExpr>;
54
55#[cfg(feature = "proto")]
56pub trait Id {
57 fn id(&self) -> &'static str;
58}
59
60#[cfg(feature = "proto")]
61pub trait ExprDeserialize: Id + Sync {
62 fn deserialize(&self, kind: &kind::Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
63}
64
65#[cfg(feature = "proto")]
66pub trait ExprSerializable {
67 fn id(&self) -> &'static str;
68
69 fn serialize_kind(&self) -> VortexResult<kind::Kind>;
70}
71
72#[cfg(not(feature = "proto"))]
73pub trait ExprSerializable {}
74#[cfg(not(feature = "proto"))]
75impl<T> ExprSerializable for T {}
76
77pub trait VortexExpr: Debug + Send + Sync + DynEq + DynHash + Display + ExprSerializable {
79 fn as_any(&self) -> &dyn Any;
81
82 fn evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
85 let result = self.unchecked_evaluate(batch)?;
86 assert_eq!(
87 result.dtype(),
88 &self.return_dtype(batch.dtype())?,
89 "Expression {} returned dtype {} but declared return_dtype of {}",
90 self,
91 result.dtype(),
92 self.return_dtype(batch.dtype())?,
93 );
94 Ok(result)
95 }
96
97 fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef>;
103
104 fn children(&self) -> Vec<&ExprRef>;
105
106 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;
107
108 fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType>;
110}
111
112pub trait VortexExprExt {
113 fn references(&self) -> HashSet<FieldName>;
115
116 #[cfg(feature = "proto")]
117 fn serialize(&self) -> VortexResult<Expr>;
118}
119
120impl VortexExprExt for ExprRef {
121 fn references(&self) -> HashSet<FieldName> {
122 let mut collector = ReferenceCollector::new();
123 self.accept(&mut collector).vortex_unwrap();
125 collector.into_fields()
126 }
127
128 #[cfg(feature = "proto")]
129 fn serialize(&self) -> VortexResult<Expr> {
130 let children = self
131 .children()
132 .iter()
133 .map(|e| e.serialize())
134 .collect::<VortexResult<_>>()?;
135
136 Ok(Expr {
137 id: self.id().to_string(),
138 children,
139 kind: Some(expr::Kind {
140 kind: Some(self.serialize_kind()?),
141 }),
142 })
143 }
144}
145
146pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
148 let mut conjunctions = vec![];
149 split_inner(expr, &mut conjunctions);
150 conjunctions
151}
152
153fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
154 match expr.as_any().downcast_ref::<BinaryExpr>() {
155 Some(bexp) if bexp.op() == Operator::And => {
156 split_inner(bexp.lhs(), exprs);
157 split_inner(bexp.rhs(), exprs);
158 }
159 Some(_) | None => {
160 exprs.push(expr.clone());
161 }
162 }
163}
164
165pub trait DynEq {
169 fn dyn_eq(&self, other: &dyn Any) -> bool;
170}
171
172impl<T: Eq + Any> DynEq for T {
173 fn dyn_eq(&self, other: &dyn Any) -> bool {
174 other.downcast_ref::<Self>() == Some(self)
175 }
176}
177
178impl PartialEq for dyn VortexExpr {
179 fn eq(&self, other: &Self) -> bool {
180 self.dyn_eq(other.as_any())
181 }
182}
183
184impl Eq for dyn VortexExpr {}
185
186dyn_hash::hash_trait_object!(VortexExpr);
187
188#[cfg(feature = "test-harness")]
189pub mod test_harness {
190 use std::sync::Arc;
191
192 use vortex_dtype::{DType, Nullability, PType, StructDType};
193
194 pub fn struct_dtype() -> DType {
195 DType::Struct(
196 Arc::new(StructDType::new(
197 [
198 "a".into(),
199 "col1".into(),
200 "col2".into(),
201 "bool1".into(),
202 "bool2".into(),
203 ]
204 .into(),
205 vec![
206 DType::Primitive(PType::I32, Nullability::NonNullable),
207 DType::Primitive(PType::U16, Nullability::Nullable),
208 DType::Primitive(PType::U16, Nullability::Nullable),
209 DType::Bool(Nullability::NonNullable),
210 DType::Bool(Nullability::NonNullable),
211 ],
212 )),
213 Nullability::NonNullable,
214 )
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use vortex_dtype::{DType, Nullability, PType, StructDType};
221 use vortex_scalar::Scalar;
222
223 use super::*;
224
225 #[test]
226 fn basic_expr_split_test() {
227 let lhs = get_item("col1", ident());
228 let rhs = lit(1);
229 let expr = eq(lhs, rhs);
230 let conjunction = split_conjunction(&expr);
231 assert_eq!(conjunction.len(), 1);
232 }
233
234 #[test]
235 fn basic_conjunction_split_test() {
236 let lhs = get_item("col1", ident());
237 let rhs = lit(1);
238 let expr = and(lhs, rhs);
239 let conjunction = split_conjunction(&expr);
240 assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
241 }
242
243 #[test]
244 fn expr_display() {
245 assert_eq!(col("a").to_string(), "$.a");
246 assert_eq!(Identity.to_string(), "$");
247
248 let col1: Arc<dyn VortexExpr> = col("col1");
249 let col2: Arc<dyn VortexExpr> = col("col2");
250 assert_eq!(
251 and(col1.clone(), col2.clone()).to_string(),
252 "($.col1 and $.col2)"
253 );
254 assert_eq!(
255 or(col1.clone(), col2.clone()).to_string(),
256 "($.col1 or $.col2)"
257 );
258 assert_eq!(
259 eq(col1.clone(), col2.clone()).to_string(),
260 "($.col1 = $.col2)"
261 );
262 assert_eq!(
263 not_eq(col1.clone(), col2.clone()).to_string(),
264 "($.col1 != $.col2)"
265 );
266 assert_eq!(
267 gt(col1.clone(), col2.clone()).to_string(),
268 "($.col1 > $.col2)"
269 );
270 assert_eq!(
271 gt_eq(col1.clone(), col2.clone()).to_string(),
272 "($.col1 >= $.col2)"
273 );
274 assert_eq!(
275 lt(col1.clone(), col2.clone()).to_string(),
276 "($.col1 < $.col2)"
277 );
278 assert_eq!(
279 lt_eq(col1.clone(), col2.clone()).to_string(),
280 "($.col1 <= $.col2)"
281 );
282
283 assert_eq!(
284 or(
285 lt(col1.clone(), col2.clone()),
286 not_eq(col1.clone(), col2.clone()),
287 )
288 .to_string(),
289 "(($.col1 < $.col2) or ($.col1 != $.col2))"
290 );
291
292 assert_eq!(not(col1.clone()).to_string(), "!$.col1");
293
294 assert_eq!(
295 select(vec![FieldName::from("col1")], ident()).to_string(),
296 "${col1}"
297 );
298 assert_eq!(
299 select(
300 vec![FieldName::from("col1"), FieldName::from("col2")],
301 ident()
302 )
303 .to_string(),
304 "${col1, col2}"
305 );
306 assert_eq!(
307 select_exclude(
308 vec![FieldName::from("col1"), FieldName::from("col2")],
309 ident()
310 )
311 .to_string(),
312 "$~{col1, col2}"
313 );
314
315 assert_eq!(lit(Scalar::from(0_u8)).to_string(), "0_u8");
316 assert_eq!(lit(Scalar::from(0.0_f32)).to_string(), "0_f32");
317 assert_eq!(
318 lit(Scalar::from(i64::MAX)).to_string(),
319 "9223372036854775807_i64"
320 );
321 assert_eq!(lit(Scalar::from(true)).to_string(), "true");
322 assert_eq!(
323 lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
324 "null"
325 );
326
327 assert_eq!(
328 lit(Scalar::struct_(
329 DType::Struct(
330 Arc::new(StructDType::new(
331 Arc::from([Arc::from("dog"), Arc::from("cat")]),
332 vec![
333 DType::Primitive(PType::U32, Nullability::NonNullable),
334 DType::Utf8(Nullability::NonNullable)
335 ],
336 )),
337 Nullability::NonNullable
338 ),
339 vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
340 ))
341 .to_string(),
342 "{dog:32_u32,cat:\"rufus\"}"
343 );
344 }
345
346 #[cfg(feature = "proto")]
347 mod tests_proto {
348
349 use crate::{VortexExprExt, deserialize_expr, eq, ident, lit};
350
351 #[test]
352 fn round_trip_serde() {
353 let expr = eq(ident(), lit(1));
354 let res = expr.serialize().unwrap();
355 let final_ = deserialize_expr(&res).unwrap();
356
357 assert_eq!(&expr, &final_);
358 }
359 }
360}