vortex_array/expr/
expression.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Debug;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::hash::Hash;
9use std::ops::Deref;
10use std::sync::Arc;
11
12use itertools::Itertools;
13use vortex_dtype::DType;
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_ensure;
17
18use crate::ArrayRef;
19use crate::expr::Root;
20use crate::expr::ScalarFn;
21use crate::expr::StatsCatalog;
22use crate::expr::VTable;
23use crate::expr::display::DisplayTreeExpr;
24use crate::expr::stats::Stat;
25
26/// A node in a Vortex expression tree.
27///
28/// Expressions represent scalar computations that can be performed on data. Each
29/// expression consists of an encoding (vtable), heap-allocated metadata, and child expressions.
30#[derive(Clone, Debug, PartialEq, Eq, Hash)]
31pub struct Expression {
32    /// The scalar fn for this node.
33    scalar_fn: ScalarFn,
34    /// Any children of this expression.
35    children: Arc<[Expression]>,
36}
37
38impl Deref for Expression {
39    type Target = ScalarFn;
40
41    fn deref(&self) -> &Self::Target {
42        &self.scalar_fn
43    }
44}
45
46impl Expression {
47    /// Create a new expression node from a scalar_fn expression and its children.
48    pub fn try_new(
49        scalar_fn: ScalarFn,
50        children: impl Into<Arc<[Expression]>>,
51    ) -> VortexResult<Self> {
52        let children: Arc<[Expression]> = children.into();
53
54        vortex_ensure!(
55            scalar_fn.signature().arity().matches(children.len()),
56            "Expression arity mismatch: expected {} children but got {}",
57            scalar_fn.signature().arity(),
58            children.len()
59        );
60
61        Ok(Self {
62            scalar_fn,
63            children,
64        })
65    }
66
67    /// Returns true if this expression is of the given vtable type.
68    pub fn is<V: VTable>(&self) -> bool {
69        self.vtable().is::<V>()
70    }
71
72    /// Returns the typed options for this expression if it matches the given vtable type.
73    pub fn as_opt<V: VTable>(&self) -> Option<&V::Options> {
74        self.options().as_any().downcast_ref::<V::Options>()
75    }
76
77    /// Returns the typed options for this expression if it matches the given vtable type.
78    pub fn as_<V: VTable>(&self) -> &V::Options {
79        self.as_opt::<V>()
80            .vortex_expect("Expression options type mismatch")
81    }
82
83    /// Returns the scalar fn vtable for this expression.
84    pub fn scalar_fn(&self) -> &ScalarFn {
85        &self.scalar_fn
86    }
87
88    /// Returns the children of this expression.
89    pub fn children(&self) -> &Arc<[Expression]> {
90        &self.children
91    }
92
93    /// Returns the n'th child of this expression.
94    pub fn child(&self, n: usize) -> &Expression {
95        &self.children[n]
96    }
97
98    /// Replace the children of this expression with the provided new children.
99    pub fn with_children(mut self, children: impl Into<Arc<[Expression]>>) -> VortexResult<Self> {
100        let children = children.into();
101        vortex_ensure!(
102            self.signature().arity().matches(children.len()),
103            "Expression arity mismatch: expected {} children but got {}",
104            self.signature().arity(),
105            children.len()
106        );
107        self.children = children;
108        Ok(self)
109    }
110
111    /// Computes the return dtype of this expression given the input dtype.
112    pub fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
113        if self.is::<Root>() {
114            return Ok(scope.clone());
115        }
116
117        let dtypes: Vec<_> = self
118            .children
119            .iter()
120            .map(|c| c.return_dtype(scope))
121            .try_collect()?;
122        self.scalar_fn.return_dtype(&dtypes)
123    }
124
125    /// Evaluates the expression in the given scope, returning an array.
126    pub fn evaluate(&self, scope: &ArrayRef) -> VortexResult<ArrayRef> {
127        if self.is::<Root>() {
128            return Ok(scope.clone());
129        }
130        self.scalar_fn.evaluate(self, scope)
131    }
132
133    /// An expression over zone-statistics which implies all records in the zone evaluate to false.
134    ///
135    /// Given an expression, `e`, if `e.stat_falsification(..)` evaluates to true, it is guaranteed
136    /// that `e` evaluates to false on all records in the zone. However, the inverse is not
137    /// necessarily true: even if the falsification evaluates to false, `e` need not evaluate to
138    /// true on all records.
139    ///
140    /// The [`StatsCatalog`] can be used to constrain or rename stats used in the final expr.
141    ///
142    /// # Examples
143    ///
144    /// - An expression over one variable: `x > 0` is false for all records in a zone if the maximum
145    ///   value of the column `x` in that zone is less than or equal to zero: `max(x) <= 0`.
146    /// - An expression over two variables: `x > y` becomes `max(x) <= min(y)`.
147    /// - A conjunctive expression: `x > y AND z < x` becomes `max(x) <= min(y) OR min(z) >= max(x).
148    ///
149    /// Some expressions, in theory, have falsifications but this function does not support them
150    /// such as `x < (y < z)` or `x LIKE "needle%"`.
151    pub fn stat_falsification(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
152        self.vtable().as_dyn().stat_falsification(self, catalog)
153    }
154
155    /// Returns an expression representing the zoned statistic for the given stat, if available.
156    ///
157    /// The [`StatsCatalog`] returns expressions that can be evaluated using the zone map as a
158    /// scope. Expressions can implement this function to propagate such statistics through the
159    /// expression tree. For example, the `a + 10` expression could propagate `min: min(a) + 10`.
160    ///
161    /// NOTE(gatesn): we currently cannot represent statistics over nested fields. Please file an
162    /// issue to discuss a solution to this.
163    pub fn stat_expression(&self, stat: Stat, catalog: &dyn StatsCatalog) -> Option<Expression> {
164        self.vtable().as_dyn().stat_expression(self, stat, catalog)
165    }
166
167    /// Returns an expression representing the zoned maximum statistic, if available.
168    pub fn stat_min(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
169        self.stat_expression(Stat::Min, catalog)
170    }
171
172    /// Returns an expression representing the zoned maximum statistic, if available.
173    pub fn stat_max(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
174        self.stat_expression(Stat::Max, catalog)
175    }
176
177    /// Format the expression as a compact string.
178    ///
179    /// Since this is a recursive formatter, it is exposed on the public Expression type.
180    /// See fmt_data that is only implemented on the vtable trait.
181    pub fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
182        self.vtable().as_dyn().fmt_sql(self, f)
183    }
184
185    /// Display the expression as a formatted tree structure.
186    ///
187    /// This provides a hierarchical view of the expression that shows the relationships
188    /// between parent and child expressions, making complex nested expressions easier
189    /// to understand and debug.
190    ///
191    /// # Example
192    ///
193    /// ```rust
194    /// # use vortex_array::compute::LikeOptions;
195    /// # use vortex_array::expr::VTableExt;
196    /// # use vortex_dtype::{DType, Nullability, PType};
197    /// # use vortex_array::expr::{and, cast, eq, get_item, gt, lit, not, root, select, Like};
198    /// // Build a complex nested expression
199    /// let complex_expr = select(
200    ///     ["result"],
201    ///     and(
202    ///         not(eq(get_item("status", root()), lit("inactive"))),
203    ///         and(
204    ///             Like.new_expr(LikeOptions::default(), [get_item("name", root()), lit("%admin%")]),
205    ///             gt(
206    ///                 cast(get_item("score", root()), DType::Primitive(PType::F64, Nullability::NonNullable)),
207    ///                 lit(75.0)
208    ///             )
209    ///         )
210    ///     )
211    /// );
212    ///
213    /// println!("{}", complex_expr.display_tree());
214    /// ```
215    ///
216    /// This produces output like:
217    ///
218    /// ```text
219    /// Select(include): {result}
220    /// └── Binary(and)
221    ///     ├── lhs: Not
222    ///     │   └── Binary(=)
223    ///     │       ├── lhs: GetItem(status)
224    ///     │       │   └── Root
225    ///     │       └── rhs: Literal(value: "inactive", dtype: utf8)
226    ///     └── rhs: Binary(and)
227    ///         ├── lhs: Like
228    ///         │   ├── child: GetItem(name)
229    ///         │   │   └── Root
230    ///         │   └── pattern: Literal(value: "%admin%", dtype: utf8)
231    ///         └── rhs: Binary(>)
232    ///             ├── lhs: Cast(target: f64)
233    ///             │   └── GetItem(score)
234    ///             │       └── Root
235    ///             └── rhs: Literal(value: 75f64, dtype: f64)
236    /// ```
237    pub fn display_tree(&self) -> impl Display {
238        DisplayTreeExpr(self)
239    }
240}
241
242/// The default display implementation for expressions uses the 'SQL'-style format.
243impl Display for Expression {
244    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
245        self.fmt_sql(f)
246    }
247}