1#![feature(option_zip)]
2
3use std::any::Any;
4use std::fmt::{Debug, Display};
5use std::sync::Arc;
6
7use dyn_hash::DynHash;
8
9mod binary;
10
11mod analysis;
12mod between;
13mod cast;
14mod field;
15pub mod forms;
16mod get_item;
17mod is_null;
18mod let_;
19mod like;
20mod list_contains;
21mod literal;
22mod merge;
23mod not;
24mod operators;
25mod pack;
26pub mod pruning;
27#[cfg(feature = "proto")]
28mod registry;
29mod scope;
30mod select;
31pub mod transform;
32pub mod traversal;
33mod var;
34
35pub use analysis::*;
36pub use between::*;
37pub use binary::*;
38pub use cast::*;
39pub use get_item::*;
40pub use is_null::*;
41pub use let_::*;
42pub use like::*;
43pub use list_contains::*;
44pub use literal::*;
45pub use merge::*;
46pub use not::*;
47pub use operators::*;
48pub use pack::*;
49#[cfg(feature = "proto")]
50pub use registry::deserialize_expr;
51pub use scope::*;
52pub use select::*;
53pub use var::*;
54use vortex_array::{Array, ArrayRef};
55use vortex_dtype::{DType, FieldName, FieldPath};
56use vortex_error::{VortexResult, VortexUnwrap};
57#[cfg(feature = "proto")]
58use vortex_proto::expr;
59#[cfg(feature = "proto")]
60use vortex_proto::expr::{Expr, kind};
61use vortex_utils::aliases::hash_set::HashSet;
62
63use crate::traversal::{Node, ReferenceCollector};
64
65pub type ExprRef = Arc<dyn VortexExpr>;
66
67#[cfg(feature = "proto")]
68pub trait Id {
69 fn id(&self) -> &'static str;
70}
71
72#[cfg(feature = "proto")]
73pub trait ExprDeserialize: Id + Sync {
74 fn deserialize(&self, kind: &kind::Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
75}
76
77#[cfg(feature = "proto")]
78pub trait ExprSerializable {
79 fn id(&self) -> &'static str;
80
81 fn serialize_kind(&self) -> VortexResult<kind::Kind>;
82}
83
84#[cfg(not(feature = "proto"))]
85pub trait ExprSerializable {}
86#[cfg(not(feature = "proto"))]
87impl<T> ExprSerializable for T {}
88pub trait VortexExpr:
90 Debug + Send + Sync + DynEq + DynHash + Display + ExprSerializable + AnalysisExpr
91{
92 fn as_any(&self) -> &dyn Any;
94
95 fn evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
97 let result = self.unchecked_evaluate(scope)?;
98 assert_eq!(
99 result.dtype(),
100 &self.return_dtype(&scope.into())?,
101 "Expression {} returned dtype {} but declared return_dtype of {}",
102 self,
103 result.dtype(),
104 self.return_dtype(&scope.into())?,
105 );
106 Ok(result)
107 }
108
109 fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef>;
115
116 fn children(&self) -> Vec<&ExprRef>;
117
118 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;
119
120 fn return_dtype(&self, scope: &ScopeDType) -> VortexResult<DType>;
122}
123
124pub trait VortexExprExt {
125 fn references(&self) -> HashSet<FieldName>;
127
128 #[cfg(feature = "proto")]
129 fn serialize(&self) -> VortexResult<Expr>;
130}
131
132impl VortexExprExt for ExprRef {
133 fn references(&self) -> HashSet<FieldName> {
134 let mut collector = ReferenceCollector::new();
135 self.accept(&mut collector).vortex_unwrap();
137 collector.into_fields()
138 }
139
140 #[cfg(feature = "proto")]
141 fn serialize(&self) -> VortexResult<Expr> {
142 let children = self
143 .children()
144 .iter()
145 .map(|e| e.serialize())
146 .collect::<VortexResult<_>>()?;
147
148 Ok(Expr {
149 id: self.id().to_string(),
150 children,
151 kind: Some(expr::Kind {
152 kind: Some(self.serialize_kind()?),
153 }),
154 })
155 }
156}
157
158#[derive(Clone, Debug, Hash, PartialEq, Eq)]
159pub struct AccessPath {
160 field_path: FieldPath,
161 identifier: Identifier,
162}
163
164impl AccessPath {
165 pub fn root_field(path: FieldName) -> Self {
166 Self {
167 field_path: FieldPath::from_name(path),
168 identifier: Identifier::Identity,
169 }
170 }
171
172 pub fn new(path: FieldPath, identifier: Identifier) -> Self {
173 Self {
174 field_path: path,
175 identifier,
176 }
177 }
178
179 pub fn identifier(&self) -> &Identifier {
180 &self.identifier
181 }
182
183 pub fn field_path(&self) -> &FieldPath {
184 &self.field_path
185 }
186}
187
188pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
190 let mut conjunctions = vec![];
191 split_inner(expr, &mut conjunctions);
192 conjunctions
193}
194
195fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
196 match expr.as_any().downcast_ref::<BinaryExpr>() {
197 Some(bexp) if bexp.op() == Operator::And => {
198 split_inner(bexp.lhs(), exprs);
199 split_inner(bexp.rhs(), exprs);
200 }
201 Some(_) | None => {
202 exprs.push(expr.clone());
203 }
204 }
205}
206
207pub trait DynEq {
211 fn dyn_eq(&self, other: &dyn Any) -> bool;
212}
213
214impl<T: Eq + Any> DynEq for T {
215 fn dyn_eq(&self, other: &dyn Any) -> bool {
216 other.downcast_ref::<Self>() == Some(self)
217 }
218}
219
220impl PartialEq for dyn VortexExpr {
221 fn eq(&self, other: &Self) -> bool {
222 self.dyn_eq(other.as_any())
223 }
224}
225
226impl Eq for dyn VortexExpr {}
227
228dyn_hash::hash_trait_object!(VortexExpr);
229
230#[cfg(feature = "test-harness")]
231pub mod test_harness {
232 use std::sync::Arc;
233
234 use vortex_dtype::{DType, Nullability, PType, StructFields};
235
236 pub fn struct_dtype() -> DType {
237 DType::Struct(
238 Arc::new(StructFields::new(
239 [
240 "a".into(),
241 "col1".into(),
242 "col2".into(),
243 "bool1".into(),
244 "bool2".into(),
245 ]
246 .into(),
247 vec![
248 DType::Primitive(PType::I32, Nullability::NonNullable),
249 DType::Primitive(PType::U16, Nullability::Nullable),
250 DType::Primitive(PType::U16, Nullability::Nullable),
251 DType::Bool(Nullability::NonNullable),
252 DType::Bool(Nullability::NonNullable),
253 ],
254 )),
255 Nullability::NonNullable,
256 )
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use vortex_dtype::{DType, Nullability, PType, StructFields};
263 use vortex_scalar::Scalar;
264
265 use super::*;
266
267 #[test]
268 fn basic_expr_split_test() {
269 let lhs = get_item("col1", root());
270 let rhs = lit(1);
271 let expr = eq(lhs, rhs);
272 let conjunction = split_conjunction(&expr);
273 assert_eq!(conjunction.len(), 1);
274 }
275
276 #[test]
277 fn basic_conjunction_split_test() {
278 let lhs = get_item("col1", root());
279 let rhs = lit(1);
280 let expr = and(lhs, rhs);
281 let conjunction = split_conjunction(&expr);
282 assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
283 }
284
285 #[test]
286 fn expr_display() {
287 assert_eq!(col("a").to_string(), "$.a");
288 assert_eq!(root().to_string(), "$");
289
290 let col1: Arc<dyn VortexExpr> = col("col1");
291 let col2: Arc<dyn VortexExpr> = col("col2");
292 assert_eq!(
293 and(col1.clone(), col2.clone()).to_string(),
294 "($.col1 and $.col2)"
295 );
296 assert_eq!(
297 or(col1.clone(), col2.clone()).to_string(),
298 "($.col1 or $.col2)"
299 );
300 assert_eq!(
301 eq(col1.clone(), col2.clone()).to_string(),
302 "($.col1 = $.col2)"
303 );
304 assert_eq!(
305 not_eq(col1.clone(), col2.clone()).to_string(),
306 "($.col1 != $.col2)"
307 );
308 assert_eq!(
309 gt(col1.clone(), col2.clone()).to_string(),
310 "($.col1 > $.col2)"
311 );
312 assert_eq!(
313 gt_eq(col1.clone(), col2.clone()).to_string(),
314 "($.col1 >= $.col2)"
315 );
316 assert_eq!(
317 lt(col1.clone(), col2.clone()).to_string(),
318 "($.col1 < $.col2)"
319 );
320 assert_eq!(
321 lt_eq(col1.clone(), col2.clone()).to_string(),
322 "($.col1 <= $.col2)"
323 );
324
325 assert_eq!(
326 or(
327 lt(col1.clone(), col2.clone()),
328 not_eq(col1.clone(), col2.clone()),
329 )
330 .to_string(),
331 "(($.col1 < $.col2) or ($.col1 != $.col2))"
332 );
333
334 assert_eq!(not(col1.clone()).to_string(), "!$.col1");
335
336 assert_eq!(
337 select(vec![FieldName::from("col1")], root()).to_string(),
338 "${col1}"
339 );
340 assert_eq!(
341 select(
342 vec![FieldName::from("col1"), FieldName::from("col2")],
343 root()
344 )
345 .to_string(),
346 "${col1, col2}"
347 );
348 assert_eq!(
349 select_exclude(
350 vec![FieldName::from("col1"), FieldName::from("col2")],
351 root()
352 )
353 .to_string(),
354 "$~{col1, col2}"
355 );
356
357 assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
358 assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
359 assert_eq!(
360 lit(Scalar::from(i64::MAX)).to_string(),
361 "9223372036854775807i64"
362 );
363 assert_eq!(lit(Scalar::from(true)).to_string(), "true");
364 assert_eq!(
365 lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
366 "null"
367 );
368
369 assert_eq!(
370 lit(Scalar::struct_(
371 DType::Struct(
372 Arc::new(StructFields::new(
373 Arc::from([Arc::from("dog"), Arc::from("cat")]),
374 vec![
375 DType::Primitive(PType::U32, Nullability::NonNullable),
376 DType::Utf8(Nullability::NonNullable)
377 ],
378 )),
379 Nullability::NonNullable
380 ),
381 vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
382 ))
383 .to_string(),
384 "{dog: 32u32, cat: \"rufus\"}"
385 );
386 }
387
388 #[cfg(feature = "proto")]
389 mod tests_proto {
390
391 use crate::{VortexExprExt, deserialize_expr, eq, lit, root};
392
393 #[test]
394 fn round_trip_serde() {
395 let expr = eq(root(), lit(1));
396 let res = expr.serialize().unwrap();
397 let final_ = deserialize_expr(&res).unwrap();
398
399 assert_eq!(&expr, &final_);
400 }
401 }
402}