vortex_array/expr/expression.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Debug;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::hash::Hash;
9use std::ops::Deref;
10use std::sync::Arc;
11
12use itertools::Itertools;
13use vortex_error::VortexResult;
14use vortex_error::vortex_ensure;
15use vortex_session::VortexSession;
16
17use crate::dtype::DType;
18use crate::expr::display::DisplayTreeExpr;
19use crate::scalar_fn::ScalarFnRef;
20use crate::scalar_fn::fns::root::Root;
21
22/// A node in a Vortex expression tree.
23///
24/// Expressions represent scalar computations that can be performed on data. Each
25/// expression consists of an encoding (vtable), heap-allocated metadata, and child expressions.
26#[derive(Clone, Debug, PartialEq, Eq, Hash)]
27pub struct Expression {
28 /// The scalar fn for this node.
29 scalar_fn: ScalarFnRef,
30 /// Any children of this expression.
31 children: Arc<Vec<Expression>>,
32}
33
34impl Deref for Expression {
35 type Target = ScalarFnRef;
36
37 fn deref(&self) -> &Self::Target {
38 &self.scalar_fn
39 }
40}
41
42impl Expression {
43 /// Create a new expression node from a scalar_fn expression and its children.
44 pub fn try_new(
45 scalar_fn: ScalarFnRef,
46 children: impl IntoIterator<Item = Expression>,
47 ) -> VortexResult<Self> {
48 let children = Vec::from_iter(children);
49
50 vortex_ensure!(
51 scalar_fn.signature().arity().matches(children.len()),
52 "Expression arity mismatch: expected {} children but got {}",
53 scalar_fn.signature().arity(),
54 children.len()
55 );
56
57 Ok(Self {
58 scalar_fn,
59 children: children.into(),
60 })
61 }
62
63 /// Returns the scalar fn vtable for this expression.
64 pub fn scalar_fn(&self) -> &ScalarFnRef {
65 &self.scalar_fn
66 }
67
68 /// Returns the children of this expression.
69 pub fn children(&self) -> &Arc<Vec<Expression>> {
70 &self.children
71 }
72
73 /// Returns the n'th child of this expression.
74 pub fn child(&self, n: usize) -> &Expression {
75 &self.children[n]
76 }
77
78 /// Replace the children of this expression with the provided new children.
79 pub fn with_children(
80 mut self,
81 children: impl IntoIterator<Item = Expression>,
82 ) -> VortexResult<Self> {
83 let children = Vec::from_iter(children);
84 vortex_ensure!(
85 self.signature().arity().matches(children.len()),
86 "Expression arity mismatch: expected {} children but got {}",
87 self.signature().arity(),
88 children.len()
89 );
90 self.children = Arc::new(children);
91 Ok(self)
92 }
93
94 /// Computes the return dtype of this expression given the input dtype.
95 pub fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
96 if self.is::<Root>() {
97 return Ok(scope.clone());
98 }
99
100 let dtypes: Vec<_> = self
101 .children
102 .iter()
103 .map(|c| c.return_dtype(scope))
104 .try_collect()?;
105 self.scalar_fn.return_dtype(&dtypes)
106 }
107
108 /// Returns a new expression representing the validity mask output of this expression.
109 ///
110 /// The returned expression evaluates to a non-nullable boolean array.
111 pub fn validity(&self) -> VortexResult<Expression> {
112 self.scalar_fn.validity(self)
113 }
114
115 /// Returns an expression that proves this predicate is definitely false from stats.
116 ///
117 /// `scope` is the dtype of the row this expression evaluates over.
118 ///
119 /// If the returned expression evaluates to `true` for a stats scope, this expression is
120 /// guaranteed to be false for every row in that scope. `false` and `null` are unknown.
121 pub fn falsify(
122 &self,
123 scope: &DType,
124 session: &VortexSession,
125 ) -> VortexResult<Option<Expression>> {
126 crate::stats::rewrite::StatsRewriteCtx::new(session, scope).falsify(self)
127 }
128
129 /// Returns an expression that proves this predicate is definitely true from stats.
130 ///
131 /// `scope` is the dtype of the row this expression evaluates over.
132 ///
133 /// If the returned expression evaluates to `true` for a stats scope, this expression is
134 /// guaranteed to be true for every row in that scope. `false` and `null` are unknown.
135 pub fn satisfy(
136 &self,
137 scope: &DType,
138 session: &VortexSession,
139 ) -> VortexResult<Option<Expression>> {
140 crate::stats::rewrite::StatsRewriteCtx::new(session, scope).satisfy(self)
141 }
142
143 /// Format the expression as a compact string.
144 ///
145 /// Since this is a recursive formatter, it is exposed on the public Expression type.
146 /// See fmt_data that is only implemented on the vtable trait.
147 pub fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
148 self.scalar_fn().fmt_sql(self, f)
149 }
150
151 /// Display the expression as a formatted tree structure.
152 ///
153 /// This provides a hierarchical view of the expression that shows the relationships
154 /// between parent and child expressions, making complex nested expressions easier
155 /// to understand and debug.
156 ///
157 /// # Example
158 ///
159 /// ```rust
160 /// # use vortex_array::dtype::{DType, Nullability, PType};
161 /// # use vortex_array::scalar_fn::fns::like::{Like, LikeOptions};
162 /// # use vortex_array::scalar_fn::ScalarFnVTableExt;
163 /// # use vortex_array::expr::{and, cast, eq, get_item, gt, lit, not, root, select};
164 /// // Build a complex nested expression
165 /// let complex_expr = select(
166 /// ["result"],
167 /// and(
168 /// not(eq(get_item("status", root()), lit("inactive"))),
169 /// and(
170 /// Like.new_expr(LikeOptions::default(), [get_item("name", root()), lit("%admin%")]),
171 /// gt(
172 /// cast(get_item("score", root()), DType::Primitive(PType::F64, Nullability::NonNullable)),
173 /// lit(75.0)
174 /// )
175 /// )
176 /// )
177 /// );
178 ///
179 /// println!("{}", complex_expr.display_tree());
180 /// ```
181 ///
182 /// This produces output like:
183 ///
184 /// ```text
185 /// Select(include): {result}
186 /// └── Binary(and)
187 /// ├── lhs: Not
188 /// │ └── Binary(=)
189 /// │ ├── lhs: GetItem(status)
190 /// │ │ └── Root
191 /// │ └── rhs: Literal(value: "inactive", dtype: utf8)
192 /// └── rhs: Binary(and)
193 /// ├── lhs: Like
194 /// │ ├── child: GetItem(name)
195 /// │ │ └── Root
196 /// │ └── pattern: Literal(value: "%admin%", dtype: utf8)
197 /// └── rhs: Binary(>)
198 /// ├── lhs: Cast(target: f64)
199 /// │ └── GetItem(score)
200 /// │ └── Root
201 /// └── rhs: Literal(value: 75f64, dtype: f64)
202 /// ```
203 pub fn display_tree(&self) -> impl Display {
204 DisplayTreeExpr(self)
205 }
206}
207
208/// The default display implementation for expressions uses the 'SQL'-style format.
209impl Display for Expression {
210 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
211 self.fmt_sql(f)
212 }
213}
214
215/// Iterative drop for expression to avoid stack overflows.
216impl Drop for Expression {
217 fn drop(&mut self) {
218 if let Some(children) = Arc::get_mut(&mut self.children) {
219 let mut children_to_drop = std::mem::take(children);
220
221 while let Some(mut child) = children_to_drop.pop() {
222 if let Some(expr_children) = Arc::get_mut(&mut child.children) {
223 children_to_drop.append(expr_children);
224 }
225 }
226 }
227 }
228}