Skip to main content

vortex_array/expr/
expression.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Debug;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::hash::Hash;
9use std::ops::Deref;
10use std::sync::Arc;
11
12use itertools::Itertools;
13use vortex_error::VortexResult;
14use vortex_error::vortex_ensure;
15use vortex_session::VortexSession;
16
17use crate::dtype::DType;
18use crate::expr::StatsCatalog;
19use crate::expr::display::DisplayTreeExpr;
20use crate::expr::stats::Stat;
21use crate::scalar_fn::ScalarFnRef;
22use crate::scalar_fn::fns::root::Root;
23
24/// A node in a Vortex expression tree.
25///
26/// Expressions represent scalar computations that can be performed on data. Each
27/// expression consists of an encoding (vtable), heap-allocated metadata, and child expressions.
28#[derive(Clone, Debug, PartialEq, Eq, Hash)]
29pub struct Expression {
30    /// The scalar fn for this node.
31    scalar_fn: ScalarFnRef,
32    /// Any children of this expression.
33    children: Arc<Vec<Expression>>,
34}
35
36impl Deref for Expression {
37    type Target = ScalarFnRef;
38
39    fn deref(&self) -> &Self::Target {
40        &self.scalar_fn
41    }
42}
43
44impl Expression {
45    /// Create a new expression node from a scalar_fn expression and its children.
46    pub fn try_new(
47        scalar_fn: ScalarFnRef,
48        children: impl IntoIterator<Item = Expression>,
49    ) -> VortexResult<Self> {
50        let children = Vec::from_iter(children);
51
52        vortex_ensure!(
53            scalar_fn.signature().arity().matches(children.len()),
54            "Expression arity mismatch: expected {} children but got {}",
55            scalar_fn.signature().arity(),
56            children.len()
57        );
58
59        Ok(Self {
60            scalar_fn,
61            children: children.into(),
62        })
63    }
64
65    /// Returns the scalar fn vtable for this expression.
66    pub fn scalar_fn(&self) -> &ScalarFnRef {
67        &self.scalar_fn
68    }
69
70    /// Returns the children of this expression.
71    pub fn children(&self) -> &Arc<Vec<Expression>> {
72        &self.children
73    }
74
75    /// Returns the n'th child of this expression.
76    pub fn child(&self, n: usize) -> &Expression {
77        &self.children[n]
78    }
79
80    /// Replace the children of this expression with the provided new children.
81    pub fn with_children(
82        mut self,
83        children: impl IntoIterator<Item = Expression>,
84    ) -> VortexResult<Self> {
85        let children = Vec::from_iter(children);
86        vortex_ensure!(
87            self.signature().arity().matches(children.len()),
88            "Expression arity mismatch: expected {} children but got {}",
89            self.signature().arity(),
90            children.len()
91        );
92        self.children = Arc::new(children);
93        Ok(self)
94    }
95
96    /// Computes the return dtype of this expression given the input dtype.
97    pub fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
98        if self.is::<Root>() {
99            return Ok(scope.clone());
100        }
101
102        let dtypes: Vec<_> = self
103            .children
104            .iter()
105            .map(|c| c.return_dtype(scope))
106            .try_collect()?;
107        self.scalar_fn.return_dtype(&dtypes)
108    }
109
110    /// Returns a new expression representing the validity mask output of this expression.
111    ///
112    /// The returned expression evaluates to a non-nullable boolean array.
113    pub fn validity(&self) -> VortexResult<Expression> {
114        self.scalar_fn.validity(self)
115    }
116
117    /// An expression over zone-statistics which implies all records in the zone evaluate to false.
118    ///
119    /// Given an expression, `e`, if `e.stat_falsification(..)` evaluates to true, it is guaranteed
120    /// that `e` evaluates to false on all records in the zone. However, the inverse is not
121    /// necessarily true: even if the falsification evaluates to false, `e` need not evaluate to
122    /// true on all records.
123    ///
124    /// The [`StatsCatalog`] can be used to constrain or rename stats used in the final expr.
125    ///
126    /// # Examples
127    ///
128    /// - An expression over one variable: `x > 0` is false for all records in a zone if the maximum
129    ///   value of the column `x` in that zone is less than or equal to zero: `max(x) <= 0`.
130    /// - An expression over two variables: `x > y` becomes `max(x) <= min(y)`.
131    /// - A conjunctive expression: `x > y AND z < x` becomes `max(x) <= min(y) OR min(z) >= max(x).
132    ///
133    /// Some expressions, in theory, have falsifications but this function does not support them
134    /// such as `x < (y < z)` or `x LIKE "needle%"`.
135    pub fn stat_falsification(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
136        self.scalar_fn().stat_falsification(self, catalog)
137    }
138
139    /// Returns an expression that proves this predicate is definitely false from stats.
140    ///
141    /// `scope` is the dtype of the row this expression evaluates over.
142    ///
143    /// If the returned expression evaluates to `true` for a stats scope, this expression is
144    /// guaranteed to be false for every row in that scope. `false` and `null` are unknown.
145    pub fn falsify(
146        &self,
147        scope: &DType,
148        session: &VortexSession,
149    ) -> VortexResult<Option<Expression>> {
150        crate::stats::rewrite::StatsRewriteCtx::new(session, scope).falsify(self)
151    }
152
153    /// Returns an expression that proves this predicate is definitely true from stats.
154    ///
155    /// `scope` is the dtype of the row this expression evaluates over.
156    ///
157    /// If the returned expression evaluates to `true` for a stats scope, this expression is
158    /// guaranteed to be true for every row in that scope. `false` and `null` are unknown.
159    pub fn satisfy(
160        &self,
161        scope: &DType,
162        session: &VortexSession,
163    ) -> VortexResult<Option<Expression>> {
164        crate::stats::rewrite::StatsRewriteCtx::new(session, scope).satisfy(self)
165    }
166
167    /// Returns an expression representing the zoned statistic for the given stat, if available.
168    ///
169    /// The [`StatsCatalog`] returns expressions that can be evaluated using the zone map as a
170    /// scope. Expressions can implement this function to propagate such statistics through the
171    /// expression tree. For example, the `a + 10` expression could propagate `min: min(a) + 10`.
172    ///
173    /// NOTE(gatesn): we currently cannot represent statistics over nested fields. Please file an
174    /// issue to discuss a solution to this.
175    pub fn stat_expression(&self, stat: Stat, catalog: &dyn StatsCatalog) -> Option<Expression> {
176        self.scalar_fn().stat_expression(self, stat, catalog)
177    }
178
179    /// Returns an expression representing the zoned maximum statistic, if available.
180    pub fn stat_min(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
181        self.stat_expression(Stat::Min, catalog)
182    }
183
184    /// Returns an expression representing the zoned maximum statistic, if available.
185    pub fn stat_max(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
186        self.stat_expression(Stat::Max, catalog)
187    }
188
189    /// Format the expression as a compact string.
190    ///
191    /// Since this is a recursive formatter, it is exposed on the public Expression type.
192    /// See fmt_data that is only implemented on the vtable trait.
193    pub fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
194        self.scalar_fn().fmt_sql(self, f)
195    }
196
197    /// Display the expression as a formatted tree structure.
198    ///
199    /// This provides a hierarchical view of the expression that shows the relationships
200    /// between parent and child expressions, making complex nested expressions easier
201    /// to understand and debug.
202    ///
203    /// # Example
204    ///
205    /// ```rust
206    /// # use vortex_array::dtype::{DType, Nullability, PType};
207    /// # use vortex_array::scalar_fn::fns::like::{Like, LikeOptions};
208    /// # use vortex_array::scalar_fn::ScalarFnVTableExt;
209    /// # use vortex_array::expr::{and, cast, eq, get_item, gt, lit, not, root, select};
210    /// // Build a complex nested expression
211    /// let complex_expr = select(
212    ///     ["result"],
213    ///     and(
214    ///         not(eq(get_item("status", root()), lit("inactive"))),
215    ///         and(
216    ///             Like.new_expr(LikeOptions::default(), [get_item("name", root()), lit("%admin%")]),
217    ///             gt(
218    ///                 cast(get_item("score", root()), DType::Primitive(PType::F64, Nullability::NonNullable)),
219    ///                 lit(75.0)
220    ///             )
221    ///         )
222    ///     )
223    /// );
224    ///
225    /// println!("{}", complex_expr.display_tree());
226    /// ```
227    ///
228    /// This produces output like:
229    ///
230    /// ```text
231    /// Select(include): {result}
232    /// └── Binary(and)
233    ///     ├── lhs: Not
234    ///     │   └── Binary(=)
235    ///     │       ├── lhs: GetItem(status)
236    ///     │       │   └── Root
237    ///     │       └── rhs: Literal(value: "inactive", dtype: utf8)
238    ///     └── rhs: Binary(and)
239    ///         ├── lhs: Like
240    ///         │   ├── child: GetItem(name)
241    ///         │   │   └── Root
242    ///         │   └── pattern: Literal(value: "%admin%", dtype: utf8)
243    ///         └── rhs: Binary(>)
244    ///             ├── lhs: Cast(target: f64)
245    ///             │   └── GetItem(score)
246    ///             │       └── Root
247    ///             └── rhs: Literal(value: 75f64, dtype: f64)
248    /// ```
249    pub fn display_tree(&self) -> impl Display {
250        DisplayTreeExpr(self)
251    }
252}
253
254/// The default display implementation for expressions uses the 'SQL'-style format.
255impl Display for Expression {
256    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
257        self.fmt_sql(f)
258    }
259}
260
261/// Iterative drop for expression to avoid stack overflows.
262impl Drop for Expression {
263    fn drop(&mut self) {
264        if let Some(children) = Arc::get_mut(&mut self.children) {
265            let mut children_to_drop = std::mem::take(children);
266
267            while let Some(mut child) = children_to_drop.pop() {
268                if let Some(expr_children) = Arc::get_mut(&mut child.children) {
269                    children_to_drop.append(expr_children);
270                }
271            }
272        }
273    }
274}