vortex_array/expr/expression.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Debug;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::hash::Hash;
9use std::ops::Deref;
10use std::sync::Arc;
11
12use itertools::Itertools;
13use vortex_dtype::DType;
14use vortex_error::VortexResult;
15use vortex_error::vortex_ensure;
16
17use crate::expr::Root;
18use crate::expr::ScalarFn;
19use crate::expr::StatsCatalog;
20use crate::expr::display::DisplayTreeExpr;
21use crate::expr::stats::Stat;
22
23/// A node in a Vortex expression tree.
24///
25/// Expressions represent scalar computations that can be performed on data. Each
26/// expression consists of an encoding (vtable), heap-allocated metadata, and child expressions.
27#[derive(Clone, Debug, PartialEq, Eq, Hash)]
28pub struct Expression {
29 /// The scalar fn for this node.
30 scalar_fn: ScalarFn,
31 /// Any children of this expression.
32 children: Arc<Vec<Expression>>,
33}
34
35impl Deref for Expression {
36 type Target = ScalarFn;
37
38 fn deref(&self) -> &Self::Target {
39 &self.scalar_fn
40 }
41}
42
43impl Expression {
44 /// Create a new expression node from a scalar_fn expression and its children.
45 pub fn try_new(
46 scalar_fn: ScalarFn,
47 children: impl IntoIterator<Item = Expression>,
48 ) -> VortexResult<Self> {
49 let children = Vec::from_iter(children);
50
51 vortex_ensure!(
52 scalar_fn.signature().arity().matches(children.len()),
53 "Expression arity mismatch: expected {} children but got {}",
54 scalar_fn.signature().arity(),
55 children.len()
56 );
57
58 Ok(Self {
59 scalar_fn,
60 children: children.into(),
61 })
62 }
63
64 /// Returns the scalar fn vtable for this expression.
65 pub fn scalar_fn(&self) -> &ScalarFn {
66 &self.scalar_fn
67 }
68
69 /// Returns the children of this expression.
70 pub fn children(&self) -> &Arc<Vec<Expression>> {
71 &self.children
72 }
73
74 /// Returns the n'th child of this expression.
75 pub fn child(&self, n: usize) -> &Expression {
76 &self.children[n]
77 }
78
79 /// Replace the children of this expression with the provided new children.
80 pub fn with_children(
81 mut self,
82 children: impl IntoIterator<Item = Expression>,
83 ) -> VortexResult<Self> {
84 let children = Vec::from_iter(children);
85 vortex_ensure!(
86 self.signature().arity().matches(children.len()),
87 "Expression arity mismatch: expected {} children but got {}",
88 self.signature().arity(),
89 children.len()
90 );
91 self.children = Arc::new(children);
92 Ok(self)
93 }
94
95 /// Computes the return dtype of this expression given the input dtype.
96 pub fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
97 if self.is::<Root>() {
98 return Ok(scope.clone());
99 }
100
101 let dtypes: Vec<_> = self
102 .children
103 .iter()
104 .map(|c| c.return_dtype(scope))
105 .try_collect()?;
106 self.scalar_fn.return_dtype(&dtypes)
107 }
108
109 /// Returns a new expression representing the validity mask output of this expression.
110 ///
111 /// The returned expression evaluates to a non-nullable boolean array.
112 pub fn validity(&self) -> VortexResult<Expression> {
113 self.scalar_fn.validity(self)
114 }
115
116 /// An expression over zone-statistics which implies all records in the zone evaluate to false.
117 ///
118 /// Given an expression, `e`, if `e.stat_falsification(..)` evaluates to true, it is guaranteed
119 /// that `e` evaluates to false on all records in the zone. However, the inverse is not
120 /// necessarily true: even if the falsification evaluates to false, `e` need not evaluate to
121 /// true on all records.
122 ///
123 /// The [`StatsCatalog`] can be used to constrain or rename stats used in the final expr.
124 ///
125 /// # Examples
126 ///
127 /// - An expression over one variable: `x > 0` is false for all records in a zone if the maximum
128 /// value of the column `x` in that zone is less than or equal to zero: `max(x) <= 0`.
129 /// - An expression over two variables: `x > y` becomes `max(x) <= min(y)`.
130 /// - A conjunctive expression: `x > y AND z < x` becomes `max(x) <= min(y) OR min(z) >= max(x).
131 ///
132 /// Some expressions, in theory, have falsifications but this function does not support them
133 /// such as `x < (y < z)` or `x LIKE "needle%"`.
134 pub fn stat_falsification(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
135 self.vtable().as_dyn().stat_falsification(self, catalog)
136 }
137
138 /// Returns an expression representing the zoned statistic for the given stat, if available.
139 ///
140 /// The [`StatsCatalog`] returns expressions that can be evaluated using the zone map as a
141 /// scope. Expressions can implement this function to propagate such statistics through the
142 /// expression tree. For example, the `a + 10` expression could propagate `min: min(a) + 10`.
143 ///
144 /// NOTE(gatesn): we currently cannot represent statistics over nested fields. Please file an
145 /// issue to discuss a solution to this.
146 pub fn stat_expression(&self, stat: Stat, catalog: &dyn StatsCatalog) -> Option<Expression> {
147 self.vtable().as_dyn().stat_expression(self, stat, catalog)
148 }
149
150 /// Returns an expression representing the zoned maximum statistic, if available.
151 pub fn stat_min(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
152 self.stat_expression(Stat::Min, catalog)
153 }
154
155 /// Returns an expression representing the zoned maximum statistic, if available.
156 pub fn stat_max(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
157 self.stat_expression(Stat::Max, catalog)
158 }
159
160 /// Format the expression as a compact string.
161 ///
162 /// Since this is a recursive formatter, it is exposed on the public Expression type.
163 /// See fmt_data that is only implemented on the vtable trait.
164 pub fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
165 self.vtable().as_dyn().fmt_sql(self, f)
166 }
167
168 /// Display the expression as a formatted tree structure.
169 ///
170 /// This provides a hierarchical view of the expression that shows the relationships
171 /// between parent and child expressions, making complex nested expressions easier
172 /// to understand and debug.
173 ///
174 /// # Example
175 ///
176 /// ```rust
177 /// # use vortex_array::expr::LikeOptions;
178 /// # use vortex_array::expr::VTableExt;
179 /// # use vortex_dtype::{DType, Nullability, PType};
180 /// # use vortex_array::expr::{and, cast, eq, get_item, gt, lit, not, root, select, Like};
181 /// // Build a complex nested expression
182 /// let complex_expr = select(
183 /// ["result"],
184 /// and(
185 /// not(eq(get_item("status", root()), lit("inactive"))),
186 /// and(
187 /// Like.new_expr(LikeOptions::default(), [get_item("name", root()), lit("%admin%")]),
188 /// gt(
189 /// cast(get_item("score", root()), DType::Primitive(PType::F64, Nullability::NonNullable)),
190 /// lit(75.0)
191 /// )
192 /// )
193 /// )
194 /// );
195 ///
196 /// println!("{}", complex_expr.display_tree());
197 /// ```
198 ///
199 /// This produces output like:
200 ///
201 /// ```text
202 /// Select(include): {result}
203 /// └── Binary(and)
204 /// ├── lhs: Not
205 /// │ └── Binary(=)
206 /// │ ├── lhs: GetItem(status)
207 /// │ │ └── Root
208 /// │ └── rhs: Literal(value: "inactive", dtype: utf8)
209 /// └── rhs: Binary(and)
210 /// ├── lhs: Like
211 /// │ ├── child: GetItem(name)
212 /// │ │ └── Root
213 /// │ └── pattern: Literal(value: "%admin%", dtype: utf8)
214 /// └── rhs: Binary(>)
215 /// ├── lhs: Cast(target: f64)
216 /// │ └── GetItem(score)
217 /// │ └── Root
218 /// └── rhs: Literal(value: 75f64, dtype: f64)
219 /// ```
220 pub fn display_tree(&self) -> impl Display {
221 DisplayTreeExpr(self)
222 }
223}
224
225/// The default display implementation for expressions uses the 'SQL'-style format.
226impl Display for Expression {
227 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
228 self.fmt_sql(f)
229 }
230}
231
232/// Iterative drop for expression to avoid stack overflows.
233impl Drop for Expression {
234 fn drop(&mut self) {
235 if let Some(children) = Arc::get_mut(&mut self.children) {
236 let mut children_to_drop = std::mem::take(children);
237
238 while let Some(mut child) = children_to_drop.pop() {
239 if let Some(expr_children) = Arc::get_mut(&mut child.children) {
240 children_to_drop.append(expr_children);
241 }
242 }
243 }
244 }
245}