vortex_expr/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Vortex's expression language.
5//!
6//! All expressions are serializable, and own their own wire format.
7//!
8//! The implementation takes inspiration from [Postgres] and [Apache Datafusion].
9//!
10//! [Postgres]: https://www.postgresql.org/docs/current/sql-expressions.html
11//! [Apache Datafusion]: https://github.com/apache/datafusion/tree/5fac581efbaffd0e6a9edf931182517524526afd/datafusion/expr
12
13use std::any::Any;
14use std::fmt::{Debug, Display, Formatter};
15use std::hash::{Hash, Hasher};
16use std::sync::Arc;
17
18use dyn_hash::DynHash;
19pub use exprs::*;
20pub mod aliases;
21mod analysis;
22#[cfg(feature = "arbitrary")]
23pub mod arbitrary;
24pub mod dyn_traits;
25mod encoding;
26mod exprs;
27mod field;
28pub mod forms;
29mod operator;
30pub mod proto;
31pub mod pruning;
32mod registry;
33mod scope;
34mod scope_vars;
35pub mod transform;
36pub mod traversal;
37mod vtable;
38
39pub use analysis::*;
40pub use between::*;
41pub use binary::*;
42pub use cast::*;
43pub use encoding::*;
44pub use get_item::*;
45pub use is_null::*;
46pub use like::*;
47pub use list_contains::*;
48pub use literal::*;
49pub use merge::*;
50pub use not::*;
51pub use operator::*;
52pub use operators::*;
53pub use pack::*;
54pub use registry::*;
55pub use root::*;
56pub use scope::*;
57pub use scope_vars::*;
58pub use select::*;
59use vortex_array::pipeline::OperatorRef;
60use vortex_array::{Array, ArrayRef, SerializeMetadata};
61use vortex_dtype::{DType, FieldName, FieldPath};
62use vortex_error::{VortexExpect, VortexResult, VortexUnwrap, vortex_bail};
63use vortex_utils::aliases::hash_set::HashSet;
64pub use vtable::*;
65
66pub mod display;
67
68use crate::display::{DisplayAs, DisplayFormat};
69use crate::dyn_traits::DynEq;
70use crate::traversal::{NodeExt, ReferenceCollector};
71
72pub trait IntoExpr {
73    /// Convert this type into an expression reference.
74    fn into_expr(self) -> ExprRef;
75}
76
77pub type ExprRef = Arc<dyn VortexExpr>;
78
79/// Represents logical operation on [`ArrayRef`]s
80pub trait VortexExpr:
81    'static + Send + Sync + Debug + DisplayAs + DynEq + DynHash + AnalysisExpr + private::Sealed
82{
83    /// Convert expression reference to reference of [`Any`] type
84    fn as_any(&self) -> &dyn Any;
85
86    /// Convert the expression to an [`ExprRef`].
87    fn to_expr(&self) -> ExprRef;
88
89    /// Return the encoding of the expression.
90    fn encoding(&self) -> ExprEncodingRef;
91
92    /// Serialize the metadata of this expression into a bytes vector.
93    ///
94    /// Returns `None` if the expression does not support serialization.
95    fn metadata(&self) -> Option<Vec<u8>> {
96        None
97    }
98
99    /// Compute result of expression on given batch producing a new batch
100    ///
101    /// "Unchecked" means that this function lacks a debug assertion that the returned array matches
102    /// the [VortexExpr::return_dtype] method. Use instead the
103    /// [`VortexExpr::evaluate`](./trait.VortexExpr.html#method.evaluate).
104    /// function which includes such an assertion.
105    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef>;
106
107    /// Returns the children of this expression.
108    fn children(&self) -> Vec<&ExprRef>;
109
110    /// Returns a new instance of this expression with the children replaced.
111    fn with_children(self: Arc<Self>, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
112
113    /// Compute the type of the array returned by
114    /// [`VortexExpr::evaluate`](./trait.VortexExpr.html#method.evaluate).
115    fn return_dtype(&self, scope: &DType) -> VortexResult<DType>;
116
117    fn operator(&self, _children: Vec<OperatorRef>) -> Option<OperatorRef>;
118}
119
120dyn_hash::hash_trait_object!(VortexExpr);
121
122impl PartialEq for dyn VortexExpr {
123    fn eq(&self, other: &Self) -> bool {
124        self.dyn_eq(other.as_any())
125    }
126}
127
128impl Eq for dyn VortexExpr {}
129
130impl dyn VortexExpr + '_ {
131    pub fn id(&self) -> ExprId {
132        self.encoding().id()
133    }
134
135    pub fn is<V: VTable>(&self) -> bool {
136        self.as_opt::<V>().is_some()
137    }
138
139    pub fn as_<V: VTable>(&self) -> &V::Expr {
140        self.as_opt::<V>()
141            .vortex_expect("Expr is not of the expected type")
142    }
143
144    pub fn as_opt<V: VTable>(&self) -> Option<&V::Expr> {
145        VortexExpr::as_any(self)
146            .downcast_ref::<ExprAdapter<V>>()
147            .map(|e| &e.0)
148    }
149
150    /// Compute result of expression on given batch producing a new batch
151    pub fn evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
152        let result = self.unchecked_evaluate(scope)?;
153        assert_eq!(
154            result.dtype(),
155            &self.return_dtype(scope.dtype())?,
156            "Expression {} returned dtype {} but declared return_dtype of {}",
157            &self,
158            result.dtype(),
159            self.return_dtype(scope.dtype())?,
160        );
161        Ok(result)
162    }
163
164    /// Display the expression as a formatted tree structure.
165    ///
166    /// This provides a hierarchical view of the expression that shows the relationships
167    /// between parent and child expressions, making complex nested expressions easier
168    /// to understand and debug.
169    ///
170    /// # Example
171    ///
172    /// ```rust
173    /// # use vortex_dtype::{DType, Nullability, PType};
174    /// # use vortex_expr::{and, cast, eq, get_item, gt, lit, not, root, select, IntoExpr, LikeExpr};
175    /// // Build a complex nested expression
176    /// let complex_expr = select(
177    ///     ["result"],
178    ///     and(
179    ///         not(eq(get_item("status", root()), lit("inactive"))),
180    ///         and(
181    ///             LikeExpr::new(get_item("name", root()), lit("%admin%"), false, false).into_expr(),
182    ///             gt(
183    ///                 cast(get_item("score", root()), DType::Primitive(PType::F64, Nullability::NonNullable)),
184    ///                 lit(75.0)
185    ///             )
186    ///         )
187    ///     )
188    /// );
189    ///
190    /// println!("{}", complex_expr.display_tree());
191    /// ```
192    ///
193    /// This produces output like:
194    ///
195    /// ```text
196    /// Select(include): {result}
197    /// └── Binary(and)
198    ///     ├── lhs: Not
199    ///     │   └── Binary(=)
200    ///     │       ├── lhs: GetItem(status)
201    ///     │       │   └── Root
202    ///     │       └── rhs: Literal(value: "inactive", dtype: utf8)
203    ///     └── rhs: Binary(and)
204    ///         ├── lhs: Like
205    ///         │   ├── child: GetItem(name)
206    ///         │   │   └── Root
207    ///         │   └── pattern: Literal(value: "%admin%", dtype: utf8)
208    ///         └── rhs: Binary(>)
209    ///             ├── lhs: Cast(target: f64)
210    ///             │   └── GetItem(score)
211    ///             │       └── Root
212    ///             └── rhs: Literal(value: 75f64, dtype: f64)
213    /// ```
214    pub fn display_tree(&self) -> impl Display {
215        display::DisplayTreeExpr(self)
216    }
217}
218
219impl Display for dyn VortexExpr + '_ {
220    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
221        DisplayAs::fmt_as(self, DisplayFormat::Compact, f)
222    }
223}
224
225pub trait VortexExprExt {
226    /// Accumulate all field references from this expression and its children in a set
227    fn field_references(&self) -> HashSet<FieldName>;
228
229    fn to_operator(&self, root: &dyn Array) -> VortexResult<Option<OperatorRef>>;
230
231    fn to_operator_unoptimized(&self, root: &dyn Array) -> VortexResult<Option<OperatorRef>>;
232}
233
234impl VortexExprExt for ExprRef {
235    fn field_references(&self) -> HashSet<FieldName> {
236        let mut collector = ReferenceCollector::new();
237        // The collector is infallible, so we can unwrap the result
238        self.accept(&mut collector).vortex_unwrap();
239        collector.into_fields()
240    }
241
242    fn to_operator(&self, root: &dyn Array) -> VortexResult<Option<OperatorRef>> {
243        let Some(operator) = self.to_operator_unoptimized(root)? else {
244            return Ok(None);
245        };
246        reduce_up(operator).map(Some)
247    }
248
249    fn to_operator_unoptimized(&self, root: &dyn Array) -> VortexResult<Option<OperatorRef>> {
250        let Some(root_op) = root.to_operator()? else {
251            return Ok(None);
252        };
253        let mut converter = ExprOperatorConverter::new(root_op);
254        self.clone().fold(&mut converter).map(|op| op.value())
255    }
256}
257
258#[derive(Clone)]
259#[repr(transparent)]
260pub struct ExprAdapter<V: VTable>(V::Expr);
261
262impl<V: VTable> VortexExpr for ExprAdapter<V> {
263    fn as_any(&self) -> &dyn Any {
264        self
265    }
266
267    fn to_expr(&self) -> ExprRef {
268        Arc::new(ExprAdapter::<V>(self.0.clone()))
269    }
270
271    fn encoding(&self) -> ExprEncodingRef {
272        V::encoding(&self.0)
273    }
274
275    fn metadata(&self) -> Option<Vec<u8>> {
276        V::metadata(&self.0).map(|m| m.serialize())
277    }
278
279    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef> {
280        V::evaluate(&self.0, ctx)
281    }
282
283    fn children(&self) -> Vec<&ExprRef> {
284        V::children(&self.0)
285    }
286
287    fn with_children(self: Arc<Self>, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
288        if self.children().len() != children.len() {
289            vortex_bail!(
290                "Expected {} children, got {}",
291                self.children().len(),
292                children.len()
293            );
294        }
295        Ok(V::with_children(&self.0, children)?.to_expr())
296    }
297
298    fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
299        V::return_dtype(&self.0, scope)
300    }
301
302    fn operator(&self, children: Vec<OperatorRef>) -> Option<OperatorRef> {
303        V::operator(&self.0, children)
304    }
305}
306
307impl<V: VTable> Debug for ExprAdapter<V> {
308    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
309        Debug::fmt(&self.0, f)
310    }
311}
312
313impl<V: VTable> Display for ExprAdapter<V> {
314    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
315        DisplayAs::fmt_as(&self.0, DisplayFormat::Compact, f)
316    }
317}
318
319impl<V: VTable> DisplayAs for ExprAdapter<V> {
320    fn fmt_as(&self, df: DisplayFormat, f: &mut Formatter) -> std::fmt::Result {
321        DisplayAs::fmt_as(&self.0, df, f)
322    }
323
324    fn child_names(&self) -> Option<Vec<String>> {
325        DisplayAs::child_names(&self.0)
326    }
327}
328
329impl<V: VTable> PartialEq for ExprAdapter<V> {
330    fn eq(&self, other: &Self) -> bool {
331        PartialEq::eq(&self.0, &other.0)
332    }
333}
334
335impl<V: VTable> Eq for ExprAdapter<V> {}
336
337impl<V: VTable> Hash for ExprAdapter<V> {
338    fn hash<H: Hasher>(&self, state: &mut H) {
339        Hash::hash(&self.0, state);
340    }
341}
342
343impl<V: VTable> AnalysisExpr for ExprAdapter<V> {
344    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
345        <V::Expr as AnalysisExpr>::stat_falsification(&self.0, catalog)
346    }
347
348    fn max(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
349        <V::Expr as AnalysisExpr>::max(&self.0, catalog)
350    }
351
352    fn min(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
353        <V::Expr as AnalysisExpr>::min(&self.0, catalog)
354    }
355
356    fn nan_count(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
357        <V::Expr as AnalysisExpr>::nan_count(&self.0, catalog)
358    }
359
360    fn field_path(&self) -> Option<FieldPath> {
361        <V::Expr as AnalysisExpr>::field_path(&self.0)
362    }
363}
364
365mod private {
366    use super::*;
367
368    pub trait Sealed {}
369
370    impl<V: VTable> Sealed for ExprAdapter<V> {}
371}
372
373/// Splits top level and operations into separate expressions.
374pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
375    let mut conjunctions = vec![];
376    split_inner(expr, &mut conjunctions);
377    conjunctions
378}
379
380fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
381    match expr.as_opt::<BinaryVTable>() {
382        Some(bexp) if bexp.op() == Operator::And => {
383            split_inner(bexp.lhs(), exprs);
384            split_inner(bexp.rhs(), exprs);
385        }
386        Some(_) | None => {
387            exprs.push(expr.clone());
388        }
389    }
390}
391
392/// An expression wrapper that performs pointer equality.
393#[derive(Clone)]
394pub struct ExactExpr(pub ExprRef);
395
396impl PartialEq for ExactExpr {
397    fn eq(&self, other: &Self) -> bool {
398        Arc::ptr_eq(&self.0, &other.0)
399    }
400}
401
402impl Eq for ExactExpr {}
403
404impl Hash for ExactExpr {
405    fn hash<H: Hasher>(&self, state: &mut H) {
406        Arc::as_ptr(&self.0).hash(state)
407    }
408}
409
410#[cfg(feature = "test-harness")]
411pub mod test_harness {
412
413    use vortex_dtype::{DType, Nullability, PType, StructFields};
414
415    pub fn struct_dtype() -> DType {
416        DType::Struct(
417            StructFields::new(
418                ["a", "col1", "col2", "bool1", "bool2"].into(),
419                vec![
420                    DType::Primitive(PType::I32, Nullability::NonNullable),
421                    DType::Primitive(PType::U16, Nullability::Nullable),
422                    DType::Primitive(PType::U16, Nullability::Nullable),
423                    DType::Bool(Nullability::NonNullable),
424                    DType::Bool(Nullability::NonNullable),
425                ],
426            ),
427            Nullability::NonNullable,
428        )
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
435    use vortex_scalar::Scalar;
436
437    use super::*;
438
439    #[test]
440    fn basic_expr_split_test() {
441        let lhs = get_item("col1", root());
442        let rhs = lit(1);
443        let expr = eq(lhs, rhs);
444        let conjunction = split_conjunction(&expr);
445        assert_eq!(conjunction.len(), 1);
446    }
447
448    #[test]
449    fn basic_conjunction_split_test() {
450        let lhs = get_item("col1", root());
451        let rhs = lit(1);
452        let expr = and(lhs, rhs);
453        let conjunction = split_conjunction(&expr);
454        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
455    }
456
457    #[test]
458    fn expr_display() {
459        assert_eq!(col("a").to_string(), "$.a");
460        assert_eq!(root().to_string(), "$");
461
462        let col1: Arc<dyn VortexExpr> = col("col1");
463        let col2: Arc<dyn VortexExpr> = col("col2");
464        assert_eq!(
465            and(col1.clone(), col2.clone()).to_string(),
466            "($.col1 and $.col2)"
467        );
468        assert_eq!(
469            or(col1.clone(), col2.clone()).to_string(),
470            "($.col1 or $.col2)"
471        );
472        assert_eq!(
473            eq(col1.clone(), col2.clone()).to_string(),
474            "($.col1 = $.col2)"
475        );
476        assert_eq!(
477            not_eq(col1.clone(), col2.clone()).to_string(),
478            "($.col1 != $.col2)"
479        );
480        assert_eq!(
481            gt(col1.clone(), col2.clone()).to_string(),
482            "($.col1 > $.col2)"
483        );
484        assert_eq!(
485            gt_eq(col1.clone(), col2.clone()).to_string(),
486            "($.col1 >= $.col2)"
487        );
488        assert_eq!(
489            lt(col1.clone(), col2.clone()).to_string(),
490            "($.col1 < $.col2)"
491        );
492        assert_eq!(
493            lt_eq(col1.clone(), col2.clone()).to_string(),
494            "($.col1 <= $.col2)"
495        );
496
497        assert_eq!(
498            or(
499                lt(col1.clone(), col2.clone()),
500                not_eq(col1.clone(), col2.clone()),
501            )
502            .to_string(),
503            "(($.col1 < $.col2) or ($.col1 != $.col2))"
504        );
505
506        assert_eq!(not(col1.clone()).to_string(), "(!$.col1)");
507
508        assert_eq!(
509            select(vec![FieldName::from("col1")], root()).to_string(),
510            "${col1}"
511        );
512        assert_eq!(
513            select(
514                vec![FieldName::from("col1"), FieldName::from("col2")],
515                root()
516            )
517            .to_string(),
518            "${col1, col2}"
519        );
520        assert_eq!(
521            select_exclude(
522                vec![FieldName::from("col1"), FieldName::from("col2")],
523                root()
524            )
525            .to_string(),
526            "$~{col1, col2}"
527        );
528
529        assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
530        assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
531        assert_eq!(
532            lit(Scalar::from(i64::MAX)).to_string(),
533            "9223372036854775807i64"
534        );
535        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
536        assert_eq!(
537            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
538            "null"
539        );
540
541        assert_eq!(
542            lit(Scalar::struct_(
543                DType::Struct(
544                    StructFields::new(
545                        FieldNames::from(["dog", "cat"]),
546                        vec![
547                            DType::Primitive(PType::U32, Nullability::NonNullable),
548                            DType::Utf8(Nullability::NonNullable)
549                        ],
550                    ),
551                    Nullability::NonNullable
552                ),
553                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
554            ))
555            .to_string(),
556            "{dog: 32u32, cat: \"rufus\"}"
557        );
558    }
559}