1use std::any::Any;
2use std::fmt::{Debug, Display};
3use std::hash::Hash;
4use std::sync::Arc;
5
6use dyn_hash::DynHash;
7pub use exprs::*;
8mod analysis;
9#[cfg(feature = "arbitrary")]
10pub mod arbitrary;
11mod exprs;
12mod field;
13pub mod forms;
14pub mod pruning;
15#[cfg(feature = "proto")]
16mod registry;
17mod scope;
18mod scope_vars;
19pub mod transform;
20pub mod traversal;
21
22pub use analysis::*;
23pub use between::*;
24pub use binary::*;
25pub use cast::*;
26pub use get_item::*;
27pub use is_null::*;
28pub use like::*;
29pub use list_contains::*;
30pub use literal::*;
31pub use merge::*;
32pub use not::*;
33pub use operators::*;
34pub use pack::*;
35#[cfg(feature = "proto")]
36pub use registry::deserialize_expr;
37pub use scope::*;
38pub use select::*;
39pub use var::*;
40use vortex_array::{Array, ArrayRef};
41use vortex_dtype::{DType, FieldName, FieldPath};
42use vortex_error::{VortexResult, VortexUnwrap};
43#[cfg(feature = "proto")]
44use vortex_proto::expr;
45#[cfg(feature = "proto")]
46use vortex_proto::expr::{Expr, kind};
47use vortex_utils::aliases::hash_set::HashSet;
48
49use crate::traversal::{Node, ReferenceCollector, VarsCollector};
50
51pub type ExprRef = Arc<dyn VortexExpr>;
52
53#[cfg(feature = "proto")]
54pub trait Id {
55 fn id(&self) -> &'static str;
56}
57
58#[cfg(feature = "proto")]
59pub trait ExprDeserialize: Id + Sync {
60 fn deserialize(&self, kind: &kind::Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
61}
62
63#[cfg(feature = "proto")]
64pub trait ExprSerializable {
65 fn id(&self) -> &'static str;
66
67 fn serialize_kind(&self) -> VortexResult<kind::Kind>;
68}
69
70#[cfg(not(feature = "proto"))]
71pub trait ExprSerializable {}
72#[cfg(not(feature = "proto"))]
73impl<T> ExprSerializable for T {}
74pub trait VortexExpr:
76 Debug + Send + Sync + DynEq + DynHash + Display + ExprSerializable + AnalysisExpr
77{
78 fn as_any(&self) -> &dyn Any;
80
81 fn evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
83 let result = self.unchecked_evaluate(scope)?;
84 assert_eq!(
85 result.dtype(),
86 &self.return_dtype(&scope.into())?,
87 "Expression {} returned dtype {} but declared return_dtype of {}",
88 self,
89 result.dtype(),
90 self.return_dtype(&scope.into())?,
91 );
92 Ok(result)
93 }
94
95 fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef>;
101
102 fn children(&self) -> Vec<&ExprRef>;
103
104 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;
105
106 fn return_dtype(&self, scope: &ScopeDType) -> VortexResult<DType>;
108}
109
110pub trait VortexExprExt {
111 fn field_references(&self) -> HashSet<FieldName>;
113
114 fn vars(&self) -> HashSet<Identifier>;
115
116 #[cfg(feature = "proto")]
117 fn serialize(&self) -> VortexResult<Expr>;
118}
119
120impl VortexExprExt for ExprRef {
121 fn field_references(&self) -> HashSet<FieldName> {
122 let mut collector = ReferenceCollector::new();
123 self.accept(&mut collector).vortex_unwrap();
125 collector.into_fields()
126 }
127
128 fn vars(&self) -> HashSet<Identifier> {
129 let mut collector = VarsCollector::new();
130 self.accept(&mut collector).vortex_unwrap();
132 collector.into_vars()
133 }
134
135 #[cfg(feature = "proto")]
136 fn serialize(&self) -> VortexResult<Expr> {
137 let children = self
138 .children()
139 .iter()
140 .map(|e| e.serialize())
141 .collect::<VortexResult<_>>()?;
142
143 Ok(Expr {
144 id: self.id().to_string(),
145 children,
146 kind: Some(expr::Kind {
147 kind: Some(self.serialize_kind()?),
148 }),
149 })
150 }
151}
152
153#[derive(Clone, Debug, Hash, PartialEq, Eq)]
154pub struct AccessPath {
155 field_path: FieldPath,
156 identifier: Identifier,
157}
158
159impl AccessPath {
160 pub fn root_field(path: FieldName) -> Self {
161 Self {
162 field_path: FieldPath::from_name(path),
163 identifier: Identifier::Identity,
164 }
165 }
166
167 pub fn new(path: FieldPath, identifier: Identifier) -> Self {
168 Self {
169 field_path: path,
170 identifier,
171 }
172 }
173
174 pub fn identifier(&self) -> &Identifier {
175 &self.identifier
176 }
177
178 pub fn field_path(&self) -> &FieldPath {
179 &self.field_path
180 }
181}
182
183pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
185 let mut conjunctions = vec![];
186 split_inner(expr, &mut conjunctions);
187 conjunctions
188}
189
190fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
191 match expr.as_any().downcast_ref::<BinaryExpr>() {
192 Some(bexp) if bexp.op() == Operator::And => {
193 split_inner(bexp.lhs(), exprs);
194 split_inner(bexp.rhs(), exprs);
195 }
196 Some(_) | None => {
197 exprs.push(expr.clone());
198 }
199 }
200}
201
202pub trait DynEq {
206 fn dyn_eq(&self, other: &dyn Any) -> bool;
207}
208
209impl<T: Eq + Any> DynEq for T {
210 fn dyn_eq(&self, other: &dyn Any) -> bool {
211 other.downcast_ref::<Self>() == Some(self)
212 }
213}
214
215impl PartialEq for dyn VortexExpr {
216 fn eq(&self, other: &Self) -> bool {
217 self.dyn_eq(other.as_any())
218 }
219}
220
221impl Eq for dyn VortexExpr {}
222
223dyn_hash::hash_trait_object!(VortexExpr);
224
225#[derive(Clone)]
227pub struct ExactExpr(pub ExprRef);
228
229impl PartialEq for ExactExpr {
230 fn eq(&self, other: &Self) -> bool {
231 Arc::ptr_eq(&self.0, &other.0)
232 }
233}
234
235impl Eq for ExactExpr {}
236
237impl Hash for ExactExpr {
238 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
239 Arc::as_ptr(&self.0).hash(state)
240 }
241}
242
243#[cfg(feature = "test-harness")]
244pub mod test_harness {
245
246 use vortex_dtype::{DType, Nullability, PType, StructFields};
247
248 pub fn struct_dtype() -> DType {
249 DType::Struct(
250 StructFields::new(
251 ["a", "col1", "col2", "bool1", "bool2"].into(),
252 vec![
253 DType::Primitive(PType::I32, Nullability::NonNullable),
254 DType::Primitive(PType::U16, Nullability::Nullable),
255 DType::Primitive(PType::U16, Nullability::Nullable),
256 DType::Bool(Nullability::NonNullable),
257 DType::Bool(Nullability::NonNullable),
258 ],
259 ),
260 Nullability::NonNullable,
261 )
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
268 use vortex_scalar::Scalar;
269
270 use super::*;
271
272 #[test]
273 fn basic_expr_split_test() {
274 let lhs = get_item("col1", root());
275 let rhs = lit(1);
276 let expr = eq(lhs, rhs);
277 let conjunction = split_conjunction(&expr);
278 assert_eq!(conjunction.len(), 1);
279 }
280
281 #[test]
282 fn basic_conjunction_split_test() {
283 let lhs = get_item("col1", root());
284 let rhs = lit(1);
285 let expr = and(lhs, rhs);
286 let conjunction = split_conjunction(&expr);
287 assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
288 }
289
290 #[test]
291 fn expr_display() {
292 assert_eq!(col("a").to_string(), "$.a");
293 assert_eq!(root().to_string(), "$");
294
295 let col1: Arc<dyn VortexExpr> = col("col1");
296 let col2: Arc<dyn VortexExpr> = col("col2");
297 assert_eq!(
298 and(col1.clone(), col2.clone()).to_string(),
299 "($.col1 and $.col2)"
300 );
301 assert_eq!(
302 or(col1.clone(), col2.clone()).to_string(),
303 "($.col1 or $.col2)"
304 );
305 assert_eq!(
306 eq(col1.clone(), col2.clone()).to_string(),
307 "($.col1 = $.col2)"
308 );
309 assert_eq!(
310 not_eq(col1.clone(), col2.clone()).to_string(),
311 "($.col1 != $.col2)"
312 );
313 assert_eq!(
314 gt(col1.clone(), col2.clone()).to_string(),
315 "($.col1 > $.col2)"
316 );
317 assert_eq!(
318 gt_eq(col1.clone(), col2.clone()).to_string(),
319 "($.col1 >= $.col2)"
320 );
321 assert_eq!(
322 lt(col1.clone(), col2.clone()).to_string(),
323 "($.col1 < $.col2)"
324 );
325 assert_eq!(
326 lt_eq(col1.clone(), col2.clone()).to_string(),
327 "($.col1 <= $.col2)"
328 );
329
330 assert_eq!(
331 or(
332 lt(col1.clone(), col2.clone()),
333 not_eq(col1.clone(), col2.clone()),
334 )
335 .to_string(),
336 "(($.col1 < $.col2) or ($.col1 != $.col2))"
337 );
338
339 assert_eq!(not(col1.clone()).to_string(), "!$.col1");
340
341 assert_eq!(
342 select(vec![FieldName::from("col1")], root()).to_string(),
343 "${col1}"
344 );
345 assert_eq!(
346 select(
347 vec![FieldName::from("col1"), FieldName::from("col2")],
348 root()
349 )
350 .to_string(),
351 "${col1, col2}"
352 );
353 assert_eq!(
354 select_exclude(
355 vec![FieldName::from("col1"), FieldName::from("col2")],
356 root()
357 )
358 .to_string(),
359 "$~{col1, col2}"
360 );
361
362 assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
363 assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
364 assert_eq!(
365 lit(Scalar::from(i64::MAX)).to_string(),
366 "9223372036854775807i64"
367 );
368 assert_eq!(lit(Scalar::from(true)).to_string(), "true");
369 assert_eq!(
370 lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
371 "null"
372 );
373
374 assert_eq!(
375 lit(Scalar::struct_(
376 DType::Struct(
377 StructFields::new(
378 FieldNames::from(["dog", "cat"]),
379 vec![
380 DType::Primitive(PType::U32, Nullability::NonNullable),
381 DType::Utf8(Nullability::NonNullable)
382 ],
383 ),
384 Nullability::NonNullable
385 ),
386 vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
387 ))
388 .to_string(),
389 "{dog: 32u32, cat: \"rufus\"}"
390 );
391 }
392
393 #[cfg(feature = "proto")]
394 mod tests_proto {
395 use crate::{VortexExprExt, deserialize_expr, eq, lit, root};
396
397 #[test]
398 fn round_trip_serde() {
399 let expr = eq(root(), lit(1));
400 let res = expr.serialize().unwrap();
401 let final_ = deserialize_expr(&res).unwrap();
402
403 assert_eq!(&expr, &final_);
404 }
405 }
406}