Skip to main content

vortex_array/expr/
expression.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Debug;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::hash::Hash;
9use std::ops::Deref;
10use std::sync::Arc;
11
12use itertools::Itertools;
13use vortex_error::VortexResult;
14use vortex_error::vortex_ensure;
15use vortex_session::VortexSession;
16
17use crate::dtype::DType;
18use crate::expr::display::DisplayTreeExpr;
19use crate::scalar_fn::ScalarFnRef;
20use crate::scalar_fn::fns::root::Root;
21
22/// A node in a Vortex expression tree.
23///
24/// Expressions represent scalar computations that can be performed on data. Each
25/// expression consists of an encoding (vtable), heap-allocated metadata, and child expressions.
26#[derive(Clone, Debug, PartialEq, Eq, Hash)]
27pub struct Expression {
28    /// The scalar fn for this node.
29    scalar_fn: ScalarFnRef,
30    /// Any children of this expression.
31    children: Arc<Vec<Expression>>,
32}
33
34impl Deref for Expression {
35    type Target = ScalarFnRef;
36
37    fn deref(&self) -> &Self::Target {
38        &self.scalar_fn
39    }
40}
41
42impl Expression {
43    /// Create a new expression node from a scalar_fn expression and its children.
44    pub fn try_new(
45        scalar_fn: ScalarFnRef,
46        children: impl IntoIterator<Item = Expression>,
47    ) -> VortexResult<Self> {
48        let children = Vec::from_iter(children);
49
50        vortex_ensure!(
51            scalar_fn.signature().arity().matches(children.len()),
52            "Expression arity mismatch: expected {} children but got {}",
53            scalar_fn.signature().arity(),
54            children.len()
55        );
56
57        Ok(Self {
58            scalar_fn,
59            children: children.into(),
60        })
61    }
62
63    /// Returns the scalar fn vtable for this expression.
64    pub fn scalar_fn(&self) -> &ScalarFnRef {
65        &self.scalar_fn
66    }
67
68    /// Returns the children of this expression.
69    pub fn children(&self) -> &Arc<Vec<Expression>> {
70        &self.children
71    }
72
73    /// Returns the n'th child of this expression.
74    pub fn child(&self, n: usize) -> &Expression {
75        &self.children[n]
76    }
77
78    /// Replace the children of this expression with the provided new children.
79    pub fn with_children(
80        mut self,
81        children: impl IntoIterator<Item = Expression>,
82    ) -> VortexResult<Self> {
83        let children = Vec::from_iter(children);
84        vortex_ensure!(
85            self.signature().arity().matches(children.len()),
86            "Expression arity mismatch: expected {} children but got {}",
87            self.signature().arity(),
88            children.len()
89        );
90        self.children = Arc::new(children);
91        Ok(self)
92    }
93
94    /// Computes the return dtype of this expression given the input dtype.
95    pub fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
96        if self.is::<Root>() {
97            return Ok(scope.clone());
98        }
99
100        let dtypes: Vec<_> = self
101            .children
102            .iter()
103            .map(|c| c.return_dtype(scope))
104            .try_collect()?;
105        self.scalar_fn.return_dtype(&dtypes)
106    }
107
108    /// Returns a new expression representing the validity mask output of this expression.
109    ///
110    /// The returned expression evaluates to a non-nullable boolean array.
111    pub fn validity(&self) -> VortexResult<Expression> {
112        self.scalar_fn.validity(self)
113    }
114
115    /// Returns an expression that proves this predicate is definitely false from stats.
116    ///
117    /// `scope` is the dtype of the row this expression evaluates over.
118    ///
119    /// If the returned expression evaluates to `true` for a stats scope, this expression is
120    /// guaranteed to be false for every row in that scope. `false` and `null` are unknown.
121    pub fn falsify(
122        &self,
123        scope: &DType,
124        session: &VortexSession,
125    ) -> VortexResult<Option<Expression>> {
126        crate::stats::rewrite::StatsRewriteCtx::new(session, scope).falsify(self)
127    }
128
129    /// Returns an expression that proves this predicate is definitely true from stats.
130    ///
131    /// `scope` is the dtype of the row this expression evaluates over.
132    ///
133    /// If the returned expression evaluates to `true` for a stats scope, this expression is
134    /// guaranteed to be true for every row in that scope. `false` and `null` are unknown.
135    pub fn satisfy(
136        &self,
137        scope: &DType,
138        session: &VortexSession,
139    ) -> VortexResult<Option<Expression>> {
140        crate::stats::rewrite::StatsRewriteCtx::new(session, scope).satisfy(self)
141    }
142
143    /// Format the expression as a compact string.
144    ///
145    /// Since this is a recursive formatter, it is exposed on the public Expression type.
146    /// See fmt_data that is only implemented on the vtable trait.
147    pub fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
148        self.scalar_fn().fmt_sql(self, f)
149    }
150
151    /// Display the expression as a formatted tree structure.
152    ///
153    /// This provides a hierarchical view of the expression that shows the relationships
154    /// between parent and child expressions, making complex nested expressions easier
155    /// to understand and debug.
156    ///
157    /// # Example
158    ///
159    /// ```rust
160    /// # use vortex_array::dtype::{DType, Nullability, PType};
161    /// # use vortex_array::scalar_fn::fns::like::{Like, LikeOptions};
162    /// # use vortex_array::scalar_fn::ScalarFnVTableExt;
163    /// # use vortex_array::expr::{and, cast, eq, get_item, gt, lit, not, root, select};
164    /// // Build a complex nested expression
165    /// let complex_expr = select(
166    ///     ["result"],
167    ///     and(
168    ///         not(eq(get_item("status", root()), lit("inactive"))),
169    ///         and(
170    ///             Like.new_expr(LikeOptions::default(), [get_item("name", root()), lit("%admin%")]),
171    ///             gt(
172    ///                 cast(get_item("score", root()), DType::Primitive(PType::F64, Nullability::NonNullable)),
173    ///                 lit(75.0)
174    ///             )
175    ///         )
176    ///     )
177    /// );
178    ///
179    /// println!("{}", complex_expr.display_tree());
180    /// ```
181    ///
182    /// This produces output like:
183    ///
184    /// ```text
185    /// Select(include): {result}
186    /// └── Binary(and)
187    ///     ├── lhs: Not
188    ///     │   └── Binary(=)
189    ///     │       ├── lhs: GetItem(status)
190    ///     │       │   └── Root
191    ///     │       └── rhs: Literal(value: "inactive", dtype: utf8)
192    ///     └── rhs: Binary(and)
193    ///         ├── lhs: Like
194    ///         │   ├── child: GetItem(name)
195    ///         │   │   └── Root
196    ///         │   └── pattern: Literal(value: "%admin%", dtype: utf8)
197    ///         └── rhs: Binary(>)
198    ///             ├── lhs: Cast(target: f64)
199    ///             │   └── GetItem(score)
200    ///             │       └── Root
201    ///             └── rhs: Literal(value: 75f64, dtype: f64)
202    /// ```
203    pub fn display_tree(&self) -> impl Display {
204        DisplayTreeExpr(self)
205    }
206}
207
208/// The default display implementation for expressions uses the 'SQL'-style format.
209impl Display for Expression {
210    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
211        self.fmt_sql(f)
212    }
213}
214
215/// Iterative drop for expression to avoid stack overflows.
216impl Drop for Expression {
217    fn drop(&mut self) {
218        if let Some(children) = Arc::get_mut(&mut self.children) {
219            let mut children_to_drop = std::mem::take(children);
220
221            while let Some(mut child) = children_to_drop.pop() {
222                if let Some(expr_children) = Arc::get_mut(&mut child.children) {
223                    children_to_drop.append(expr_children);
224                }
225            }
226        }
227    }
228}