Skip to main content

vortex_array/expr/
expression.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Debug;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::hash::Hash;
9use std::ops::Deref;
10use std::sync::Arc;
11
12use itertools::Itertools;
13use vortex_error::VortexResult;
14use vortex_error::vortex_ensure;
15use vortex_session::VortexSession;
16
17use crate::dtype::DType;
18use crate::expr::StatsCatalog;
19use crate::expr::display::DisplayTreeExpr;
20use crate::expr::stats::Stat;
21use crate::scalar_fn::ScalarFnRef;
22use crate::scalar_fn::fns::root::Root;
23
24/// A node in a Vortex expression tree.
25///
26/// Expressions represent scalar computations that can be performed on data. Each
27/// expression consists of an encoding (vtable), heap-allocated metadata, and child expressions.
28#[derive(Clone, Debug, PartialEq, Eq, Hash)]
29pub struct Expression {
30    /// The scalar fn for this node.
31    scalar_fn: ScalarFnRef,
32    /// Any children of this expression.
33    children: Arc<Vec<Expression>>,
34}
35
36impl Deref for Expression {
37    type Target = ScalarFnRef;
38
39    fn deref(&self) -> &Self::Target {
40        &self.scalar_fn
41    }
42}
43
44impl Expression {
45    /// Create a new expression node from a scalar_fn expression and its children.
46    pub fn try_new(
47        scalar_fn: ScalarFnRef,
48        children: impl IntoIterator<Item = Expression>,
49    ) -> VortexResult<Self> {
50        let children = Vec::from_iter(children);
51
52        vortex_ensure!(
53            scalar_fn.signature().arity().matches(children.len()),
54            "Expression arity mismatch: expected {} children but got {}",
55            scalar_fn.signature().arity(),
56            children.len()
57        );
58
59        Ok(Self {
60            scalar_fn,
61            children: children.into(),
62        })
63    }
64
65    /// Returns the scalar fn vtable for this expression.
66    pub fn scalar_fn(&self) -> &ScalarFnRef {
67        &self.scalar_fn
68    }
69
70    /// Returns the children of this expression.
71    pub fn children(&self) -> &Arc<Vec<Expression>> {
72        &self.children
73    }
74
75    /// Returns the n'th child of this expression.
76    pub fn child(&self, n: usize) -> &Expression {
77        &self.children[n]
78    }
79
80    /// Replace the children of this expression with the provided new children.
81    pub fn with_children(
82        mut self,
83        children: impl IntoIterator<Item = Expression>,
84    ) -> VortexResult<Self> {
85        let children = Vec::from_iter(children);
86        vortex_ensure!(
87            self.signature().arity().matches(children.len()),
88            "Expression arity mismatch: expected {} children but got {}",
89            self.signature().arity(),
90            children.len()
91        );
92        self.children = Arc::new(children);
93        Ok(self)
94    }
95
96    /// Computes the return dtype of this expression given the input dtype.
97    pub fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
98        if self.is::<Root>() {
99            return Ok(scope.clone());
100        }
101
102        let dtypes: Vec<_> = self
103            .children
104            .iter()
105            .map(|c| c.return_dtype(scope))
106            .try_collect()?;
107        self.scalar_fn.return_dtype(&dtypes)
108    }
109
110    /// Returns a new expression representing the validity mask output of this expression.
111    ///
112    /// The returned expression evaluates to a non-nullable boolean array.
113    pub fn validity(&self) -> VortexResult<Expression> {
114        self.scalar_fn.validity(self)
115    }
116
117    /// An expression over zone-statistics which implies all records in the zone evaluate to false.
118    ///
119    /// Given an expression, `e`, if `e.stat_falsification(..)` evaluates to true, it is guaranteed
120    /// that `e` evaluates to false on all records in the zone. However, the inverse is not
121    /// necessarily true: even if the falsification evaluates to false, `e` need not evaluate to
122    /// true on all records.
123    ///
124    /// The [`StatsCatalog`] can be used to constrain or rename stats used in the final expr.
125    ///
126    /// # Examples
127    ///
128    /// - An expression over one variable: `x > 0` is false for all records in a zone if the maximum
129    ///   value of the column `x` in that zone is less than or equal to zero: `max(x) <= 0`.
130    /// - An expression over two variables: `x > y` becomes `max(x) <= min(y)`.
131    /// - A conjunctive expression: `x > y AND z < x` becomes `max(x) <= min(y) OR min(z) >= max(x).
132    ///
133    /// Some expressions, in theory, have falsifications but this function does not support them
134    /// such as `x < (y < z)` or `x LIKE "needle%"`.
135    pub fn stat_falsification(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
136        self.scalar_fn().stat_falsification(self, catalog)
137    }
138
139    /// Returns an expression that proves this predicate is definitely false from stats.
140    ///
141    /// If the returned expression evaluates to `true` for a stats scope, this expression is
142    /// guaranteed to be false for every row in that scope. `false` and `null` are unknown.
143    pub fn falsify(&self, session: &VortexSession) -> VortexResult<Option<Expression>> {
144        crate::stats::rewrite::StatsRewriteCtx::new(session).falsify(self)
145    }
146
147    /// Returns an expression that proves this predicate is definitely true from stats.
148    ///
149    /// If the returned expression evaluates to `true` for a stats scope, this expression is
150    /// guaranteed to be true for every row in that scope. `false` and `null` are unknown.
151    pub fn satisfy(&self, session: &VortexSession) -> VortexResult<Option<Expression>> {
152        crate::stats::rewrite::StatsRewriteCtx::new(session).satisfy(self)
153    }
154
155    /// Returns an expression representing the zoned statistic for the given stat, if available.
156    ///
157    /// The [`StatsCatalog`] returns expressions that can be evaluated using the zone map as a
158    /// scope. Expressions can implement this function to propagate such statistics through the
159    /// expression tree. For example, the `a + 10` expression could propagate `min: min(a) + 10`.
160    ///
161    /// NOTE(gatesn): we currently cannot represent statistics over nested fields. Please file an
162    /// issue to discuss a solution to this.
163    pub fn stat_expression(&self, stat: Stat, catalog: &dyn StatsCatalog) -> Option<Expression> {
164        self.scalar_fn().stat_expression(self, stat, catalog)
165    }
166
167    /// Returns an expression representing the zoned maximum statistic, if available.
168    pub fn stat_min(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
169        self.stat_expression(Stat::Min, catalog)
170    }
171
172    /// Returns an expression representing the zoned maximum statistic, if available.
173    pub fn stat_max(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
174        self.stat_expression(Stat::Max, catalog)
175    }
176
177    /// Format the expression as a compact string.
178    ///
179    /// Since this is a recursive formatter, it is exposed on the public Expression type.
180    /// See fmt_data that is only implemented on the vtable trait.
181    pub fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
182        self.scalar_fn().fmt_sql(self, f)
183    }
184
185    /// Display the expression as a formatted tree structure.
186    ///
187    /// This provides a hierarchical view of the expression that shows the relationships
188    /// between parent and child expressions, making complex nested expressions easier
189    /// to understand and debug.
190    ///
191    /// # Example
192    ///
193    /// ```rust
194    /// # use vortex_array::dtype::{DType, Nullability, PType};
195    /// # use vortex_array::scalar_fn::fns::like::{Like, LikeOptions};
196    /// # use vortex_array::scalar_fn::ScalarFnVTableExt;
197    /// # use vortex_array::expr::{and, cast, eq, get_item, gt, lit, not, root, select};
198    /// // Build a complex nested expression
199    /// let complex_expr = select(
200    ///     ["result"],
201    ///     and(
202    ///         not(eq(get_item("status", root()), lit("inactive"))),
203    ///         and(
204    ///             Like.new_expr(LikeOptions::default(), [get_item("name", root()), lit("%admin%")]),
205    ///             gt(
206    ///                 cast(get_item("score", root()), DType::Primitive(PType::F64, Nullability::NonNullable)),
207    ///                 lit(75.0)
208    ///             )
209    ///         )
210    ///     )
211    /// );
212    ///
213    /// println!("{}", complex_expr.display_tree());
214    /// ```
215    ///
216    /// This produces output like:
217    ///
218    /// ```text
219    /// Select(include): {result}
220    /// └── Binary(and)
221    ///     ├── lhs: Not
222    ///     │   └── Binary(=)
223    ///     │       ├── lhs: GetItem(status)
224    ///     │       │   └── Root
225    ///     │       └── rhs: Literal(value: "inactive", dtype: utf8)
226    ///     └── rhs: Binary(and)
227    ///         ├── lhs: Like
228    ///         │   ├── child: GetItem(name)
229    ///         │   │   └── Root
230    ///         │   └── pattern: Literal(value: "%admin%", dtype: utf8)
231    ///         └── rhs: Binary(>)
232    ///             ├── lhs: Cast(target: f64)
233    ///             │   └── GetItem(score)
234    ///             │       └── Root
235    ///             └── rhs: Literal(value: 75f64, dtype: f64)
236    /// ```
237    pub fn display_tree(&self) -> impl Display {
238        DisplayTreeExpr(self)
239    }
240}
241
242/// The default display implementation for expressions uses the 'SQL'-style format.
243impl Display for Expression {
244    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
245        self.fmt_sql(f)
246    }
247}
248
249/// Iterative drop for expression to avoid stack overflows.
250impl Drop for Expression {
251    fn drop(&mut self) {
252        if let Some(children) = Arc::get_mut(&mut self.children) {
253            let mut children_to_drop = std::mem::take(children);
254
255            while let Some(mut child) = children_to_drop.pop() {
256                if let Some(expr_children) = Arc::get_mut(&mut child.children) {
257                    children_to_drop.append(expr_children);
258                }
259            }
260        }
261    }
262}