1use std::any::Any;
2use std::fmt::{Debug, Display};
3use std::hash::Hash;
4use std::sync::Arc;
5
6use dyn_hash::DynHash;
7pub use exprs::*;
8mod analysis;
9#[cfg(feature = "arbitrary")]
10pub mod arbitrary;
11mod exprs;
12mod field;
13pub mod forms;
14pub mod pruning;
15#[cfg(feature = "proto")]
16mod registry;
17mod scope;
18pub mod transform;
19pub mod traversal;
20
21pub use analysis::*;
22pub use between::*;
23pub use binary::*;
24pub use cast::*;
25pub use get_item::*;
26pub use is_null::*;
27pub use like::*;
28pub use list_contains::*;
29pub use literal::*;
30pub use merge::*;
31pub use not::*;
32pub use operators::*;
33pub use pack::*;
34#[cfg(feature = "proto")]
35pub use registry::deserialize_expr;
36pub use scope::*;
37pub use select::*;
38pub use var::*;
39use vortex_array::{Array, ArrayRef};
40use vortex_dtype::{DType, FieldName, FieldPath};
41use vortex_error::{VortexResult, VortexUnwrap};
42#[cfg(feature = "proto")]
43use vortex_proto::expr;
44#[cfg(feature = "proto")]
45use vortex_proto::expr::{Expr, kind};
46use vortex_utils::aliases::hash_set::HashSet;
47
48use crate::traversal::{Node, ReferenceCollector, VarsCollector};
49
50pub type ExprRef = Arc<dyn VortexExpr>;
51
52#[cfg(feature = "proto")]
53pub trait Id {
54 fn id(&self) -> &'static str;
55}
56
57#[cfg(feature = "proto")]
58pub trait ExprDeserialize: Id + Sync {
59 fn deserialize(&self, kind: &kind::Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
60}
61
62#[cfg(feature = "proto")]
63pub trait ExprSerializable {
64 fn id(&self) -> &'static str;
65
66 fn serialize_kind(&self) -> VortexResult<kind::Kind>;
67}
68
69#[cfg(not(feature = "proto"))]
70pub trait ExprSerializable {}
71#[cfg(not(feature = "proto"))]
72impl<T> ExprSerializable for T {}
73pub trait VortexExpr:
75 Debug + Send + Sync + DynEq + DynHash + Display + ExprSerializable + AnalysisExpr
76{
77 fn as_any(&self) -> &dyn Any;
79
80 fn evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
82 let result = self.unchecked_evaluate(scope)?;
83 assert_eq!(
84 result.dtype(),
85 &self.return_dtype(&scope.into())?,
86 "Expression {} returned dtype {} but declared return_dtype of {}",
87 self,
88 result.dtype(),
89 self.return_dtype(&scope.into())?,
90 );
91 Ok(result)
92 }
93
94 fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef>;
100
101 fn children(&self) -> Vec<&ExprRef>;
102
103 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;
104
105 fn return_dtype(&self, scope: &ScopeDType) -> VortexResult<DType>;
107}
108
109pub trait VortexExprExt {
110 fn field_references(&self) -> HashSet<FieldName>;
112
113 fn vars(&self) -> HashSet<Identifier>;
114
115 #[cfg(feature = "proto")]
116 fn serialize(&self) -> VortexResult<Expr>;
117}
118
119impl VortexExprExt for ExprRef {
120 fn field_references(&self) -> HashSet<FieldName> {
121 let mut collector = ReferenceCollector::new();
122 self.accept(&mut collector).vortex_unwrap();
124 collector.into_fields()
125 }
126
127 fn vars(&self) -> HashSet<Identifier> {
128 let mut collector = VarsCollector::new();
129 self.accept(&mut collector).vortex_unwrap();
131 collector.into_vars()
132 }
133
134 #[cfg(feature = "proto")]
135 fn serialize(&self) -> VortexResult<Expr> {
136 let children = self
137 .children()
138 .iter()
139 .map(|e| e.serialize())
140 .collect::<VortexResult<_>>()?;
141
142 Ok(Expr {
143 id: self.id().to_string(),
144 children,
145 kind: Some(expr::Kind {
146 kind: Some(self.serialize_kind()?),
147 }),
148 })
149 }
150}
151
152#[derive(Clone, Debug, Hash, PartialEq, Eq)]
153pub struct AccessPath {
154 field_path: FieldPath,
155 identifier: Identifier,
156}
157
158impl AccessPath {
159 pub fn root_field(path: FieldName) -> Self {
160 Self {
161 field_path: FieldPath::from_name(path),
162 identifier: Identifier::Identity,
163 }
164 }
165
166 pub fn new(path: FieldPath, identifier: Identifier) -> Self {
167 Self {
168 field_path: path,
169 identifier,
170 }
171 }
172
173 pub fn identifier(&self) -> &Identifier {
174 &self.identifier
175 }
176
177 pub fn field_path(&self) -> &FieldPath {
178 &self.field_path
179 }
180}
181
182pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
184 let mut conjunctions = vec![];
185 split_inner(expr, &mut conjunctions);
186 conjunctions
187}
188
189fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
190 match expr.as_any().downcast_ref::<BinaryExpr>() {
191 Some(bexp) if bexp.op() == Operator::And => {
192 split_inner(bexp.lhs(), exprs);
193 split_inner(bexp.rhs(), exprs);
194 }
195 Some(_) | None => {
196 exprs.push(expr.clone());
197 }
198 }
199}
200
201pub trait DynEq {
205 fn dyn_eq(&self, other: &dyn Any) -> bool;
206}
207
208impl<T: Eq + Any> DynEq for T {
209 fn dyn_eq(&self, other: &dyn Any) -> bool {
210 other.downcast_ref::<Self>() == Some(self)
211 }
212}
213
214impl PartialEq for dyn VortexExpr {
215 fn eq(&self, other: &Self) -> bool {
216 self.dyn_eq(other.as_any())
217 }
218}
219
220impl Eq for dyn VortexExpr {}
221
222dyn_hash::hash_trait_object!(VortexExpr);
223
224#[derive(Clone)]
226pub struct ExactExpr(pub ExprRef);
227
228impl PartialEq for ExactExpr {
229 fn eq(&self, other: &Self) -> bool {
230 Arc::ptr_eq(&self.0, &other.0)
231 }
232}
233
234impl Eq for ExactExpr {}
235
236impl Hash for ExactExpr {
237 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
238 Arc::as_ptr(&self.0).hash(state)
239 }
240}
241
242#[cfg(feature = "test-harness")]
243pub mod test_harness {
244 use std::sync::Arc;
245
246 use vortex_dtype::{DType, Nullability, PType, StructFields};
247
248 pub fn struct_dtype() -> DType {
249 DType::Struct(
250 Arc::new(StructFields::new(
251 [
252 "a".into(),
253 "col1".into(),
254 "col2".into(),
255 "bool1".into(),
256 "bool2".into(),
257 ]
258 .into(),
259 vec![
260 DType::Primitive(PType::I32, Nullability::NonNullable),
261 DType::Primitive(PType::U16, Nullability::Nullable),
262 DType::Primitive(PType::U16, Nullability::Nullable),
263 DType::Bool(Nullability::NonNullable),
264 DType::Bool(Nullability::NonNullable),
265 ],
266 )),
267 Nullability::NonNullable,
268 )
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use vortex_dtype::{DType, Nullability, PType, StructFields};
275 use vortex_scalar::Scalar;
276
277 use super::*;
278
279 #[test]
280 fn basic_expr_split_test() {
281 let lhs = get_item("col1", root());
282 let rhs = lit(1);
283 let expr = eq(lhs, rhs);
284 let conjunction = split_conjunction(&expr);
285 assert_eq!(conjunction.len(), 1);
286 }
287
288 #[test]
289 fn basic_conjunction_split_test() {
290 let lhs = get_item("col1", root());
291 let rhs = lit(1);
292 let expr = and(lhs, rhs);
293 let conjunction = split_conjunction(&expr);
294 assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
295 }
296
297 #[test]
298 fn expr_display() {
299 assert_eq!(col("a").to_string(), "$.a");
300 assert_eq!(root().to_string(), "$");
301
302 let col1: Arc<dyn VortexExpr> = col("col1");
303 let col2: Arc<dyn VortexExpr> = col("col2");
304 assert_eq!(
305 and(col1.clone(), col2.clone()).to_string(),
306 "($.col1 and $.col2)"
307 );
308 assert_eq!(
309 or(col1.clone(), col2.clone()).to_string(),
310 "($.col1 or $.col2)"
311 );
312 assert_eq!(
313 eq(col1.clone(), col2.clone()).to_string(),
314 "($.col1 = $.col2)"
315 );
316 assert_eq!(
317 not_eq(col1.clone(), col2.clone()).to_string(),
318 "($.col1 != $.col2)"
319 );
320 assert_eq!(
321 gt(col1.clone(), col2.clone()).to_string(),
322 "($.col1 > $.col2)"
323 );
324 assert_eq!(
325 gt_eq(col1.clone(), col2.clone()).to_string(),
326 "($.col1 >= $.col2)"
327 );
328 assert_eq!(
329 lt(col1.clone(), col2.clone()).to_string(),
330 "($.col1 < $.col2)"
331 );
332 assert_eq!(
333 lt_eq(col1.clone(), col2.clone()).to_string(),
334 "($.col1 <= $.col2)"
335 );
336
337 assert_eq!(
338 or(
339 lt(col1.clone(), col2.clone()),
340 not_eq(col1.clone(), col2.clone()),
341 )
342 .to_string(),
343 "(($.col1 < $.col2) or ($.col1 != $.col2))"
344 );
345
346 assert_eq!(not(col1.clone()).to_string(), "!$.col1");
347
348 assert_eq!(
349 select(vec![FieldName::from("col1")], root()).to_string(),
350 "${col1}"
351 );
352 assert_eq!(
353 select(
354 vec![FieldName::from("col1"), FieldName::from("col2")],
355 root()
356 )
357 .to_string(),
358 "${col1, col2}"
359 );
360 assert_eq!(
361 select_exclude(
362 vec![FieldName::from("col1"), FieldName::from("col2")],
363 root()
364 )
365 .to_string(),
366 "$~{col1, col2}"
367 );
368
369 assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
370 assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
371 assert_eq!(
372 lit(Scalar::from(i64::MAX)).to_string(),
373 "9223372036854775807i64"
374 );
375 assert_eq!(lit(Scalar::from(true)).to_string(), "true");
376 assert_eq!(
377 lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
378 "null"
379 );
380
381 assert_eq!(
382 lit(Scalar::struct_(
383 DType::Struct(
384 Arc::new(StructFields::new(
385 Arc::from([Arc::from("dog"), Arc::from("cat")]),
386 vec![
387 DType::Primitive(PType::U32, Nullability::NonNullable),
388 DType::Utf8(Nullability::NonNullable)
389 ],
390 )),
391 Nullability::NonNullable
392 ),
393 vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
394 ))
395 .to_string(),
396 "{dog: 32u32, cat: \"rufus\"}"
397 );
398 }
399
400 #[cfg(feature = "proto")]
401 mod tests_proto {
402 use crate::{VortexExprExt, deserialize_expr, eq, lit, root};
403
404 #[test]
405 fn round_trip_serde() {
406 let expr = eq(root(), lit(1));
407 let res = expr.serialize().unwrap();
408 let final_ = deserialize_expr(&res).unwrap();
409
410 assert_eq!(&expr, &final_);
411 }
412 }
413}