vortex_expr/
lib.rs

1use std::any::Any;
2use std::fmt::{Debug, Display};
3use std::sync::Arc;
4
5use dyn_hash::DynHash;
6
7mod binary;
8
9mod between;
10pub mod datafusion;
11mod field;
12pub mod forms;
13mod get_item;
14mod identity;
15mod like;
16mod literal;
17mod merge;
18mod not;
19mod operators;
20mod pack;
21pub mod pruning;
22#[cfg(feature = "proto")]
23mod registry;
24mod select;
25pub mod transform;
26pub mod traversal;
27
28pub use between::*;
29pub use binary::*;
30pub use get_item::*;
31pub use identity::*;
32pub use like::*;
33pub use literal::*;
34pub use merge::*;
35pub use not::*;
36pub use operators::*;
37pub use pack::*;
38#[cfg(feature = "proto")]
39pub use registry::deserialize_expr;
40pub use select::*;
41use vortex_array::aliases::hash_set::HashSet;
42use vortex_array::{Array, ArrayRef};
43use vortex_dtype::{DType, FieldName};
44use vortex_error::{VortexResult, VortexUnwrap};
45#[cfg(feature = "proto")]
46use vortex_proto::expr;
47#[cfg(feature = "proto")]
48use vortex_proto::expr::{Expr, kind};
49
50use crate::traversal::{Node, ReferenceCollector};
51
52pub type ExprRef = Arc<dyn VortexExpr>;
53
54#[cfg(feature = "proto")]
55pub trait Id {
56    fn id(&self) -> &'static str;
57}
58
59#[cfg(feature = "proto")]
60pub trait ExprDeserialize: Id + Sync {
61    fn deserialize(&self, kind: &kind::Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
62}
63
64#[cfg(feature = "proto")]
65pub trait ExprSerializable {
66    fn id(&self) -> &'static str;
67
68    fn serialize_kind(&self) -> VortexResult<kind::Kind>;
69}
70
71#[cfg(not(feature = "proto"))]
72pub trait ExprSerializable {}
73#[cfg(not(feature = "proto"))]
74impl<T> ExprSerializable for T {}
75
76/// Represents logical operation on [`ArrayRef`]s
77pub trait VortexExpr: Debug + Send + Sync + DynEq + DynHash + Display + ExprSerializable {
78    /// Convert expression reference to reference of [`Any`] type
79    fn as_any(&self) -> &dyn Any;
80
81    /// Compute result of expression on given batch producing a new batch
82    ///
83    fn evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
84        let result = self.unchecked_evaluate(batch)?;
85        assert_eq!(
86            result.dtype(),
87            &self.return_dtype(batch.dtype())?,
88            "Expression {} returned dtype {} but declared return_dtype of {}",
89            self,
90            result.dtype(),
91            self.return_dtype(batch.dtype())?,
92        );
93        Ok(result)
94    }
95
96    /// Compute result of expression on given batch producing a new batch
97    ///
98    /// "Unchecked" means that this function lacks a debug assertion that the returned array matches
99    /// the [VortexExpr::return_dtype] method. Use instead the [VortexExpr::evaluate] function which
100    /// includes such an assertion.
101    fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef>;
102
103    fn children(&self) -> Vec<&ExprRef>;
104
105    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;
106
107    /// Compute the type of the array returned by [VortexExpr::evaluate].
108    fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType>;
109}
110
111pub trait VortexExprExt {
112    /// Accumulate all field references from this expression and its children in a set
113    fn references(&self) -> HashSet<FieldName>;
114
115    #[cfg(feature = "proto")]
116    fn serialize(&self) -> VortexResult<Expr>;
117}
118
119impl VortexExprExt for ExprRef {
120    fn references(&self) -> HashSet<FieldName> {
121        let mut collector = ReferenceCollector::new();
122        // The collector is infallible, so we can unwrap the result
123        self.accept(&mut collector).vortex_unwrap();
124        collector.into_fields()
125    }
126
127    #[cfg(feature = "proto")]
128    fn serialize(&self) -> VortexResult<Expr> {
129        let children = self
130            .children()
131            .iter()
132            .map(|e| e.serialize())
133            .collect::<VortexResult<_>>()?;
134
135        Ok(Expr {
136            id: self.id().to_string(),
137            children,
138            kind: Some(expr::Kind {
139                kind: Some(self.serialize_kind()?),
140            }),
141        })
142    }
143}
144
145/// Splits top level and operations into separate expressions
146pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
147    let mut conjunctions = vec![];
148    split_inner(expr, &mut conjunctions);
149    conjunctions
150}
151
152fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
153    match expr.as_any().downcast_ref::<BinaryExpr>() {
154        Some(bexp) if bexp.op() == Operator::And => {
155            split_inner(bexp.lhs(), exprs);
156            split_inner(bexp.rhs(), exprs);
157        }
158        Some(_) | None => {
159            exprs.push(expr.clone());
160        }
161    }
162}
163
164// Adapted from apache/datafusion https://github.com/apache/datafusion/blob/f31ca5b927c040ce03f6a3c8c8dc3d7f4ef5be34/datafusion/physical-expr-common/src/physical_expr.rs#L156
165/// [`VortexExpr`] can't be constrained by [`Eq`] directly because it must remain object
166/// safe. To ease implementation blanket implementation is provided for [`Eq`] types.
167pub trait DynEq {
168    fn dyn_eq(&self, other: &dyn Any) -> bool;
169}
170
171impl<T: Eq + Any> DynEq for T {
172    fn dyn_eq(&self, other: &dyn Any) -> bool {
173        other.downcast_ref::<Self>() == Some(self)
174    }
175}
176
177impl PartialEq for dyn VortexExpr {
178    fn eq(&self, other: &Self) -> bool {
179        self.dyn_eq(other.as_any())
180    }
181}
182
183impl Eq for dyn VortexExpr {}
184
185dyn_hash::hash_trait_object!(VortexExpr);
186
187#[cfg(feature = "test-harness")]
188pub mod test_harness {
189    use std::sync::Arc;
190
191    use vortex_dtype::{DType, Nullability, PType, StructDType};
192
193    pub fn struct_dtype() -> DType {
194        DType::Struct(
195            Arc::new(StructDType::new(
196                [
197                    "a".into(),
198                    "col1".into(),
199                    "col2".into(),
200                    "bool1".into(),
201                    "bool2".into(),
202                ]
203                .into(),
204                vec![
205                    DType::Primitive(PType::I32, Nullability::NonNullable),
206                    DType::Primitive(PType::U16, Nullability::Nullable),
207                    DType::Primitive(PType::U16, Nullability::Nullable),
208                    DType::Bool(Nullability::NonNullable),
209                    DType::Bool(Nullability::NonNullable),
210                ],
211            )),
212            Nullability::NonNullable,
213        )
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use vortex_dtype::{DType, Nullability, PType, StructDType};
220    use vortex_scalar::Scalar;
221
222    use super::*;
223
224    #[test]
225    fn basic_expr_split_test() {
226        let lhs = get_item("col1", ident());
227        let rhs = lit(1);
228        let expr = eq(lhs, rhs);
229        let conjunction = split_conjunction(&expr);
230        assert_eq!(conjunction.len(), 1);
231    }
232
233    #[test]
234    fn basic_conjunction_split_test() {
235        let lhs = get_item("col1", ident());
236        let rhs = lit(1);
237        let expr = and(lhs, rhs);
238        let conjunction = split_conjunction(&expr);
239        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
240    }
241
242    #[test]
243    fn expr_display() {
244        assert_eq!(col("a").to_string(), "$.a");
245        assert_eq!(Identity.to_string(), "$");
246
247        let col1: Arc<dyn VortexExpr> = col("col1");
248        let col2: Arc<dyn VortexExpr> = col("col2");
249        assert_eq!(
250            and(col1.clone(), col2.clone()).to_string(),
251            "($.col1 and $.col2)"
252        );
253        assert_eq!(
254            or(col1.clone(), col2.clone()).to_string(),
255            "($.col1 or $.col2)"
256        );
257        assert_eq!(
258            eq(col1.clone(), col2.clone()).to_string(),
259            "($.col1 = $.col2)"
260        );
261        assert_eq!(
262            not_eq(col1.clone(), col2.clone()).to_string(),
263            "($.col1 != $.col2)"
264        );
265        assert_eq!(
266            gt(col1.clone(), col2.clone()).to_string(),
267            "($.col1 > $.col2)"
268        );
269        assert_eq!(
270            gt_eq(col1.clone(), col2.clone()).to_string(),
271            "($.col1 >= $.col2)"
272        );
273        assert_eq!(
274            lt(col1.clone(), col2.clone()).to_string(),
275            "($.col1 < $.col2)"
276        );
277        assert_eq!(
278            lt_eq(col1.clone(), col2.clone()).to_string(),
279            "($.col1 <= $.col2)"
280        );
281
282        assert_eq!(
283            or(
284                lt(col1.clone(), col2.clone()),
285                not_eq(col1.clone(), col2.clone()),
286            )
287            .to_string(),
288            "(($.col1 < $.col2) or ($.col1 != $.col2))"
289        );
290
291        assert_eq!(not(col1.clone()).to_string(), "!$.col1");
292
293        assert_eq!(
294            select(vec![FieldName::from("col1")], ident()).to_string(),
295            "${col1}"
296        );
297        assert_eq!(
298            select(
299                vec![FieldName::from("col1"), FieldName::from("col2")],
300                ident()
301            )
302            .to_string(),
303            "${col1, col2}"
304        );
305        assert_eq!(
306            select_exclude(
307                vec![FieldName::from("col1"), FieldName::from("col2")],
308                ident()
309            )
310            .to_string(),
311            "$~{col1, col2}"
312        );
313
314        assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
315        assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
316        assert_eq!(
317            lit(Scalar::from(i64::MAX)).to_string(),
318            "9223372036854775807i64"
319        );
320        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
321        assert_eq!(
322            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
323            "null"
324        );
325
326        assert_eq!(
327            lit(Scalar::struct_(
328                DType::Struct(
329                    Arc::new(StructDType::new(
330                        Arc::from([Arc::from("dog"), Arc::from("cat")]),
331                        vec![
332                            DType::Primitive(PType::U32, Nullability::NonNullable),
333                            DType::Utf8(Nullability::NonNullable)
334                        ],
335                    )),
336                    Nullability::NonNullable
337                ),
338                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
339            ))
340            .to_string(),
341            "{dog: 32u32, cat: \"rufus\"}"
342        );
343    }
344
345    #[cfg(feature = "proto")]
346    mod tests_proto {
347
348        use crate::{VortexExprExt, deserialize_expr, eq, ident, lit};
349
350        #[test]
351        fn round_trip_serde() {
352            let expr = eq(ident(), lit(1));
353            let res = expr.serialize().unwrap();
354            let final_ = deserialize_expr(&res).unwrap();
355
356            assert_eq!(&expr, &final_);
357        }
358    }
359}