vortex_expr/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Vortex's expression language.
5//!
6//! All expressions are serializable, and own their own wire format.
7//!
8//! The implementation takes inspiration from [Postgres] and [Apache Datafusion].
9//!
10//! [Postgres]: https://www.postgresql.org/docs/current/sql-expressions.html
11//! [Apache Datafusion]: https://github.com/apache/datafusion/tree/5fac581efbaffd0e6a9edf931182517524526afd/datafusion/expr
12
13use std::any::Any;
14use std::fmt::{Debug, Display, Formatter};
15use std::hash::{Hash, Hasher};
16use std::sync::Arc;
17
18use dyn_hash::DynHash;
19pub use exprs::*;
20pub mod aliases;
21mod analysis;
22#[cfg(feature = "arbitrary")]
23pub mod arbitrary;
24pub mod dyn_traits;
25mod encoding;
26mod exprs;
27mod field;
28pub mod forms;
29pub mod proto;
30pub mod pruning;
31mod registry;
32mod scope;
33mod scope_vars;
34pub mod transform;
35pub mod traversal;
36mod vtable;
37
38pub use analysis::*;
39pub use between::*;
40pub use binary::*;
41pub use cast::*;
42pub use concat::*;
43pub use encoding::*;
44pub use get_item::*;
45pub use is_null::*;
46pub use like::*;
47pub use list_contains::*;
48pub use literal::*;
49pub use merge::*;
50pub use not::*;
51pub use operators::*;
52pub use pack::*;
53pub use registry::*;
54pub use root::*;
55pub use scope::*;
56pub use scope_vars::*;
57pub use select::*;
58use vortex_array::operator::OperatorRef;
59use vortex_array::{Array, ArrayRef, SerializeMetadata};
60use vortex_dtype::{DType, FieldName, FieldPath};
61use vortex_error::{VortexExpect, VortexResult, VortexUnwrap, vortex_bail};
62use vortex_utils::aliases::hash_set::HashSet;
63pub use vtable::*;
64
65pub mod display;
66
67use crate::display::{DisplayAs, DisplayFormat};
68use crate::dyn_traits::DynEq;
69use crate::traversal::{NodeExt, ReferenceCollector};
70
71pub trait IntoExpr {
72    /// Convert this type into an expression reference.
73    fn into_expr(self) -> ExprRef;
74}
75
76pub type ExprRef = Arc<dyn VortexExpr>;
77
78/// Represents logical operation on [`ArrayRef`]s
79pub trait VortexExpr:
80    'static + Send + Sync + Debug + DisplayAs + DynEq + DynHash + AnalysisExpr + private::Sealed
81{
82    /// Convert expression reference to reference of [`Any`] type
83    fn as_any(&self) -> &dyn Any;
84
85    /// Convert the expression to an [`ExprRef`].
86    fn to_expr(&self) -> ExprRef;
87
88    /// Return the encoding of the expression.
89    fn encoding(&self) -> ExprEncodingRef;
90
91    /// Serialize the metadata of this expression into a bytes vector.
92    ///
93    /// Returns `None` if the expression does not support serialization.
94    fn metadata(&self) -> Option<Vec<u8>> {
95        None
96    }
97
98    /// Compute result of expression on given batch producing a new batch
99    ///
100    /// "Unchecked" means that this function lacks a debug assertion that the returned array matches
101    /// the [VortexExpr::return_dtype] method. Use instead the
102    /// [`VortexExpr::evaluate`](./trait.VortexExpr.html#method.evaluate).
103    /// function which includes such an assertion.
104    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef>;
105
106    /// Returns the children of this expression.
107    fn children(&self) -> Vec<&ExprRef>;
108
109    /// Returns a new instance of this expression with the children replaced.
110    fn with_children(self: Arc<Self>, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
111
112    /// Compute the type of the array returned by
113    /// [`VortexExpr::evaluate`](./trait.VortexExpr.html#method.evaluate).
114    fn return_dtype(&self, scope: &DType) -> VortexResult<DType>;
115
116    fn operator(&self, scope: &OperatorRef) -> VortexResult<Option<OperatorRef>>;
117}
118
119dyn_hash::hash_trait_object!(VortexExpr);
120
121impl PartialEq for dyn VortexExpr {
122    fn eq(&self, other: &Self) -> bool {
123        self.dyn_eq(other.as_any())
124    }
125}
126
127impl Eq for dyn VortexExpr {}
128
129impl dyn VortexExpr + '_ {
130    pub fn id(&self) -> ExprId {
131        self.encoding().id()
132    }
133
134    pub fn is<V: VTable>(&self) -> bool {
135        self.as_opt::<V>().is_some()
136    }
137
138    pub fn as_<V: VTable>(&self) -> &V::Expr {
139        self.as_opt::<V>()
140            .vortex_expect("Expr is not of the expected type")
141    }
142
143    pub fn as_opt<V: VTable>(&self) -> Option<&V::Expr> {
144        VortexExpr::as_any(self)
145            .downcast_ref::<ExprAdapter<V>>()
146            .map(|e| &e.0)
147    }
148
149    /// Compute result of expression on given batch producing a new batch
150    pub fn evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
151        let result = self.unchecked_evaluate(scope)?;
152        assert_eq!(
153            result.dtype(),
154            &self.return_dtype(scope.dtype())?,
155            "Expression {} returned dtype {} but declared return_dtype of {}",
156            &self,
157            result.dtype(),
158            self.return_dtype(scope.dtype())?,
159        );
160        Ok(result)
161    }
162
163    /// Display the expression as a formatted tree structure.
164    ///
165    /// This provides a hierarchical view of the expression that shows the relationships
166    /// between parent and child expressions, making complex nested expressions easier
167    /// to understand and debug.
168    ///
169    /// # Example
170    ///
171    /// ```rust
172    /// # use vortex_dtype::{DType, Nullability, PType};
173    /// # use vortex_expr::{and, cast, eq, get_item, gt, lit, not, root, select, IntoExpr, LikeExpr};
174    /// // Build a complex nested expression
175    /// let complex_expr = select(
176    ///     ["result"],
177    ///     and(
178    ///         not(eq(get_item("status", root()), lit("inactive"))),
179    ///         and(
180    ///             LikeExpr::new(get_item("name", root()), lit("%admin%"), false, false).into_expr(),
181    ///             gt(
182    ///                 cast(get_item("score", root()), DType::Primitive(PType::F64, Nullability::NonNullable)),
183    ///                 lit(75.0)
184    ///             )
185    ///         )
186    ///     )
187    /// );
188    ///
189    /// println!("{}", complex_expr.display_tree());
190    /// ```
191    ///
192    /// This produces output like:
193    ///
194    /// ```text
195    /// Select(include): {result}
196    /// └── Binary(and)
197    ///     ├── lhs: Not
198    ///     │   └── Binary(=)
199    ///     │       ├── lhs: GetItem(status)
200    ///     │       │   └── Root
201    ///     │       └── rhs: Literal(value: "inactive", dtype: utf8)
202    ///     └── rhs: Binary(and)
203    ///         ├── lhs: Like
204    ///         │   ├── child: GetItem(name)
205    ///         │   │   └── Root
206    ///         │   └── pattern: Literal(value: "%admin%", dtype: utf8)
207    ///         └── rhs: Binary(>)
208    ///             ├── lhs: Cast(target: f64)
209    ///             │   └── GetItem(score)
210    ///             │       └── Root
211    ///             └── rhs: Literal(value: 75f64, dtype: f64)
212    /// ```
213    pub fn display_tree(&self) -> impl Display {
214        display::DisplayTreeExpr(self)
215    }
216}
217
218impl Display for dyn VortexExpr + '_ {
219    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
220        DisplayAs::fmt_as(self, DisplayFormat::Compact, f)
221    }
222}
223
224pub trait VortexExprExt {
225    /// Accumulate all field references from this expression and its children in a set
226    fn field_references(&self) -> HashSet<FieldName>;
227}
228
229impl VortexExprExt for ExprRef {
230    fn field_references(&self) -> HashSet<FieldName> {
231        let mut collector = ReferenceCollector::new();
232        // The collector is infallible, so we can unwrap the result
233        self.accept(&mut collector).vortex_unwrap();
234        collector.into_fields()
235    }
236}
237
238#[derive(Clone)]
239#[repr(transparent)]
240pub struct ExprAdapter<V: VTable>(V::Expr);
241
242impl<V: VTable> VortexExpr for ExprAdapter<V> {
243    fn as_any(&self) -> &dyn Any {
244        self
245    }
246
247    fn to_expr(&self) -> ExprRef {
248        Arc::new(ExprAdapter::<V>(self.0.clone()))
249    }
250
251    fn encoding(&self) -> ExprEncodingRef {
252        V::encoding(&self.0)
253    }
254
255    fn metadata(&self) -> Option<Vec<u8>> {
256        V::metadata(&self.0).map(|m| m.serialize())
257    }
258
259    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef> {
260        V::evaluate(&self.0, ctx)
261    }
262
263    fn children(&self) -> Vec<&ExprRef> {
264        V::children(&self.0)
265    }
266
267    fn with_children(self: Arc<Self>, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
268        if self.children().len() != children.len() {
269            vortex_bail!(
270                "Expected {} children, got {}",
271                self.children().len(),
272                children.len()
273            );
274        }
275        Ok(V::with_children(&self.0, children)?.to_expr())
276    }
277
278    fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
279        V::return_dtype(&self.0, scope)
280    }
281
282    fn operator(&self, scope: &OperatorRef) -> VortexResult<Option<OperatorRef>> {
283        V::operator(&self.0, scope)
284    }
285}
286
287impl<V: VTable> Debug for ExprAdapter<V> {
288    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
289        Debug::fmt(&self.0, f)
290    }
291}
292
293impl<V: VTable> Display for ExprAdapter<V> {
294    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
295        DisplayAs::fmt_as(&self.0, DisplayFormat::Compact, f)
296    }
297}
298
299impl<V: VTable> DisplayAs for ExprAdapter<V> {
300    fn fmt_as(&self, df: DisplayFormat, f: &mut Formatter) -> std::fmt::Result {
301        DisplayAs::fmt_as(&self.0, df, f)
302    }
303
304    fn child_names(&self) -> Option<Vec<String>> {
305        DisplayAs::child_names(&self.0)
306    }
307}
308
309impl<V: VTable> PartialEq for ExprAdapter<V> {
310    fn eq(&self, other: &Self) -> bool {
311        PartialEq::eq(&self.0, &other.0)
312    }
313}
314
315impl<V: VTable> Eq for ExprAdapter<V> {}
316
317impl<V: VTable> Hash for ExprAdapter<V> {
318    fn hash<H: Hasher>(&self, state: &mut H) {
319        Hash::hash(&self.0, state);
320    }
321}
322
323impl<V: VTable> AnalysisExpr for ExprAdapter<V> {
324    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
325        <V::Expr as AnalysisExpr>::stat_falsification(&self.0, catalog)
326    }
327
328    fn max(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
329        <V::Expr as AnalysisExpr>::max(&self.0, catalog)
330    }
331
332    fn min(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
333        <V::Expr as AnalysisExpr>::min(&self.0, catalog)
334    }
335
336    fn nan_count(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
337        <V::Expr as AnalysisExpr>::nan_count(&self.0, catalog)
338    }
339
340    fn field_path(&self) -> Option<FieldPath> {
341        <V::Expr as AnalysisExpr>::field_path(&self.0)
342    }
343}
344
345mod private {
346    use super::*;
347
348    pub trait Sealed {}
349
350    impl<V: VTable> Sealed for ExprAdapter<V> {}
351}
352
353/// Splits top level and operations into separate expressions.
354pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
355    let mut conjunctions = vec![];
356    split_inner(expr, &mut conjunctions);
357    conjunctions
358}
359
360fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
361    match expr.as_opt::<BinaryVTable>() {
362        Some(bexp) if bexp.op() == Operator::And => {
363            split_inner(bexp.lhs(), exprs);
364            split_inner(bexp.rhs(), exprs);
365        }
366        Some(_) | None => {
367            exprs.push(expr.clone());
368        }
369    }
370}
371
372/// An expression wrapper that performs pointer equality.
373#[derive(Clone)]
374pub struct ExactExpr(pub ExprRef);
375
376impl PartialEq for ExactExpr {
377    fn eq(&self, other: &Self) -> bool {
378        Arc::ptr_eq(&self.0, &other.0)
379    }
380}
381
382impl Eq for ExactExpr {}
383
384impl Hash for ExactExpr {
385    fn hash<H: Hasher>(&self, state: &mut H) {
386        Arc::as_ptr(&self.0).hash(state)
387    }
388}
389
390#[cfg(feature = "test-harness")]
391pub mod test_harness {
392    use vortex_dtype::{DType, Nullability, PType, StructFields};
393
394    pub fn struct_dtype() -> DType {
395        DType::Struct(
396            StructFields::new(
397                ["a", "col1", "col2", "bool1", "bool2"].into(),
398                vec![
399                    DType::Primitive(PType::I32, Nullability::NonNullable),
400                    DType::Primitive(PType::U16, Nullability::Nullable),
401                    DType::Primitive(PType::U16, Nullability::Nullable),
402                    DType::Bool(Nullability::NonNullable),
403                    DType::Bool(Nullability::NonNullable),
404                ],
405            ),
406            Nullability::NonNullable,
407        )
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
414    use vortex_scalar::Scalar;
415
416    use super::*;
417
418    #[test]
419    fn basic_expr_split_test() {
420        let lhs = get_item("col1", root());
421        let rhs = lit(1);
422        let expr = eq(lhs, rhs);
423        let conjunction = split_conjunction(&expr);
424        assert_eq!(conjunction.len(), 1);
425    }
426
427    #[test]
428    fn basic_conjunction_split_test() {
429        let lhs = get_item("col1", root());
430        let rhs = lit(1);
431        let expr = and(lhs, rhs);
432        let conjunction = split_conjunction(&expr);
433        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
434    }
435
436    #[test]
437    fn expr_display() {
438        assert_eq!(col("a").to_string(), "$.a");
439        assert_eq!(root().to_string(), "$");
440
441        let col1: Arc<dyn VortexExpr> = col("col1");
442        let col2: Arc<dyn VortexExpr> = col("col2");
443        assert_eq!(
444            and(col1.clone(), col2.clone()).to_string(),
445            "($.col1 and $.col2)"
446        );
447        assert_eq!(
448            or(col1.clone(), col2.clone()).to_string(),
449            "($.col1 or $.col2)"
450        );
451        assert_eq!(
452            eq(col1.clone(), col2.clone()).to_string(),
453            "($.col1 = $.col2)"
454        );
455        assert_eq!(
456            not_eq(col1.clone(), col2.clone()).to_string(),
457            "($.col1 != $.col2)"
458        );
459        assert_eq!(
460            gt(col1.clone(), col2.clone()).to_string(),
461            "($.col1 > $.col2)"
462        );
463        assert_eq!(
464            gt_eq(col1.clone(), col2.clone()).to_string(),
465            "($.col1 >= $.col2)"
466        );
467        assert_eq!(
468            lt(col1.clone(), col2.clone()).to_string(),
469            "($.col1 < $.col2)"
470        );
471        assert_eq!(
472            lt_eq(col1.clone(), col2.clone()).to_string(),
473            "($.col1 <= $.col2)"
474        );
475
476        assert_eq!(
477            or(
478                lt(col1.clone(), col2.clone()),
479                not_eq(col1.clone(), col2.clone()),
480            )
481            .to_string(),
482            "(($.col1 < $.col2) or ($.col1 != $.col2))"
483        );
484
485        assert_eq!(not(col1.clone()).to_string(), "(!$.col1)");
486
487        assert_eq!(
488            select(vec![FieldName::from("col1")], root()).to_string(),
489            "${col1}"
490        );
491        assert_eq!(
492            select(
493                vec![FieldName::from("col1"), FieldName::from("col2")],
494                root()
495            )
496            .to_string(),
497            "${col1, col2}"
498        );
499        assert_eq!(
500            select_exclude(
501                vec![FieldName::from("col1"), FieldName::from("col2")],
502                root()
503            )
504            .to_string(),
505            "$~{col1, col2}"
506        );
507
508        assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
509        assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
510        assert_eq!(
511            lit(Scalar::from(i64::MAX)).to_string(),
512            "9223372036854775807i64"
513        );
514        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
515        assert_eq!(
516            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
517            "null"
518        );
519
520        assert_eq!(
521            lit(Scalar::struct_(
522                DType::Struct(
523                    StructFields::new(
524                        FieldNames::from(["dog", "cat"]),
525                        vec![
526                            DType::Primitive(PType::U32, Nullability::NonNullable),
527                            DType::Utf8(Nullability::NonNullable)
528                        ],
529                    ),
530                    Nullability::NonNullable
531                ),
532                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
533            ))
534            .to_string(),
535            "{dog: 32u32, cat: \"rufus\"}"
536        );
537    }
538}