vortex_array/expr/expression.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Debug;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::hash::Hash;
9use std::ops::Deref;
10use std::sync::Arc;
11
12use itertools::Itertools;
13use vortex_error::VortexResult;
14use vortex_error::vortex_ensure;
15use vortex_session::VortexSession;
16
17use crate::dtype::DType;
18use crate::expr::StatsCatalog;
19use crate::expr::display::DisplayTreeExpr;
20use crate::expr::stats::Stat;
21use crate::scalar_fn::ScalarFnRef;
22use crate::scalar_fn::fns::root::Root;
23
24/// A node in a Vortex expression tree.
25///
26/// Expressions represent scalar computations that can be performed on data. Each
27/// expression consists of an encoding (vtable), heap-allocated metadata, and child expressions.
28#[derive(Clone, Debug, PartialEq, Eq, Hash)]
29pub struct Expression {
30 /// The scalar fn for this node.
31 scalar_fn: ScalarFnRef,
32 /// Any children of this expression.
33 children: Arc<Vec<Expression>>,
34}
35
36impl Deref for Expression {
37 type Target = ScalarFnRef;
38
39 fn deref(&self) -> &Self::Target {
40 &self.scalar_fn
41 }
42}
43
44impl Expression {
45 /// Create a new expression node from a scalar_fn expression and its children.
46 pub fn try_new(
47 scalar_fn: ScalarFnRef,
48 children: impl IntoIterator<Item = Expression>,
49 ) -> VortexResult<Self> {
50 let children = Vec::from_iter(children);
51
52 vortex_ensure!(
53 scalar_fn.signature().arity().matches(children.len()),
54 "Expression arity mismatch: expected {} children but got {}",
55 scalar_fn.signature().arity(),
56 children.len()
57 );
58
59 Ok(Self {
60 scalar_fn,
61 children: children.into(),
62 })
63 }
64
65 /// Returns the scalar fn vtable for this expression.
66 pub fn scalar_fn(&self) -> &ScalarFnRef {
67 &self.scalar_fn
68 }
69
70 /// Returns the children of this expression.
71 pub fn children(&self) -> &Arc<Vec<Expression>> {
72 &self.children
73 }
74
75 /// Returns the n'th child of this expression.
76 pub fn child(&self, n: usize) -> &Expression {
77 &self.children[n]
78 }
79
80 /// Replace the children of this expression with the provided new children.
81 pub fn with_children(
82 mut self,
83 children: impl IntoIterator<Item = Expression>,
84 ) -> VortexResult<Self> {
85 let children = Vec::from_iter(children);
86 vortex_ensure!(
87 self.signature().arity().matches(children.len()),
88 "Expression arity mismatch: expected {} children but got {}",
89 self.signature().arity(),
90 children.len()
91 );
92 self.children = Arc::new(children);
93 Ok(self)
94 }
95
96 /// Computes the return dtype of this expression given the input dtype.
97 pub fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
98 if self.is::<Root>() {
99 return Ok(scope.clone());
100 }
101
102 let dtypes: Vec<_> = self
103 .children
104 .iter()
105 .map(|c| c.return_dtype(scope))
106 .try_collect()?;
107 self.scalar_fn.return_dtype(&dtypes)
108 }
109
110 /// Returns a new expression representing the validity mask output of this expression.
111 ///
112 /// The returned expression evaluates to a non-nullable boolean array.
113 pub fn validity(&self) -> VortexResult<Expression> {
114 self.scalar_fn.validity(self)
115 }
116
117 /// An expression over zone-statistics which implies all records in the zone evaluate to false.
118 ///
119 /// Given an expression, `e`, if `e.stat_falsification(..)` evaluates to true, it is guaranteed
120 /// that `e` evaluates to false on all records in the zone. However, the inverse is not
121 /// necessarily true: even if the falsification evaluates to false, `e` need not evaluate to
122 /// true on all records.
123 ///
124 /// The [`StatsCatalog`] can be used to constrain or rename stats used in the final expr.
125 ///
126 /// # Examples
127 ///
128 /// - An expression over one variable: `x > 0` is false for all records in a zone if the maximum
129 /// value of the column `x` in that zone is less than or equal to zero: `max(x) <= 0`.
130 /// - An expression over two variables: `x > y` becomes `max(x) <= min(y)`.
131 /// - A conjunctive expression: `x > y AND z < x` becomes `max(x) <= min(y) OR min(z) >= max(x).
132 ///
133 /// Some expressions, in theory, have falsifications but this function does not support them
134 /// such as `x < (y < z)` or `x LIKE "needle%"`.
135 pub fn stat_falsification(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
136 self.scalar_fn().stat_falsification(self, catalog)
137 }
138
139 /// Returns an expression that proves this predicate is definitely false from stats.
140 ///
141 /// If the returned expression evaluates to `true` for a stats scope, this expression is
142 /// guaranteed to be false for every row in that scope. `false` and `null` are unknown.
143 pub fn falsify(&self, session: &VortexSession) -> VortexResult<Option<Expression>> {
144 crate::stats::rewrite::StatsRewriteCtx::new(session).falsify(self)
145 }
146
147 /// Returns an expression that proves this predicate is definitely true from stats.
148 ///
149 /// If the returned expression evaluates to `true` for a stats scope, this expression is
150 /// guaranteed to be true for every row in that scope. `false` and `null` are unknown.
151 pub fn satisfy(&self, session: &VortexSession) -> VortexResult<Option<Expression>> {
152 crate::stats::rewrite::StatsRewriteCtx::new(session).satisfy(self)
153 }
154
155 /// Returns an expression representing the zoned statistic for the given stat, if available.
156 ///
157 /// The [`StatsCatalog`] returns expressions that can be evaluated using the zone map as a
158 /// scope. Expressions can implement this function to propagate such statistics through the
159 /// expression tree. For example, the `a + 10` expression could propagate `min: min(a) + 10`.
160 ///
161 /// NOTE(gatesn): we currently cannot represent statistics over nested fields. Please file an
162 /// issue to discuss a solution to this.
163 pub fn stat_expression(&self, stat: Stat, catalog: &dyn StatsCatalog) -> Option<Expression> {
164 self.scalar_fn().stat_expression(self, stat, catalog)
165 }
166
167 /// Returns an expression representing the zoned maximum statistic, if available.
168 pub fn stat_min(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
169 self.stat_expression(Stat::Min, catalog)
170 }
171
172 /// Returns an expression representing the zoned maximum statistic, if available.
173 pub fn stat_max(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
174 self.stat_expression(Stat::Max, catalog)
175 }
176
177 /// Format the expression as a compact string.
178 ///
179 /// Since this is a recursive formatter, it is exposed on the public Expression type.
180 /// See fmt_data that is only implemented on the vtable trait.
181 pub fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
182 self.scalar_fn().fmt_sql(self, f)
183 }
184
185 /// Display the expression as a formatted tree structure.
186 ///
187 /// This provides a hierarchical view of the expression that shows the relationships
188 /// between parent and child expressions, making complex nested expressions easier
189 /// to understand and debug.
190 ///
191 /// # Example
192 ///
193 /// ```rust
194 /// # use vortex_array::dtype::{DType, Nullability, PType};
195 /// # use vortex_array::scalar_fn::fns::like::{Like, LikeOptions};
196 /// # use vortex_array::scalar_fn::ScalarFnVTableExt;
197 /// # use vortex_array::expr::{and, cast, eq, get_item, gt, lit, not, root, select};
198 /// // Build a complex nested expression
199 /// let complex_expr = select(
200 /// ["result"],
201 /// and(
202 /// not(eq(get_item("status", root()), lit("inactive"))),
203 /// and(
204 /// Like.new_expr(LikeOptions::default(), [get_item("name", root()), lit("%admin%")]),
205 /// gt(
206 /// cast(get_item("score", root()), DType::Primitive(PType::F64, Nullability::NonNullable)),
207 /// lit(75.0)
208 /// )
209 /// )
210 /// )
211 /// );
212 ///
213 /// println!("{}", complex_expr.display_tree());
214 /// ```
215 ///
216 /// This produces output like:
217 ///
218 /// ```text
219 /// Select(include): {result}
220 /// └── Binary(and)
221 /// ├── lhs: Not
222 /// │ └── Binary(=)
223 /// │ ├── lhs: GetItem(status)
224 /// │ │ └── Root
225 /// │ └── rhs: Literal(value: "inactive", dtype: utf8)
226 /// └── rhs: Binary(and)
227 /// ├── lhs: Like
228 /// │ ├── child: GetItem(name)
229 /// │ │ └── Root
230 /// │ └── pattern: Literal(value: "%admin%", dtype: utf8)
231 /// └── rhs: Binary(>)
232 /// ├── lhs: Cast(target: f64)
233 /// │ └── GetItem(score)
234 /// │ └── Root
235 /// └── rhs: Literal(value: 75f64, dtype: f64)
236 /// ```
237 pub fn display_tree(&self) -> impl Display {
238 DisplayTreeExpr(self)
239 }
240}
241
242/// The default display implementation for expressions uses the 'SQL'-style format.
243impl Display for Expression {
244 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
245 self.fmt_sql(f)
246 }
247}
248
249/// Iterative drop for expression to avoid stack overflows.
250impl Drop for Expression {
251 fn drop(&mut self) {
252 if let Some(children) = Arc::get_mut(&mut self.children) {
253 let mut children_to_drop = std::mem::take(children);
254
255 while let Some(mut child) = children_to_drop.pop() {
256 if let Some(expr_children) = Arc::get_mut(&mut child.children) {
257 children_to_drop.append(expr_children);
258 }
259 }
260 }
261 }
262}