1extern crate core;
2
3use std::any::Any;
4use std::fmt::{Debug, Display};
5use std::sync::Arc;
6
7use dyn_hash::DynHash;
8
9mod binary;
10
11mod between;
12pub mod datafusion;
13mod field;
14pub mod forms;
15mod get_item;
16mod identity;
17mod like;
18mod literal;
19mod merge;
20mod not;
21mod operators;
22mod pack;
23pub mod pruning;
24#[cfg(feature = "proto")]
25mod registry;
26mod select;
27pub mod transform;
28pub mod traversal;
29
30pub use between::*;
31pub use binary::*;
32pub use get_item::*;
33pub use identity::*;
34pub use like::*;
35pub use literal::*;
36pub use merge::*;
37pub use not::*;
38pub use operators::*;
39pub use pack::*;
40#[cfg(feature = "proto")]
41pub use registry::deserialize_expr;
42pub use select::*;
43use vortex_array::aliases::hash_set::HashSet;
44use vortex_array::{Array, ArrayRef};
45use vortex_dtype::{DType, FieldName};
46use vortex_error::{VortexResult, VortexUnwrap};
47#[cfg(feature = "proto")]
48use vortex_proto::expr;
49#[cfg(feature = "proto")]
50use vortex_proto::expr::{Expr, kind};
51
52use crate::traversal::{Node, ReferenceCollector};
53
54pub type ExprRef = Arc<dyn VortexExpr>;
55
56#[cfg(feature = "proto")]
57pub trait Id {
58    fn id(&self) -> &'static str;
59}
60
61#[cfg(feature = "proto")]
62pub trait ExprDeserialize: Id + Sync {
63    fn deserialize(&self, kind: &kind::Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
64}
65
66#[cfg(feature = "proto")]
67pub trait ExprSerializable {
68    fn id(&self) -> &'static str;
69
70    fn serialize_kind(&self) -> VortexResult<kind::Kind>;
71}
72
73#[cfg(not(feature = "proto"))]
74pub trait ExprSerializable {}
75#[cfg(not(feature = "proto"))]
76impl<T> ExprSerializable for T {}
77
78pub trait VortexExpr: Debug + Send + Sync + DynEq + DynHash + Display + ExprSerializable {
80    fn as_any(&self) -> &dyn Any;
82
83    fn evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
86        let result = self.unchecked_evaluate(batch)?;
87        assert_eq!(
88            result.dtype(),
89            &self.return_dtype(batch.dtype())?,
90            "Expression {} returned dtype {} but declared return_dtype of {}",
91            self,
92            result.dtype(),
93            self.return_dtype(batch.dtype())?,
94        );
95        Ok(result)
96    }
97
98    fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef>;
104
105    fn children(&self) -> Vec<&ExprRef>;
106
107    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;
108
109    fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType>;
111}
112
113pub trait VortexExprExt {
114    fn references(&self) -> HashSet<FieldName>;
116
117    #[cfg(feature = "proto")]
118    fn serialize(&self) -> VortexResult<Expr>;
119}
120
121impl VortexExprExt for ExprRef {
122    fn references(&self) -> HashSet<FieldName> {
123        let mut collector = ReferenceCollector::new();
124        self.accept(&mut collector).vortex_unwrap();
126        collector.into_fields()
127    }
128
129    #[cfg(feature = "proto")]
130    fn serialize(&self) -> VortexResult<Expr> {
131        let children = self
132            .children()
133            .iter()
134            .map(|e| e.serialize())
135            .collect::<VortexResult<_>>()?;
136
137        Ok(Expr {
138            id: self.id().to_string(),
139            children,
140            kind: Some(expr::Kind {
141                kind: Some(self.serialize_kind()?),
142            }),
143        })
144    }
145}
146
147pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
149    let mut conjunctions = vec![];
150    split_inner(expr, &mut conjunctions);
151    conjunctions
152}
153
154fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
155    match expr.as_any().downcast_ref::<BinaryExpr>() {
156        Some(bexp) if bexp.op() == Operator::And => {
157            split_inner(bexp.lhs(), exprs);
158            split_inner(bexp.rhs(), exprs);
159        }
160        Some(_) | None => {
161            exprs.push(expr.clone());
162        }
163    }
164}
165
166pub trait DynEq {
170    fn dyn_eq(&self, other: &dyn Any) -> bool;
171}
172
173impl<T: Eq + Any> DynEq for T {
174    fn dyn_eq(&self, other: &dyn Any) -> bool {
175        other.downcast_ref::<Self>() == Some(self)
176    }
177}
178
179impl PartialEq for dyn VortexExpr {
180    fn eq(&self, other: &Self) -> bool {
181        self.dyn_eq(other.as_any())
182    }
183}
184
185impl Eq for dyn VortexExpr {}
186
187dyn_hash::hash_trait_object!(VortexExpr);
188
189#[cfg(feature = "test-harness")]
190pub mod test_harness {
191    use std::sync::Arc;
192
193    use vortex_dtype::{DType, Nullability, PType, StructDType};
194
195    pub fn struct_dtype() -> DType {
196        DType::Struct(
197            Arc::new(StructDType::new(
198                [
199                    "a".into(),
200                    "col1".into(),
201                    "col2".into(),
202                    "bool1".into(),
203                    "bool2".into(),
204                ]
205                .into(),
206                vec![
207                    DType::Primitive(PType::I32, Nullability::NonNullable),
208                    DType::Primitive(PType::U16, Nullability::Nullable),
209                    DType::Primitive(PType::U16, Nullability::Nullable),
210                    DType::Bool(Nullability::NonNullable),
211                    DType::Bool(Nullability::NonNullable),
212                ],
213            )),
214            Nullability::NonNullable,
215        )
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use vortex_dtype::{DType, Nullability, PType, StructDType};
222    use vortex_scalar::Scalar;
223
224    use super::*;
225
226    #[test]
227    fn basic_expr_split_test() {
228        let lhs = get_item("col1", ident());
229        let rhs = lit(1);
230        let expr = eq(lhs, rhs);
231        let conjunction = split_conjunction(&expr);
232        assert_eq!(conjunction.len(), 1);
233    }
234
235    #[test]
236    fn basic_conjunction_split_test() {
237        let lhs = get_item("col1", ident());
238        let rhs = lit(1);
239        let expr = and(lhs, rhs);
240        let conjunction = split_conjunction(&expr);
241        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
242    }
243
244    #[test]
245    fn expr_display() {
246        assert_eq!(col("a").to_string(), "$.a");
247        assert_eq!(Identity.to_string(), "$");
248
249        let col1: Arc<dyn VortexExpr> = col("col1");
250        let col2: Arc<dyn VortexExpr> = col("col2");
251        assert_eq!(
252            and(col1.clone(), col2.clone()).to_string(),
253            "($.col1 and $.col2)"
254        );
255        assert_eq!(
256            or(col1.clone(), col2.clone()).to_string(),
257            "($.col1 or $.col2)"
258        );
259        assert_eq!(
260            eq(col1.clone(), col2.clone()).to_string(),
261            "($.col1 = $.col2)"
262        );
263        assert_eq!(
264            not_eq(col1.clone(), col2.clone()).to_string(),
265            "($.col1 != $.col2)"
266        );
267        assert_eq!(
268            gt(col1.clone(), col2.clone()).to_string(),
269            "($.col1 > $.col2)"
270        );
271        assert_eq!(
272            gt_eq(col1.clone(), col2.clone()).to_string(),
273            "($.col1 >= $.col2)"
274        );
275        assert_eq!(
276            lt(col1.clone(), col2.clone()).to_string(),
277            "($.col1 < $.col2)"
278        );
279        assert_eq!(
280            lt_eq(col1.clone(), col2.clone()).to_string(),
281            "($.col1 <= $.col2)"
282        );
283
284        assert_eq!(
285            or(
286                lt(col1.clone(), col2.clone()),
287                not_eq(col1.clone(), col2.clone()),
288            )
289            .to_string(),
290            "(($.col1 < $.col2) or ($.col1 != $.col2))"
291        );
292
293        assert_eq!(not(col1.clone()).to_string(), "!$.col1");
294
295        assert_eq!(
296            select(vec![FieldName::from("col1")], ident()).to_string(),
297            "${col1}"
298        );
299        assert_eq!(
300            select(
301                vec![FieldName::from("col1"), FieldName::from("col2")],
302                ident()
303            )
304            .to_string(),
305            "${col1, col2}"
306        );
307        assert_eq!(
308            select_exclude(
309                vec![FieldName::from("col1"), FieldName::from("col2")],
310                ident()
311            )
312            .to_string(),
313            "$~{col1, col2}"
314        );
315
316        assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
317        assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
318        assert_eq!(
319            lit(Scalar::from(i64::MAX)).to_string(),
320            "9223372036854775807i64"
321        );
322        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
323        assert_eq!(
324            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
325            "null"
326        );
327
328        assert_eq!(
329            lit(Scalar::struct_(
330                DType::Struct(
331                    Arc::new(StructDType::new(
332                        Arc::from([Arc::from("dog"), Arc::from("cat")]),
333                        vec![
334                            DType::Primitive(PType::U32, Nullability::NonNullable),
335                            DType::Utf8(Nullability::NonNullable)
336                        ],
337                    )),
338                    Nullability::NonNullable
339                ),
340                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
341            ))
342            .to_string(),
343            "{dog: 32u32, cat: \"rufus\"}"
344        );
345    }
346
347    #[cfg(feature = "proto")]
348    mod tests_proto {
349
350        use crate::{VortexExprExt, deserialize_expr, eq, ident, lit};
351
352        #[test]
353        fn round_trip_serde() {
354            let expr = eq(ident(), lit(1));
355            let res = expr.serialize().unwrap();
356            let final_ = deserialize_expr(&res).unwrap();
357
358            assert_eq!(&expr, &final_);
359        }
360    }
361}