vortex_array/expr/expression.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Debug;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::hash::Hash;
9use std::ops::Deref;
10use std::sync::Arc;
11
12use itertools::Itertools;
13use vortex_dtype::DType;
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_ensure;
17
18use crate::ArrayRef;
19use crate::expr::Root;
20use crate::expr::ScalarFn;
21use crate::expr::StatsCatalog;
22use crate::expr::VTable;
23use crate::expr::display::DisplayTreeExpr;
24use crate::expr::stats::Stat;
25
26/// A node in a Vortex expression tree.
27///
28/// Expressions represent scalar computations that can be performed on data. Each
29/// expression consists of an encoding (vtable), heap-allocated metadata, and child expressions.
30#[derive(Clone, Debug, PartialEq, Eq, Hash)]
31pub struct Expression {
32 /// The scalar fn for this node.
33 scalar_fn: ScalarFn,
34 /// Any children of this expression.
35 children: Arc<[Expression]>,
36}
37
38impl Deref for Expression {
39 type Target = ScalarFn;
40
41 fn deref(&self) -> &Self::Target {
42 &self.scalar_fn
43 }
44}
45
46impl Expression {
47 /// Create a new expression node from a scalar_fn expression and its children.
48 pub fn try_new(
49 scalar_fn: ScalarFn,
50 children: impl Into<Arc<[Expression]>>,
51 ) -> VortexResult<Self> {
52 let children: Arc<[Expression]> = children.into();
53
54 vortex_ensure!(
55 scalar_fn.signature().arity().matches(children.len()),
56 "Expression arity mismatch: expected {} children but got {}",
57 scalar_fn.signature().arity(),
58 children.len()
59 );
60
61 Ok(Self {
62 scalar_fn,
63 children,
64 })
65 }
66
67 /// Returns true if this expression is of the given vtable type.
68 pub fn is<V: VTable>(&self) -> bool {
69 self.vtable().is::<V>()
70 }
71
72 /// Returns the typed options for this expression if it matches the given vtable type.
73 pub fn as_opt<V: VTable>(&self) -> Option<&V::Options> {
74 self.options().as_any().downcast_ref::<V::Options>()
75 }
76
77 /// Returns the typed options for this expression if it matches the given vtable type.
78 pub fn as_<V: VTable>(&self) -> &V::Options {
79 self.as_opt::<V>()
80 .vortex_expect("Expression options type mismatch")
81 }
82
83 /// Returns the scalar fn vtable for this expression.
84 pub fn scalar_fn(&self) -> &ScalarFn {
85 &self.scalar_fn
86 }
87
88 /// Returns the children of this expression.
89 pub fn children(&self) -> &Arc<[Expression]> {
90 &self.children
91 }
92
93 /// Returns the n'th child of this expression.
94 pub fn child(&self, n: usize) -> &Expression {
95 &self.children[n]
96 }
97
98 /// Replace the children of this expression with the provided new children.
99 pub fn with_children(mut self, children: impl Into<Arc<[Expression]>>) -> VortexResult<Self> {
100 let children = children.into();
101 vortex_ensure!(
102 self.signature().arity().matches(children.len()),
103 "Expression arity mismatch: expected {} children but got {}",
104 self.signature().arity(),
105 children.len()
106 );
107 self.children = children;
108 Ok(self)
109 }
110
111 /// Computes the return dtype of this expression given the input dtype.
112 pub fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
113 if self.is::<Root>() {
114 return Ok(scope.clone());
115 }
116
117 let dtypes: Vec<_> = self
118 .children
119 .iter()
120 .map(|c| c.return_dtype(scope))
121 .try_collect()?;
122 self.scalar_fn.return_dtype(&dtypes)
123 }
124
125 /// Evaluates the expression in the given scope, returning an array.
126 pub fn evaluate(&self, scope: &ArrayRef) -> VortexResult<ArrayRef> {
127 if self.is::<Root>() {
128 return Ok(scope.clone());
129 }
130 self.scalar_fn.evaluate(self, scope)
131 }
132
133 /// An expression over zone-statistics which implies all records in the zone evaluate to false.
134 ///
135 /// Given an expression, `e`, if `e.stat_falsification(..)` evaluates to true, it is guaranteed
136 /// that `e` evaluates to false on all records in the zone. However, the inverse is not
137 /// necessarily true: even if the falsification evaluates to false, `e` need not evaluate to
138 /// true on all records.
139 ///
140 /// The [`StatsCatalog`] can be used to constrain or rename stats used in the final expr.
141 ///
142 /// # Examples
143 ///
144 /// - An expression over one variable: `x > 0` is false for all records in a zone if the maximum
145 /// value of the column `x` in that zone is less than or equal to zero: `max(x) <= 0`.
146 /// - An expression over two variables: `x > y` becomes `max(x) <= min(y)`.
147 /// - A conjunctive expression: `x > y AND z < x` becomes `max(x) <= min(y) OR min(z) >= max(x).
148 ///
149 /// Some expressions, in theory, have falsifications but this function does not support them
150 /// such as `x < (y < z)` or `x LIKE "needle%"`.
151 pub fn stat_falsification(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
152 self.vtable().as_dyn().stat_falsification(self, catalog)
153 }
154
155 /// Returns an expression representing the zoned statistic for the given stat, if available.
156 ///
157 /// The [`StatsCatalog`] returns expressions that can be evaluated using the zone map as a
158 /// scope. Expressions can implement this function to propagate such statistics through the
159 /// expression tree. For example, the `a + 10` expression could propagate `min: min(a) + 10`.
160 ///
161 /// NOTE(gatesn): we currently cannot represent statistics over nested fields. Please file an
162 /// issue to discuss a solution to this.
163 pub fn stat_expression(&self, stat: Stat, catalog: &dyn StatsCatalog) -> Option<Expression> {
164 self.vtable().as_dyn().stat_expression(self, stat, catalog)
165 }
166
167 /// Returns an expression representing the zoned maximum statistic, if available.
168 pub fn stat_min(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
169 self.stat_expression(Stat::Min, catalog)
170 }
171
172 /// Returns an expression representing the zoned maximum statistic, if available.
173 pub fn stat_max(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
174 self.stat_expression(Stat::Max, catalog)
175 }
176
177 /// Format the expression as a compact string.
178 ///
179 /// Since this is a recursive formatter, it is exposed on the public Expression type.
180 /// See fmt_data that is only implemented on the vtable trait.
181 pub fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
182 self.vtable().as_dyn().fmt_sql(self, f)
183 }
184
185 /// Display the expression as a formatted tree structure.
186 ///
187 /// This provides a hierarchical view of the expression that shows the relationships
188 /// between parent and child expressions, making complex nested expressions easier
189 /// to understand and debug.
190 ///
191 /// # Example
192 ///
193 /// ```rust
194 /// # use vortex_array::compute::LikeOptions;
195 /// # use vortex_array::expr::VTableExt;
196 /// # use vortex_dtype::{DType, Nullability, PType};
197 /// # use vortex_array::expr::{and, cast, eq, get_item, gt, lit, not, root, select, Like};
198 /// // Build a complex nested expression
199 /// let complex_expr = select(
200 /// ["result"],
201 /// and(
202 /// not(eq(get_item("status", root()), lit("inactive"))),
203 /// and(
204 /// Like.new_expr(LikeOptions::default(), [get_item("name", root()), lit("%admin%")]),
205 /// gt(
206 /// cast(get_item("score", root()), DType::Primitive(PType::F64, Nullability::NonNullable)),
207 /// lit(75.0)
208 /// )
209 /// )
210 /// )
211 /// );
212 ///
213 /// println!("{}", complex_expr.display_tree());
214 /// ```
215 ///
216 /// This produces output like:
217 ///
218 /// ```text
219 /// Select(include): {result}
220 /// └── Binary(and)
221 /// ├── lhs: Not
222 /// │ └── Binary(=)
223 /// │ ├── lhs: GetItem(status)
224 /// │ │ └── Root
225 /// │ └── rhs: Literal(value: "inactive", dtype: utf8)
226 /// └── rhs: Binary(and)
227 /// ├── lhs: Like
228 /// │ ├── child: GetItem(name)
229 /// │ │ └── Root
230 /// │ └── pattern: Literal(value: "%admin%", dtype: utf8)
231 /// └── rhs: Binary(>)
232 /// ├── lhs: Cast(target: f64)
233 /// │ └── GetItem(score)
234 /// │ └── Root
235 /// └── rhs: Literal(value: 75f64, dtype: f64)
236 /// ```
237 pub fn display_tree(&self) -> impl Display {
238 DisplayTreeExpr(self)
239 }
240}
241
242/// The default display implementation for expressions uses the 'SQL'-style format.
243impl Display for Expression {
244 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
245 self.fmt_sql(f)
246 }
247}