vortex_expr/
lib.rs

1extern crate core;
2
3use std::any::Any;
4use std::fmt::{Debug, Display};
5use std::sync::Arc;
6
7use dyn_hash::DynHash;
8
9mod binary;
10
11mod between;
12pub mod datafusion;
13mod field;
14pub mod forms;
15mod get_item;
16mod identity;
17mod like;
18mod literal;
19mod merge;
20mod not;
21mod operators;
22mod pack;
23pub mod pruning;
24#[cfg(feature = "proto")]
25mod registry;
26mod select;
27pub mod transform;
28pub mod traversal;
29
30pub use binary::*;
31pub use get_item::*;
32pub use identity::*;
33pub use like::*;
34pub use literal::*;
35pub use merge::*;
36pub use not::*;
37pub use operators::*;
38pub use pack::*;
39#[cfg(feature = "proto")]
40pub use registry::deserialize_expr;
41pub use select::*;
42use vortex_array::aliases::hash_set::HashSet;
43use vortex_array::{Array, ArrayRef};
44use vortex_dtype::{DType, FieldName};
45use vortex_error::{VortexResult, VortexUnwrap};
46#[cfg(feature = "proto")]
47use vortex_proto::expr;
48#[cfg(feature = "proto")]
49use vortex_proto::expr::{Expr, kind};
50
51use crate::traversal::{Node, ReferenceCollector};
52
53pub type ExprRef = Arc<dyn VortexExpr>;
54
55#[cfg(feature = "proto")]
56pub trait Id {
57    fn id(&self) -> &'static str;
58}
59
60#[cfg(feature = "proto")]
61pub trait ExprDeserialize: Id + Sync {
62    fn deserialize(&self, kind: &kind::Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
63}
64
65#[cfg(feature = "proto")]
66pub trait ExprSerializable {
67    fn id(&self) -> &'static str;
68
69    fn serialize_kind(&self) -> VortexResult<kind::Kind>;
70}
71
72#[cfg(not(feature = "proto"))]
73pub trait ExprSerializable {}
74#[cfg(not(feature = "proto"))]
75impl<T> ExprSerializable for T {}
76
77/// Represents logical operation on [`ArrayRef`]s
78pub trait VortexExpr: Debug + Send + Sync + DynEq + DynHash + Display + ExprSerializable {
79    /// Convert expression reference to reference of [`Any`] type
80    fn as_any(&self) -> &dyn Any;
81
82    /// Compute result of expression on given batch producing a new batch
83    ///
84    fn evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
85        let result = self.unchecked_evaluate(batch)?;
86        assert_eq!(
87            result.dtype(),
88            &self.return_dtype(batch.dtype())?,
89            "Expression {} returned dtype {} but declared return_dtype of {}",
90            self,
91            result.dtype(),
92            self.return_dtype(batch.dtype())?,
93        );
94        Ok(result)
95    }
96
97    /// Compute result of expression on given batch producing a new batch
98    ///
99    /// "Unchecked" means that this function lacks a debug assertion that the returned array matches
100    /// the [VortexExpr::return_dtype] method. Use instead the [VortexExpr::evaluate] function which
101    /// includes such an assertion.
102    fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef>;
103
104    fn children(&self) -> Vec<&ExprRef>;
105
106    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;
107
108    /// Compute the type of the array returned by [VortexExpr::evaluate].
109    fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType>;
110}
111
112pub trait VortexExprExt {
113    /// Accumulate all field references from this expression and its children in a set
114    fn references(&self) -> HashSet<FieldName>;
115
116    #[cfg(feature = "proto")]
117    fn serialize(&self) -> VortexResult<Expr>;
118}
119
120impl VortexExprExt for ExprRef {
121    fn references(&self) -> HashSet<FieldName> {
122        let mut collector = ReferenceCollector::new();
123        // The collector is infallible, so we can unwrap the result
124        self.accept(&mut collector).vortex_unwrap();
125        collector.into_fields()
126    }
127
128    #[cfg(feature = "proto")]
129    fn serialize(&self) -> VortexResult<Expr> {
130        let children = self
131            .children()
132            .iter()
133            .map(|e| e.serialize())
134            .collect::<VortexResult<_>>()?;
135
136        Ok(Expr {
137            id: self.id().to_string(),
138            children,
139            kind: Some(expr::Kind {
140                kind: Some(self.serialize_kind()?),
141            }),
142        })
143    }
144}
145
146/// Splits top level and operations into separate expressions
147pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
148    let mut conjunctions = vec![];
149    split_inner(expr, &mut conjunctions);
150    conjunctions
151}
152
153fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
154    match expr.as_any().downcast_ref::<BinaryExpr>() {
155        Some(bexp) if bexp.op() == Operator::And => {
156            split_inner(bexp.lhs(), exprs);
157            split_inner(bexp.rhs(), exprs);
158        }
159        Some(_) | None => {
160            exprs.push(expr.clone());
161        }
162    }
163}
164
165// Adapted from apache/datafusion https://github.com/apache/datafusion/blob/f31ca5b927c040ce03f6a3c8c8dc3d7f4ef5be34/datafusion/physical-expr-common/src/physical_expr.rs#L156
166/// [`VortexExpr`] can't be constrained by [`Eq`] directly because it must remain object
167/// safe. To ease implementation blanket implementation is provided for [`Eq`] types.
168pub trait DynEq {
169    fn dyn_eq(&self, other: &dyn Any) -> bool;
170}
171
172impl<T: Eq + Any> DynEq for T {
173    fn dyn_eq(&self, other: &dyn Any) -> bool {
174        other.downcast_ref::<Self>() == Some(self)
175    }
176}
177
178impl PartialEq for dyn VortexExpr {
179    fn eq(&self, other: &Self) -> bool {
180        self.dyn_eq(other.as_any())
181    }
182}
183
184impl Eq for dyn VortexExpr {}
185
186dyn_hash::hash_trait_object!(VortexExpr);
187
188#[cfg(feature = "test-harness")]
189pub mod test_harness {
190    use std::sync::Arc;
191
192    use vortex_dtype::{DType, Nullability, PType, StructDType};
193
194    pub fn struct_dtype() -> DType {
195        DType::Struct(
196            Arc::new(StructDType::new(
197                [
198                    "a".into(),
199                    "col1".into(),
200                    "col2".into(),
201                    "bool1".into(),
202                    "bool2".into(),
203                ]
204                .into(),
205                vec![
206                    DType::Primitive(PType::I32, Nullability::NonNullable),
207                    DType::Primitive(PType::U16, Nullability::Nullable),
208                    DType::Primitive(PType::U16, Nullability::Nullable),
209                    DType::Bool(Nullability::NonNullable),
210                    DType::Bool(Nullability::NonNullable),
211                ],
212            )),
213            Nullability::NonNullable,
214        )
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use vortex_dtype::{DType, Nullability, PType, StructDType};
221    use vortex_scalar::Scalar;
222
223    use super::*;
224
225    #[test]
226    fn basic_expr_split_test() {
227        let lhs = get_item("col1", ident());
228        let rhs = lit(1);
229        let expr = eq(lhs, rhs);
230        let conjunction = split_conjunction(&expr);
231        assert_eq!(conjunction.len(), 1);
232    }
233
234    #[test]
235    fn basic_conjunction_split_test() {
236        let lhs = get_item("col1", ident());
237        let rhs = lit(1);
238        let expr = and(lhs, rhs);
239        let conjunction = split_conjunction(&expr);
240        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
241    }
242
243    #[test]
244    fn expr_display() {
245        assert_eq!(col("a").to_string(), "$.a");
246        assert_eq!(Identity.to_string(), "$");
247
248        let col1: Arc<dyn VortexExpr> = col("col1");
249        let col2: Arc<dyn VortexExpr> = col("col2");
250        assert_eq!(
251            and(col1.clone(), col2.clone()).to_string(),
252            "($.col1 and $.col2)"
253        );
254        assert_eq!(
255            or(col1.clone(), col2.clone()).to_string(),
256            "($.col1 or $.col2)"
257        );
258        assert_eq!(
259            eq(col1.clone(), col2.clone()).to_string(),
260            "($.col1 = $.col2)"
261        );
262        assert_eq!(
263            not_eq(col1.clone(), col2.clone()).to_string(),
264            "($.col1 != $.col2)"
265        );
266        assert_eq!(
267            gt(col1.clone(), col2.clone()).to_string(),
268            "($.col1 > $.col2)"
269        );
270        assert_eq!(
271            gt_eq(col1.clone(), col2.clone()).to_string(),
272            "($.col1 >= $.col2)"
273        );
274        assert_eq!(
275            lt(col1.clone(), col2.clone()).to_string(),
276            "($.col1 < $.col2)"
277        );
278        assert_eq!(
279            lt_eq(col1.clone(), col2.clone()).to_string(),
280            "($.col1 <= $.col2)"
281        );
282
283        assert_eq!(
284            or(
285                lt(col1.clone(), col2.clone()),
286                not_eq(col1.clone(), col2.clone()),
287            )
288            .to_string(),
289            "(($.col1 < $.col2) or ($.col1 != $.col2))"
290        );
291
292        assert_eq!(not(col1.clone()).to_string(), "!$.col1");
293
294        assert_eq!(
295            select(vec![FieldName::from("col1")], ident()).to_string(),
296            "${col1}"
297        );
298        assert_eq!(
299            select(
300                vec![FieldName::from("col1"), FieldName::from("col2")],
301                ident()
302            )
303            .to_string(),
304            "${col1, col2}"
305        );
306        assert_eq!(
307            select_exclude(
308                vec![FieldName::from("col1"), FieldName::from("col2")],
309                ident()
310            )
311            .to_string(),
312            "$~{col1, col2}"
313        );
314
315        assert_eq!(lit(Scalar::from(0_u8)).to_string(), "0_u8");
316        assert_eq!(lit(Scalar::from(0.0_f32)).to_string(), "0_f32");
317        assert_eq!(
318            lit(Scalar::from(i64::MAX)).to_string(),
319            "9223372036854775807_i64"
320        );
321        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
322        assert_eq!(
323            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
324            "null"
325        );
326
327        assert_eq!(
328            lit(Scalar::struct_(
329                DType::Struct(
330                    Arc::new(StructDType::new(
331                        Arc::from([Arc::from("dog"), Arc::from("cat")]),
332                        vec![
333                            DType::Primitive(PType::U32, Nullability::NonNullable),
334                            DType::Utf8(Nullability::NonNullable)
335                        ],
336                    )),
337                    Nullability::NonNullable
338                ),
339                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
340            ))
341            .to_string(),
342            "{dog:32_u32,cat:\"rufus\"}"
343        );
344    }
345
346    #[cfg(feature = "proto")]
347    mod tests_proto {
348
349        use crate::{VortexExprExt, deserialize_expr, eq, ident, lit};
350
351        #[test]
352        fn round_trip_serde() {
353            let expr = eq(ident(), lit(1));
354            let res = expr.serialize().unwrap();
355            let final_ = deserialize_expr(&res).unwrap();
356
357            assert_eq!(&expr, &final_);
358        }
359    }
360}