1extern crate core;
2
3use std::any::Any;
4use std::fmt::{Debug, Display};
5use std::sync::Arc;
6
7use dyn_hash::DynHash;
8
9mod binary;
10
11mod between;
12pub mod datafusion;
13mod field;
14pub mod forms;
15mod get_item;
16mod identity;
17mod like;
18mod literal;
19mod merge;
20mod not;
21mod operators;
22mod pack;
23pub mod pruning;
24mod select;
25pub mod transform;
26pub mod traversal;
27
28pub use binary::*;
29pub use get_item::*;
30pub use identity::*;
31pub use like::*;
32pub use literal::*;
33pub use merge::*;
34pub use not::*;
35pub use operators::*;
36pub use pack::*;
37pub use select::*;
38use vortex_array::aliases::hash_set::HashSet;
39use vortex_array::{Array, ArrayRef};
40use vortex_dtype::{DType, FieldName};
41use vortex_error::{VortexResult, VortexUnwrap};
42
43use crate::traversal::{Node, ReferenceCollector};
44
45pub type ExprRef = Arc<dyn VortexExpr>;
46
47pub trait VortexExpr: Debug + Send + Sync + DynEq + DynHash + Display {
49 fn as_any(&self) -> &dyn Any;
51
52 fn evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
55 let result = self.unchecked_evaluate(batch)?;
56 assert_eq!(
57 result.dtype(),
58 &self.return_dtype(batch.dtype())?,
59 "Expression {} returned dtype {} but declared return_dtype of {}",
60 self,
61 result.dtype(),
62 self.return_dtype(batch.dtype())?,
63 );
64 Ok(result)
65 }
66
67 fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef>;
73
74 fn children(&self) -> Vec<&ExprRef>;
75
76 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;
77
78 fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType>;
80}
81
82pub trait VortexExprExt {
83 fn references(&self) -> HashSet<FieldName>;
85}
86
87impl VortexExprExt for ExprRef {
88 fn references(&self) -> HashSet<FieldName> {
89 let mut collector = ReferenceCollector::new();
90 self.accept(&mut collector).vortex_unwrap();
92 collector.into_fields()
93 }
94}
95
96pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
98 let mut conjunctions = vec![];
99 split_inner(expr, &mut conjunctions);
100 conjunctions
101}
102
103fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
104 match expr.as_any().downcast_ref::<BinaryExpr>() {
105 Some(bexp) if bexp.op() == Operator::And => {
106 split_inner(bexp.lhs(), exprs);
107 split_inner(bexp.rhs(), exprs);
108 }
109 Some(_) | None => {
110 exprs.push(expr.clone());
111 }
112 }
113}
114
115pub trait DynEq {
119 fn dyn_eq(&self, other: &dyn Any) -> bool;
120}
121
122impl<T: Eq + Any> DynEq for T {
123 fn dyn_eq(&self, other: &dyn Any) -> bool {
124 other.downcast_ref::<Self>() == Some(self)
125 }
126}
127
128impl PartialEq for dyn VortexExpr {
129 fn eq(&self, other: &Self) -> bool {
130 self.dyn_eq(other.as_any())
131 }
132}
133
134impl Eq for dyn VortexExpr {}
135
136dyn_hash::hash_trait_object!(VortexExpr);
137
138#[cfg(feature = "test-harness")]
139pub mod test_harness {
140 use std::sync::Arc;
141
142 use vortex_dtype::{DType, Nullability, PType, StructDType};
143
144 pub fn struct_dtype() -> DType {
145 DType::Struct(
146 Arc::new(StructDType::new(
147 [
148 "a".into(),
149 "col1".into(),
150 "col2".into(),
151 "bool1".into(),
152 "bool2".into(),
153 ]
154 .into(),
155 vec![
156 DType::Primitive(PType::I32, Nullability::NonNullable),
157 DType::Primitive(PType::U16, Nullability::Nullable),
158 DType::Primitive(PType::U16, Nullability::Nullable),
159 DType::Bool(Nullability::NonNullable),
160 DType::Bool(Nullability::NonNullable),
161 ],
162 )),
163 Nullability::NonNullable,
164 )
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use vortex_dtype::{DType, Nullability, PType, StructDType};
171 use vortex_scalar::Scalar;
172
173 use super::*;
174
175 #[test]
176 fn basic_expr_split_test() {
177 let lhs = get_item("col1", ident());
178 let rhs = lit(1);
179 let expr = eq(lhs, rhs);
180 let conjunction = split_conjunction(&expr);
181 assert_eq!(conjunction.len(), 1);
182 }
183
184 #[test]
185 fn basic_conjunction_split_test() {
186 let lhs = get_item("col1", ident());
187 let rhs = lit(1);
188 let expr = and(lhs, rhs);
189 let conjunction = split_conjunction(&expr);
190 assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
191 }
192
193 #[test]
194 fn expr_display() {
195 assert_eq!(col("a").to_string(), "$.a");
196 assert_eq!(Identity.to_string(), "$");
197
198 let col1: Arc<dyn VortexExpr> = col("col1");
199 let col2: Arc<dyn VortexExpr> = col("col2");
200 assert_eq!(
201 and(col1.clone(), col2.clone()).to_string(),
202 "($.col1 and $.col2)"
203 );
204 assert_eq!(
205 or(col1.clone(), col2.clone()).to_string(),
206 "($.col1 or $.col2)"
207 );
208 assert_eq!(
209 eq(col1.clone(), col2.clone()).to_string(),
210 "($.col1 = $.col2)"
211 );
212 assert_eq!(
213 not_eq(col1.clone(), col2.clone()).to_string(),
214 "($.col1 != $.col2)"
215 );
216 assert_eq!(
217 gt(col1.clone(), col2.clone()).to_string(),
218 "($.col1 > $.col2)"
219 );
220 assert_eq!(
221 gt_eq(col1.clone(), col2.clone()).to_string(),
222 "($.col1 >= $.col2)"
223 );
224 assert_eq!(
225 lt(col1.clone(), col2.clone()).to_string(),
226 "($.col1 < $.col2)"
227 );
228 assert_eq!(
229 lt_eq(col1.clone(), col2.clone()).to_string(),
230 "($.col1 <= $.col2)"
231 );
232
233 assert_eq!(
234 or(
235 lt(col1.clone(), col2.clone()),
236 not_eq(col1.clone(), col2.clone()),
237 )
238 .to_string(),
239 "(($.col1 < $.col2) or ($.col1 != $.col2))"
240 );
241
242 assert_eq!(not(col1.clone()).to_string(), "!$.col1");
243
244 assert_eq!(
245 select(vec![FieldName::from("col1")], ident()).to_string(),
246 "${col1}"
247 );
248 assert_eq!(
249 select(
250 vec![FieldName::from("col1"), FieldName::from("col2")],
251 ident()
252 )
253 .to_string(),
254 "${col1, col2}"
255 );
256 assert_eq!(
257 select_exclude(
258 vec![FieldName::from("col1"), FieldName::from("col2")],
259 ident()
260 )
261 .to_string(),
262 "$~{col1, col2}"
263 );
264
265 assert_eq!(lit(Scalar::from(0_u8)).to_string(), "0_u8");
266 assert_eq!(lit(Scalar::from(0.0_f32)).to_string(), "0_f32");
267 assert_eq!(
268 lit(Scalar::from(i64::MAX)).to_string(),
269 "9223372036854775807_i64"
270 );
271 assert_eq!(lit(Scalar::from(true)).to_string(), "true");
272 assert_eq!(
273 lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
274 "null"
275 );
276
277 assert_eq!(
278 lit(Scalar::struct_(
279 DType::Struct(
280 Arc::new(StructDType::new(
281 Arc::from([Arc::from("dog"), Arc::from("cat")]),
282 vec![
283 DType::Primitive(PType::U32, Nullability::NonNullable),
284 DType::Utf8(Nullability::NonNullable)
285 ],
286 )),
287 Nullability::NonNullable
288 ),
289 vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
290 ))
291 .to_string(),
292 "{dog:32_u32,cat:\"rufus\"}"
293 );
294 }
295}