vortex_expr/
lib.rs

1use std::any::Any;
2use std::fmt::{Debug, Display};
3use std::hash::Hash;
4use std::sync::Arc;
5
6use dyn_hash::DynHash;
7pub use exprs::*;
8mod analysis;
9#[cfg(feature = "arbitrary")]
10pub mod arbitrary;
11mod exprs;
12mod field;
13pub mod forms;
14pub mod pruning;
15#[cfg(feature = "proto")]
16mod registry;
17mod scope;
18mod scope_vars;
19pub mod transform;
20pub mod traversal;
21
22pub use analysis::*;
23pub use between::*;
24pub use binary::*;
25pub use cast::*;
26pub use get_item::*;
27pub use is_null::*;
28pub use like::*;
29pub use list_contains::*;
30pub use literal::*;
31pub use merge::*;
32pub use not::*;
33pub use operators::*;
34pub use pack::*;
35#[cfg(feature = "proto")]
36pub use registry::deserialize_expr;
37pub use scope::*;
38pub use select::*;
39pub use var::*;
40use vortex_array::{Array, ArrayRef};
41use vortex_dtype::{DType, FieldName, FieldPath};
42use vortex_error::{VortexResult, VortexUnwrap};
43#[cfg(feature = "proto")]
44use vortex_proto::expr;
45#[cfg(feature = "proto")]
46use vortex_proto::expr::{Expr, kind};
47use vortex_utils::aliases::hash_set::HashSet;
48
49use crate::traversal::{Node, ReferenceCollector, VarsCollector};
50
51pub type ExprRef = Arc<dyn VortexExpr>;
52
53#[cfg(feature = "proto")]
54pub trait Id {
55    fn id(&self) -> &'static str;
56}
57
58#[cfg(feature = "proto")]
59pub trait ExprDeserialize: Id + Sync {
60    fn deserialize(&self, kind: &kind::Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
61}
62
63#[cfg(feature = "proto")]
64pub trait ExprSerializable {
65    fn id(&self) -> &'static str;
66
67    fn serialize_kind(&self) -> VortexResult<kind::Kind>;
68}
69
70#[cfg(not(feature = "proto"))]
71pub trait ExprSerializable {}
72#[cfg(not(feature = "proto"))]
73impl<T> ExprSerializable for T {}
74/// Represents logical operation on [`ArrayRef`]s
75pub trait VortexExpr:
76    Debug + Send + Sync + DynEq + DynHash + Display + ExprSerializable + AnalysisExpr
77{
78    /// Convert expression reference to reference of [`Any`] type
79    fn as_any(&self) -> &dyn Any;
80
81    /// Compute result of expression on given batch producing a new batch
82    fn evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
83        let result = self.unchecked_evaluate(scope)?;
84        assert_eq!(
85            result.dtype(),
86            &self.return_dtype(&scope.into())?,
87            "Expression {} returned dtype {} but declared return_dtype of {}",
88            self,
89            result.dtype(),
90            self.return_dtype(&scope.into())?,
91        );
92        Ok(result)
93    }
94
95    /// Compute result of expression on given batch producing a new batch
96    ///
97    /// "Unchecked" means that this function lacks a debug assertion that the returned array matches
98    /// the [VortexExpr::return_dtype] method. Use instead the [VortexExpr::evaluate] function which
99    /// includes such an assertion.
100    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef>;
101
102    fn children(&self) -> Vec<&ExprRef>;
103
104    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;
105
106    /// Compute the type of the array returned by [VortexExpr::evaluate].
107    fn return_dtype(&self, scope: &ScopeDType) -> VortexResult<DType>;
108}
109
110pub trait VortexExprExt {
111    /// Accumulate all field references from this expression and its children in a set
112    fn field_references(&self) -> HashSet<FieldName>;
113
114    fn vars(&self) -> HashSet<Identifier>;
115
116    #[cfg(feature = "proto")]
117    fn serialize(&self) -> VortexResult<Expr>;
118}
119
120impl VortexExprExt for ExprRef {
121    fn field_references(&self) -> HashSet<FieldName> {
122        let mut collector = ReferenceCollector::new();
123        // The collector is infallible, so we can unwrap the result
124        self.accept(&mut collector).vortex_unwrap();
125        collector.into_fields()
126    }
127
128    fn vars(&self) -> HashSet<Identifier> {
129        let mut collector = VarsCollector::new();
130        // The collector is infallible, so we can unwrap the result
131        self.accept(&mut collector).vortex_unwrap();
132        collector.into_vars()
133    }
134
135    #[cfg(feature = "proto")]
136    fn serialize(&self) -> VortexResult<Expr> {
137        let children = self
138            .children()
139            .iter()
140            .map(|e| e.serialize())
141            .collect::<VortexResult<_>>()?;
142
143        Ok(Expr {
144            id: self.id().to_string(),
145            children,
146            kind: Some(expr::Kind {
147                kind: Some(self.serialize_kind()?),
148            }),
149        })
150    }
151}
152
153#[derive(Clone, Debug, Hash, PartialEq, Eq)]
154pub struct AccessPath {
155    field_path: FieldPath,
156    identifier: Identifier,
157}
158
159impl AccessPath {
160    pub fn root_field(path: FieldName) -> Self {
161        Self {
162            field_path: FieldPath::from_name(path),
163            identifier: Identifier::Identity,
164        }
165    }
166
167    pub fn new(path: FieldPath, identifier: Identifier) -> Self {
168        Self {
169            field_path: path,
170            identifier,
171        }
172    }
173
174    pub fn identifier(&self) -> &Identifier {
175        &self.identifier
176    }
177
178    pub fn field_path(&self) -> &FieldPath {
179        &self.field_path
180    }
181}
182
183/// Splits top level and operations into separate expressions
184pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
185    let mut conjunctions = vec![];
186    split_inner(expr, &mut conjunctions);
187    conjunctions
188}
189
190fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
191    match expr.as_any().downcast_ref::<BinaryExpr>() {
192        Some(bexp) if bexp.op() == Operator::And => {
193            split_inner(bexp.lhs(), exprs);
194            split_inner(bexp.rhs(), exprs);
195        }
196        Some(_) | None => {
197            exprs.push(expr.clone());
198        }
199    }
200}
201
202// Adapted from apache/datafusion https://github.com/apache/datafusion/blob/f31ca5b927c040ce03f6a3c8c8dc3d7f4ef5be34/datafusion/physical-expr-common/src/physical_expr.rs#L156
203/// [`VortexExpr`] can't be constrained by [`Eq`] directly because it must remain object
204/// safe. To ease implementation blanket implementation is provided for [`Eq`] types.
205pub trait DynEq {
206    fn dyn_eq(&self, other: &dyn Any) -> bool;
207}
208
209impl<T: Eq + Any> DynEq for T {
210    fn dyn_eq(&self, other: &dyn Any) -> bool {
211        other.downcast_ref::<Self>() == Some(self)
212    }
213}
214
215impl PartialEq for dyn VortexExpr {
216    fn eq(&self, other: &Self) -> bool {
217        self.dyn_eq(other.as_any())
218    }
219}
220
221impl Eq for dyn VortexExpr {}
222
223dyn_hash::hash_trait_object!(VortexExpr);
224
225/// An expression wrapper that performs pointer equality.
226#[derive(Clone)]
227pub struct ExactExpr(pub ExprRef);
228
229impl PartialEq for ExactExpr {
230    fn eq(&self, other: &Self) -> bool {
231        Arc::ptr_eq(&self.0, &other.0)
232    }
233}
234
235impl Eq for ExactExpr {}
236
237impl Hash for ExactExpr {
238    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
239        Arc::as_ptr(&self.0).hash(state)
240    }
241}
242
243#[cfg(feature = "test-harness")]
244pub mod test_harness {
245
246    use vortex_dtype::{DType, Nullability, PType, StructFields};
247
248    pub fn struct_dtype() -> DType {
249        DType::Struct(
250            StructFields::new(
251                ["a", "col1", "col2", "bool1", "bool2"].into(),
252                vec![
253                    DType::Primitive(PType::I32, Nullability::NonNullable),
254                    DType::Primitive(PType::U16, Nullability::Nullable),
255                    DType::Primitive(PType::U16, Nullability::Nullable),
256                    DType::Bool(Nullability::NonNullable),
257                    DType::Bool(Nullability::NonNullable),
258                ],
259            ),
260            Nullability::NonNullable,
261        )
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
268    use vortex_scalar::Scalar;
269
270    use super::*;
271
272    #[test]
273    fn basic_expr_split_test() {
274        let lhs = get_item("col1", root());
275        let rhs = lit(1);
276        let expr = eq(lhs, rhs);
277        let conjunction = split_conjunction(&expr);
278        assert_eq!(conjunction.len(), 1);
279    }
280
281    #[test]
282    fn basic_conjunction_split_test() {
283        let lhs = get_item("col1", root());
284        let rhs = lit(1);
285        let expr = and(lhs, rhs);
286        let conjunction = split_conjunction(&expr);
287        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
288    }
289
290    #[test]
291    fn expr_display() {
292        assert_eq!(col("a").to_string(), "$.a");
293        assert_eq!(root().to_string(), "$");
294
295        let col1: Arc<dyn VortexExpr> = col("col1");
296        let col2: Arc<dyn VortexExpr> = col("col2");
297        assert_eq!(
298            and(col1.clone(), col2.clone()).to_string(),
299            "($.col1 and $.col2)"
300        );
301        assert_eq!(
302            or(col1.clone(), col2.clone()).to_string(),
303            "($.col1 or $.col2)"
304        );
305        assert_eq!(
306            eq(col1.clone(), col2.clone()).to_string(),
307            "($.col1 = $.col2)"
308        );
309        assert_eq!(
310            not_eq(col1.clone(), col2.clone()).to_string(),
311            "($.col1 != $.col2)"
312        );
313        assert_eq!(
314            gt(col1.clone(), col2.clone()).to_string(),
315            "($.col1 > $.col2)"
316        );
317        assert_eq!(
318            gt_eq(col1.clone(), col2.clone()).to_string(),
319            "($.col1 >= $.col2)"
320        );
321        assert_eq!(
322            lt(col1.clone(), col2.clone()).to_string(),
323            "($.col1 < $.col2)"
324        );
325        assert_eq!(
326            lt_eq(col1.clone(), col2.clone()).to_string(),
327            "($.col1 <= $.col2)"
328        );
329
330        assert_eq!(
331            or(
332                lt(col1.clone(), col2.clone()),
333                not_eq(col1.clone(), col2.clone()),
334            )
335            .to_string(),
336            "(($.col1 < $.col2) or ($.col1 != $.col2))"
337        );
338
339        assert_eq!(not(col1.clone()).to_string(), "!$.col1");
340
341        assert_eq!(
342            select(vec![FieldName::from("col1")], root()).to_string(),
343            "${col1}"
344        );
345        assert_eq!(
346            select(
347                vec![FieldName::from("col1"), FieldName::from("col2")],
348                root()
349            )
350            .to_string(),
351            "${col1, col2}"
352        );
353        assert_eq!(
354            select_exclude(
355                vec![FieldName::from("col1"), FieldName::from("col2")],
356                root()
357            )
358            .to_string(),
359            "$~{col1, col2}"
360        );
361
362        assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
363        assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
364        assert_eq!(
365            lit(Scalar::from(i64::MAX)).to_string(),
366            "9223372036854775807i64"
367        );
368        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
369        assert_eq!(
370            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
371            "null"
372        );
373
374        assert_eq!(
375            lit(Scalar::struct_(
376                DType::Struct(
377                    StructFields::new(
378                        FieldNames::from(["dog", "cat"]),
379                        vec![
380                            DType::Primitive(PType::U32, Nullability::NonNullable),
381                            DType::Utf8(Nullability::NonNullable)
382                        ],
383                    ),
384                    Nullability::NonNullable
385                ),
386                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
387            ))
388            .to_string(),
389            "{dog: 32u32, cat: \"rufus\"}"
390        );
391    }
392
393    #[cfg(feature = "proto")]
394    mod tests_proto {
395        use crate::{VortexExprExt, deserialize_expr, eq, lit, root};
396
397        #[test]
398        fn round_trip_serde() {
399            let expr = eq(root(), lit(1));
400            let res = expr.serialize().unwrap();
401            let final_ = deserialize_expr(&res).unwrap();
402
403            assert_eq!(&expr, &final_);
404        }
405    }
406}