vortex_expr/
lib.rs

1extern crate core;
2
3use std::any::Any;
4use std::fmt::{Debug, Display};
5use std::sync::Arc;
6
7use dyn_hash::DynHash;
8
9mod binary;
10
11mod between;
12pub mod datafusion;
13mod field;
14pub mod forms;
15mod get_item;
16mod identity;
17mod like;
18mod literal;
19mod merge;
20mod not;
21mod operators;
22mod pack;
23pub mod pruning;
24mod select;
25pub mod transform;
26pub mod traversal;
27
28pub use binary::*;
29pub use get_item::*;
30pub use identity::*;
31pub use like::*;
32pub use literal::*;
33pub use merge::*;
34pub use not::*;
35pub use operators::*;
36pub use pack::*;
37pub use select::*;
38use vortex_array::aliases::hash_set::HashSet;
39use vortex_array::{Array, ArrayRef};
40use vortex_dtype::{DType, FieldName};
41use vortex_error::{VortexResult, VortexUnwrap};
42
43use crate::traversal::{Node, ReferenceCollector};
44
45pub type ExprRef = Arc<dyn VortexExpr>;
46
47/// Represents logical operation on [`ArrayRef`]s
48pub trait VortexExpr: Debug + Send + Sync + DynEq + DynHash + Display {
49    /// Convert expression reference to reference of [`Any`] type
50    fn as_any(&self) -> &dyn Any;
51
52    /// Compute result of expression on given batch producing a new batch
53    ///
54    fn evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
55        let result = self.unchecked_evaluate(batch)?;
56        assert_eq!(
57            result.dtype(),
58            &self.return_dtype(batch.dtype())?,
59            "Expression {} returned dtype {} but declared return_dtype of {}",
60            self,
61            result.dtype(),
62            self.return_dtype(batch.dtype())?,
63        );
64        Ok(result)
65    }
66
67    /// Compute result of expression on given batch producing a new batch
68    ///
69    /// "Unchecked" means that this function lacks a debug assertion that the returned array matches
70    /// the [VortexExpr::return_dtype] method. Use instead the [VortexExpr::evaluate] function which
71    /// includes such an assertion.
72    fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef>;
73
74    fn children(&self) -> Vec<&ExprRef>;
75
76    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;
77
78    /// Compute the type of the array returned by [VortexExpr::evaluate].
79    fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType>;
80}
81
82pub trait VortexExprExt {
83    /// Accumulate all field references from this expression and its children in a set
84    fn references(&self) -> HashSet<FieldName>;
85}
86
87impl VortexExprExt for ExprRef {
88    fn references(&self) -> HashSet<FieldName> {
89        let mut collector = ReferenceCollector::new();
90        // The collector is infallible, so we can unwrap the result
91        self.accept(&mut collector).vortex_unwrap();
92        collector.into_fields()
93    }
94}
95
96/// Splits top level and operations into separate expressions
97pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
98    let mut conjunctions = vec![];
99    split_inner(expr, &mut conjunctions);
100    conjunctions
101}
102
103fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
104    match expr.as_any().downcast_ref::<BinaryExpr>() {
105        Some(bexp) if bexp.op() == Operator::And => {
106            split_inner(bexp.lhs(), exprs);
107            split_inner(bexp.rhs(), exprs);
108        }
109        Some(_) | None => {
110            exprs.push(expr.clone());
111        }
112    }
113}
114
115// Adapted from apache/datafusion https://github.com/apache/datafusion/blob/f31ca5b927c040ce03f6a3c8c8dc3d7f4ef5be34/datafusion/physical-expr-common/src/physical_expr.rs#L156
116/// [`VortexExpr`] can't be constrained by [`Eq`] directly because it must remain object
117/// safe. To ease implementation blanket implementation is provided for [`Eq`] types.
118pub trait DynEq {
119    fn dyn_eq(&self, other: &dyn Any) -> bool;
120}
121
122impl<T: Eq + Any> DynEq for T {
123    fn dyn_eq(&self, other: &dyn Any) -> bool {
124        other.downcast_ref::<Self>() == Some(self)
125    }
126}
127
128impl PartialEq for dyn VortexExpr {
129    fn eq(&self, other: &Self) -> bool {
130        self.dyn_eq(other.as_any())
131    }
132}
133
134impl Eq for dyn VortexExpr {}
135
136dyn_hash::hash_trait_object!(VortexExpr);
137
138#[cfg(feature = "test-harness")]
139pub mod test_harness {
140    use std::sync::Arc;
141
142    use vortex_dtype::{DType, Nullability, PType, StructDType};
143
144    pub fn struct_dtype() -> DType {
145        DType::Struct(
146            Arc::new(StructDType::new(
147                [
148                    "a".into(),
149                    "col1".into(),
150                    "col2".into(),
151                    "bool1".into(),
152                    "bool2".into(),
153                ]
154                .into(),
155                vec![
156                    DType::Primitive(PType::I32, Nullability::NonNullable),
157                    DType::Primitive(PType::U16, Nullability::Nullable),
158                    DType::Primitive(PType::U16, Nullability::Nullable),
159                    DType::Bool(Nullability::NonNullable),
160                    DType::Bool(Nullability::NonNullable),
161                ],
162            )),
163            Nullability::NonNullable,
164        )
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use vortex_dtype::{DType, Nullability, PType, StructDType};
171    use vortex_scalar::Scalar;
172
173    use super::*;
174
175    #[test]
176    fn basic_expr_split_test() {
177        let lhs = get_item("col1", ident());
178        let rhs = lit(1);
179        let expr = eq(lhs, rhs);
180        let conjunction = split_conjunction(&expr);
181        assert_eq!(conjunction.len(), 1);
182    }
183
184    #[test]
185    fn basic_conjunction_split_test() {
186        let lhs = get_item("col1", ident());
187        let rhs = lit(1);
188        let expr = and(lhs, rhs);
189        let conjunction = split_conjunction(&expr);
190        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
191    }
192
193    #[test]
194    fn expr_display() {
195        assert_eq!(col("a").to_string(), "$.a");
196        assert_eq!(Identity.to_string(), "$");
197
198        let col1: Arc<dyn VortexExpr> = col("col1");
199        let col2: Arc<dyn VortexExpr> = col("col2");
200        assert_eq!(
201            and(col1.clone(), col2.clone()).to_string(),
202            "($.col1 and $.col2)"
203        );
204        assert_eq!(
205            or(col1.clone(), col2.clone()).to_string(),
206            "($.col1 or $.col2)"
207        );
208        assert_eq!(
209            eq(col1.clone(), col2.clone()).to_string(),
210            "($.col1 = $.col2)"
211        );
212        assert_eq!(
213            not_eq(col1.clone(), col2.clone()).to_string(),
214            "($.col1 != $.col2)"
215        );
216        assert_eq!(
217            gt(col1.clone(), col2.clone()).to_string(),
218            "($.col1 > $.col2)"
219        );
220        assert_eq!(
221            gt_eq(col1.clone(), col2.clone()).to_string(),
222            "($.col1 >= $.col2)"
223        );
224        assert_eq!(
225            lt(col1.clone(), col2.clone()).to_string(),
226            "($.col1 < $.col2)"
227        );
228        assert_eq!(
229            lt_eq(col1.clone(), col2.clone()).to_string(),
230            "($.col1 <= $.col2)"
231        );
232
233        assert_eq!(
234            or(
235                lt(col1.clone(), col2.clone()),
236                not_eq(col1.clone(), col2.clone()),
237            )
238            .to_string(),
239            "(($.col1 < $.col2) or ($.col1 != $.col2))"
240        );
241
242        assert_eq!(not(col1.clone()).to_string(), "!$.col1");
243
244        assert_eq!(
245            select(vec![FieldName::from("col1")], ident()).to_string(),
246            "${col1}"
247        );
248        assert_eq!(
249            select(
250                vec![FieldName::from("col1"), FieldName::from("col2")],
251                ident()
252            )
253            .to_string(),
254            "${col1, col2}"
255        );
256        assert_eq!(
257            select_exclude(
258                vec![FieldName::from("col1"), FieldName::from("col2")],
259                ident()
260            )
261            .to_string(),
262            "$~{col1, col2}"
263        );
264
265        assert_eq!(lit(Scalar::from(0_u8)).to_string(), "0_u8");
266        assert_eq!(lit(Scalar::from(0.0_f32)).to_string(), "0_f32");
267        assert_eq!(
268            lit(Scalar::from(i64::MAX)).to_string(),
269            "9223372036854775807_i64"
270        );
271        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
272        assert_eq!(
273            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
274            "null"
275        );
276
277        assert_eq!(
278            lit(Scalar::struct_(
279                DType::Struct(
280                    Arc::new(StructDType::new(
281                        Arc::from([Arc::from("dog"), Arc::from("cat")]),
282                        vec![
283                            DType::Primitive(PType::U32, Nullability::NonNullable),
284                            DType::Utf8(Nullability::NonNullable)
285                        ],
286                    )),
287                    Nullability::NonNullable
288                ),
289                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
290            ))
291            .to_string(),
292            "{dog:32_u32,cat:\"rufus\"}"
293        );
294    }
295}