vortex_expr/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Vortex's expression language.
5//!
6//! All expressions are serializable, and own their own wire format.
7//!
8//! The implementation takes inspiration from [Postgres] and [Apache Datafusion].
9//!
10//! [Postgres]: https://www.postgresql.org/docs/current/sql-expressions.html
11//! [Apache Datafusion]: https://github.com/apache/datafusion/tree/5fac581efbaffd0e6a9edf931182517524526afd/datafusion/expr
12
13use std::any::Any;
14use std::fmt::{Debug, Display, Formatter};
15use std::hash::{Hash, Hasher};
16use std::sync::Arc;
17
18use dyn_hash::DynHash;
19pub use exprs::*;
20pub mod aliases;
21mod analysis;
22#[cfg(feature = "arbitrary")]
23pub mod arbitrary;
24pub mod dyn_traits;
25mod encoding;
26mod exprs;
27mod field;
28pub mod forms;
29pub mod proto;
30pub mod pruning;
31mod registry;
32mod scope;
33mod scope_vars;
34pub mod transform;
35pub mod traversal;
36mod vtable;
37
38pub use analysis::*;
39pub use between::*;
40pub use binary::*;
41pub use cast::*;
42pub use encoding::*;
43pub use get_item::*;
44pub use is_null::*;
45pub use like::*;
46pub use list_contains::*;
47pub use literal::*;
48pub use merge::*;
49pub use not::*;
50pub use operators::*;
51pub use pack::*;
52pub use registry::*;
53pub use root::*;
54pub use scope::*;
55pub use scope_vars::*;
56pub use select::*;
57use vortex_array::operator::OperatorRef;
58use vortex_array::{Array, ArrayRef, SerializeMetadata};
59use vortex_dtype::{DType, FieldName, FieldPath};
60use vortex_error::{VortexExpect, VortexResult, VortexUnwrap, vortex_bail};
61use vortex_utils::aliases::hash_set::HashSet;
62pub use vtable::*;
63
64pub mod display;
65
66use crate::display::{DisplayAs, DisplayFormat};
67use crate::dyn_traits::DynEq;
68use crate::traversal::{NodeExt, ReferenceCollector};
69
70pub trait IntoExpr {
71    /// Convert this type into an expression reference.
72    fn into_expr(self) -> ExprRef;
73}
74
75pub type ExprRef = Arc<dyn VortexExpr>;
76
77/// Represents logical operation on [`ArrayRef`]s
78pub trait VortexExpr:
79    'static + Send + Sync + Debug + DisplayAs + DynEq + DynHash + AnalysisExpr + private::Sealed
80{
81    /// Convert expression reference to reference of [`Any`] type
82    fn as_any(&self) -> &dyn Any;
83
84    /// Convert the expression to an [`ExprRef`].
85    fn to_expr(&self) -> ExprRef;
86
87    /// Return the encoding of the expression.
88    fn encoding(&self) -> ExprEncodingRef;
89
90    /// Serialize the metadata of this expression into a bytes vector.
91    ///
92    /// Returns `None` if the expression does not support serialization.
93    fn metadata(&self) -> Option<Vec<u8>> {
94        None
95    }
96
97    /// Compute result of expression on given batch producing a new batch
98    ///
99    /// "Unchecked" means that this function lacks a debug assertion that the returned array matches
100    /// the [VortexExpr::return_dtype] method. Use instead the
101    /// [`VortexExpr::evaluate`](./trait.VortexExpr.html#method.evaluate).
102    /// function which includes such an assertion.
103    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef>;
104
105    /// Returns the children of this expression.
106    fn children(&self) -> Vec<&ExprRef>;
107
108    /// Returns a new instance of this expression with the children replaced.
109    fn with_children(self: Arc<Self>, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
110
111    /// Compute the type of the array returned by
112    /// [`VortexExpr::evaluate`](./trait.VortexExpr.html#method.evaluate).
113    fn return_dtype(&self, scope: &DType) -> VortexResult<DType>;
114
115    fn operator(&self, scope: &OperatorRef) -> VortexResult<Option<OperatorRef>>;
116}
117
118dyn_hash::hash_trait_object!(VortexExpr);
119
120impl PartialEq for dyn VortexExpr {
121    fn eq(&self, other: &Self) -> bool {
122        self.dyn_eq(other.as_any())
123    }
124}
125
126impl Eq for dyn VortexExpr {}
127
128impl dyn VortexExpr + '_ {
129    pub fn id(&self) -> ExprId {
130        self.encoding().id()
131    }
132
133    pub fn is<V: VTable>(&self) -> bool {
134        self.as_opt::<V>().is_some()
135    }
136
137    pub fn as_<V: VTable>(&self) -> &V::Expr {
138        self.as_opt::<V>()
139            .vortex_expect("Expr is not of the expected type")
140    }
141
142    pub fn as_opt<V: VTable>(&self) -> Option<&V::Expr> {
143        VortexExpr::as_any(self)
144            .downcast_ref::<ExprAdapter<V>>()
145            .map(|e| &e.0)
146    }
147
148    /// Compute result of expression on given batch producing a new batch
149    pub fn evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
150        let result = self.unchecked_evaluate(scope)?;
151        assert_eq!(
152            result.dtype(),
153            &self.return_dtype(scope.dtype())?,
154            "Expression {} returned dtype {} but declared return_dtype of {}",
155            &self,
156            result.dtype(),
157            self.return_dtype(scope.dtype())?,
158        );
159        Ok(result)
160    }
161
162    /// Display the expression as a formatted tree structure.
163    ///
164    /// This provides a hierarchical view of the expression that shows the relationships
165    /// between parent and child expressions, making complex nested expressions easier
166    /// to understand and debug.
167    ///
168    /// # Example
169    ///
170    /// ```rust
171    /// # use vortex_dtype::{DType, Nullability, PType};
172    /// # use vortex_expr::{and, cast, eq, get_item, gt, lit, not, root, select, IntoExpr, LikeExpr};
173    /// // Build a complex nested expression
174    /// let complex_expr = select(
175    ///     ["result"],
176    ///     and(
177    ///         not(eq(get_item("status", root()), lit("inactive"))),
178    ///         and(
179    ///             LikeExpr::new(get_item("name", root()), lit("%admin%"), false, false).into_expr(),
180    ///             gt(
181    ///                 cast(get_item("score", root()), DType::Primitive(PType::F64, Nullability::NonNullable)),
182    ///                 lit(75.0)
183    ///             )
184    ///         )
185    ///     )
186    /// );
187    ///
188    /// println!("{}", complex_expr.display_tree());
189    /// ```
190    ///
191    /// This produces output like:
192    ///
193    /// ```text
194    /// Select(include): {result}
195    /// └── Binary(and)
196    ///     ├── lhs: Not
197    ///     │   └── Binary(=)
198    ///     │       ├── lhs: GetItem(status)
199    ///     │       │   └── Root
200    ///     │       └── rhs: Literal(value: "inactive", dtype: utf8)
201    ///     └── rhs: Binary(and)
202    ///         ├── lhs: Like
203    ///         │   ├── child: GetItem(name)
204    ///         │   │   └── Root
205    ///         │   └── pattern: Literal(value: "%admin%", dtype: utf8)
206    ///         └── rhs: Binary(>)
207    ///             ├── lhs: Cast(target: f64)
208    ///             │   └── GetItem(score)
209    ///             │       └── Root
210    ///             └── rhs: Literal(value: 75f64, dtype: f64)
211    /// ```
212    pub fn display_tree(&self) -> impl Display {
213        display::DisplayTreeExpr(self)
214    }
215}
216
217impl Display for dyn VortexExpr + '_ {
218    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
219        DisplayAs::fmt_as(self, DisplayFormat::Compact, f)
220    }
221}
222
223pub trait VortexExprExt {
224    /// Accumulate all field references from this expression and its children in a set
225    fn field_references(&self) -> HashSet<FieldName>;
226}
227
228impl VortexExprExt for ExprRef {
229    fn field_references(&self) -> HashSet<FieldName> {
230        let mut collector = ReferenceCollector::new();
231        // The collector is infallible, so we can unwrap the result
232        self.accept(&mut collector).vortex_unwrap();
233        collector.into_fields()
234    }
235}
236
237#[derive(Clone)]
238#[repr(transparent)]
239pub struct ExprAdapter<V: VTable>(V::Expr);
240
241impl<V: VTable> VortexExpr for ExprAdapter<V> {
242    fn as_any(&self) -> &dyn Any {
243        self
244    }
245
246    fn to_expr(&self) -> ExprRef {
247        Arc::new(ExprAdapter::<V>(self.0.clone()))
248    }
249
250    fn encoding(&self) -> ExprEncodingRef {
251        V::encoding(&self.0)
252    }
253
254    fn metadata(&self) -> Option<Vec<u8>> {
255        V::metadata(&self.0).map(|m| m.serialize())
256    }
257
258    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef> {
259        V::evaluate(&self.0, ctx)
260    }
261
262    fn children(&self) -> Vec<&ExprRef> {
263        V::children(&self.0)
264    }
265
266    fn with_children(self: Arc<Self>, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
267        if self.children().len() != children.len() {
268            vortex_bail!(
269                "Expected {} children, got {}",
270                self.children().len(),
271                children.len()
272            );
273        }
274        Ok(V::with_children(&self.0, children)?.to_expr())
275    }
276
277    fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
278        V::return_dtype(&self.0, scope)
279    }
280
281    fn operator(&self, scope: &OperatorRef) -> VortexResult<Option<OperatorRef>> {
282        V::operator(&self.0, scope)
283    }
284}
285
286impl<V: VTable> Debug for ExprAdapter<V> {
287    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
288        Debug::fmt(&self.0, f)
289    }
290}
291
292impl<V: VTable> Display for ExprAdapter<V> {
293    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
294        DisplayAs::fmt_as(&self.0, DisplayFormat::Compact, f)
295    }
296}
297
298impl<V: VTable> DisplayAs for ExprAdapter<V> {
299    fn fmt_as(&self, df: DisplayFormat, f: &mut Formatter) -> std::fmt::Result {
300        DisplayAs::fmt_as(&self.0, df, f)
301    }
302
303    fn child_names(&self) -> Option<Vec<String>> {
304        DisplayAs::child_names(&self.0)
305    }
306}
307
308impl<V: VTable> PartialEq for ExprAdapter<V> {
309    fn eq(&self, other: &Self) -> bool {
310        PartialEq::eq(&self.0, &other.0)
311    }
312}
313
314impl<V: VTable> Eq for ExprAdapter<V> {}
315
316impl<V: VTable> Hash for ExprAdapter<V> {
317    fn hash<H: Hasher>(&self, state: &mut H) {
318        Hash::hash(&self.0, state);
319    }
320}
321
322impl<V: VTable> AnalysisExpr for ExprAdapter<V> {
323    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
324        <V::Expr as AnalysisExpr>::stat_falsification(&self.0, catalog)
325    }
326
327    fn max(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
328        <V::Expr as AnalysisExpr>::max(&self.0, catalog)
329    }
330
331    fn min(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
332        <V::Expr as AnalysisExpr>::min(&self.0, catalog)
333    }
334
335    fn nan_count(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
336        <V::Expr as AnalysisExpr>::nan_count(&self.0, catalog)
337    }
338
339    fn field_path(&self) -> Option<FieldPath> {
340        <V::Expr as AnalysisExpr>::field_path(&self.0)
341    }
342}
343
344mod private {
345    use super::*;
346
347    pub trait Sealed {}
348
349    impl<V: VTable> Sealed for ExprAdapter<V> {}
350}
351
352/// Splits top level and operations into separate expressions.
353pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
354    let mut conjunctions = vec![];
355    split_inner(expr, &mut conjunctions);
356    conjunctions
357}
358
359fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
360    match expr.as_opt::<BinaryVTable>() {
361        Some(bexp) if bexp.op() == Operator::And => {
362            split_inner(bexp.lhs(), exprs);
363            split_inner(bexp.rhs(), exprs);
364        }
365        Some(_) | None => {
366            exprs.push(expr.clone());
367        }
368    }
369}
370
371/// An expression wrapper that performs pointer equality.
372#[derive(Clone)]
373pub struct ExactExpr(pub ExprRef);
374
375impl PartialEq for ExactExpr {
376    fn eq(&self, other: &Self) -> bool {
377        Arc::ptr_eq(&self.0, &other.0)
378    }
379}
380
381impl Eq for ExactExpr {}
382
383impl Hash for ExactExpr {
384    fn hash<H: Hasher>(&self, state: &mut H) {
385        Arc::as_ptr(&self.0).hash(state)
386    }
387}
388
389#[cfg(feature = "test-harness")]
390pub mod test_harness {
391    use vortex_dtype::{DType, Nullability, PType, StructFields};
392
393    pub fn struct_dtype() -> DType {
394        DType::Struct(
395            StructFields::new(
396                ["a", "col1", "col2", "bool1", "bool2"].into(),
397                vec![
398                    DType::Primitive(PType::I32, Nullability::NonNullable),
399                    DType::Primitive(PType::U16, Nullability::Nullable),
400                    DType::Primitive(PType::U16, Nullability::Nullable),
401                    DType::Bool(Nullability::NonNullable),
402                    DType::Bool(Nullability::NonNullable),
403                ],
404            ),
405            Nullability::NonNullable,
406        )
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
413    use vortex_scalar::Scalar;
414
415    use super::*;
416
417    #[test]
418    fn basic_expr_split_test() {
419        let lhs = get_item("col1", root());
420        let rhs = lit(1);
421        let expr = eq(lhs, rhs);
422        let conjunction = split_conjunction(&expr);
423        assert_eq!(conjunction.len(), 1);
424    }
425
426    #[test]
427    fn basic_conjunction_split_test() {
428        let lhs = get_item("col1", root());
429        let rhs = lit(1);
430        let expr = and(lhs, rhs);
431        let conjunction = split_conjunction(&expr);
432        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
433    }
434
435    #[test]
436    fn expr_display() {
437        assert_eq!(col("a").to_string(), "$.a");
438        assert_eq!(root().to_string(), "$");
439
440        let col1: Arc<dyn VortexExpr> = col("col1");
441        let col2: Arc<dyn VortexExpr> = col("col2");
442        assert_eq!(
443            and(col1.clone(), col2.clone()).to_string(),
444            "($.col1 and $.col2)"
445        );
446        assert_eq!(
447            or(col1.clone(), col2.clone()).to_string(),
448            "($.col1 or $.col2)"
449        );
450        assert_eq!(
451            eq(col1.clone(), col2.clone()).to_string(),
452            "($.col1 = $.col2)"
453        );
454        assert_eq!(
455            not_eq(col1.clone(), col2.clone()).to_string(),
456            "($.col1 != $.col2)"
457        );
458        assert_eq!(
459            gt(col1.clone(), col2.clone()).to_string(),
460            "($.col1 > $.col2)"
461        );
462        assert_eq!(
463            gt_eq(col1.clone(), col2.clone()).to_string(),
464            "($.col1 >= $.col2)"
465        );
466        assert_eq!(
467            lt(col1.clone(), col2.clone()).to_string(),
468            "($.col1 < $.col2)"
469        );
470        assert_eq!(
471            lt_eq(col1.clone(), col2.clone()).to_string(),
472            "($.col1 <= $.col2)"
473        );
474
475        assert_eq!(
476            or(
477                lt(col1.clone(), col2.clone()),
478                not_eq(col1.clone(), col2.clone()),
479            )
480            .to_string(),
481            "(($.col1 < $.col2) or ($.col1 != $.col2))"
482        );
483
484        assert_eq!(not(col1.clone()).to_string(), "(!$.col1)");
485
486        assert_eq!(
487            select(vec![FieldName::from("col1")], root()).to_string(),
488            "${col1}"
489        );
490        assert_eq!(
491            select(
492                vec![FieldName::from("col1"), FieldName::from("col2")],
493                root()
494            )
495            .to_string(),
496            "${col1, col2}"
497        );
498        assert_eq!(
499            select_exclude(
500                vec![FieldName::from("col1"), FieldName::from("col2")],
501                root()
502            )
503            .to_string(),
504            "$~{col1, col2}"
505        );
506
507        assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
508        assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
509        assert_eq!(
510            lit(Scalar::from(i64::MAX)).to_string(),
511            "9223372036854775807i64"
512        );
513        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
514        assert_eq!(
515            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
516            "null"
517        );
518
519        assert_eq!(
520            lit(Scalar::struct_(
521                DType::Struct(
522                    StructFields::new(
523                        FieldNames::from(["dog", "cat"]),
524                        vec![
525                            DType::Primitive(PType::U32, Nullability::NonNullable),
526                            DType::Utf8(Nullability::NonNullable)
527                        ],
528                    ),
529                    Nullability::NonNullable
530                ),
531                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
532            ))
533            .to_string(),
534            "{dog: 32u32, cat: \"rufus\"}"
535        );
536    }
537}