Skip to main content

vortex_array/expr/
expression.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Debug;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::hash::Hash;
9use std::ops::Deref;
10use std::sync::Arc;
11
12use itertools::Itertools;
13use vortex_dtype::DType;
14use vortex_error::VortexResult;
15use vortex_error::vortex_ensure;
16
17use crate::expr::Root;
18use crate::expr::ScalarFn;
19use crate::expr::StatsCatalog;
20use crate::expr::display::DisplayTreeExpr;
21use crate::expr::stats::Stat;
22
23/// A node in a Vortex expression tree.
24///
25/// Expressions represent scalar computations that can be performed on data. Each
26/// expression consists of an encoding (vtable), heap-allocated metadata, and child expressions.
27#[derive(Clone, Debug, PartialEq, Eq, Hash)]
28pub struct Expression {
29    /// The scalar fn for this node.
30    scalar_fn: ScalarFn,
31    /// Any children of this expression.
32    children: Arc<Vec<Expression>>,
33}
34
35impl Deref for Expression {
36    type Target = ScalarFn;
37
38    fn deref(&self) -> &Self::Target {
39        &self.scalar_fn
40    }
41}
42
43impl Expression {
44    /// Create a new expression node from a scalar_fn expression and its children.
45    pub fn try_new(
46        scalar_fn: ScalarFn,
47        children: impl IntoIterator<Item = Expression>,
48    ) -> VortexResult<Self> {
49        let children = Vec::from_iter(children);
50
51        vortex_ensure!(
52            scalar_fn.signature().arity().matches(children.len()),
53            "Expression arity mismatch: expected {} children but got {}",
54            scalar_fn.signature().arity(),
55            children.len()
56        );
57
58        Ok(Self {
59            scalar_fn,
60            children: children.into(),
61        })
62    }
63
64    /// Returns the scalar fn vtable for this expression.
65    pub fn scalar_fn(&self) -> &ScalarFn {
66        &self.scalar_fn
67    }
68
69    /// Returns the children of this expression.
70    pub fn children(&self) -> &Arc<Vec<Expression>> {
71        &self.children
72    }
73
74    /// Returns the n'th child of this expression.
75    pub fn child(&self, n: usize) -> &Expression {
76        &self.children[n]
77    }
78
79    /// Replace the children of this expression with the provided new children.
80    pub fn with_children(
81        mut self,
82        children: impl IntoIterator<Item = Expression>,
83    ) -> VortexResult<Self> {
84        let children = Vec::from_iter(children);
85        vortex_ensure!(
86            self.signature().arity().matches(children.len()),
87            "Expression arity mismatch: expected {} children but got {}",
88            self.signature().arity(),
89            children.len()
90        );
91        self.children = Arc::new(children);
92        Ok(self)
93    }
94
95    /// Computes the return dtype of this expression given the input dtype.
96    pub fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
97        if self.is::<Root>() {
98            return Ok(scope.clone());
99        }
100
101        let dtypes: Vec<_> = self
102            .children
103            .iter()
104            .map(|c| c.return_dtype(scope))
105            .try_collect()?;
106        self.scalar_fn.return_dtype(&dtypes)
107    }
108
109    /// Returns a new expression representing the validity mask output of this expression.
110    ///
111    /// The returned expression evaluates to a non-nullable boolean array.
112    pub fn validity(&self) -> VortexResult<Expression> {
113        self.scalar_fn.validity(self)
114    }
115
116    /// An expression over zone-statistics which implies all records in the zone evaluate to false.
117    ///
118    /// Given an expression, `e`, if `e.stat_falsification(..)` evaluates to true, it is guaranteed
119    /// that `e` evaluates to false on all records in the zone. However, the inverse is not
120    /// necessarily true: even if the falsification evaluates to false, `e` need not evaluate to
121    /// true on all records.
122    ///
123    /// The [`StatsCatalog`] can be used to constrain or rename stats used in the final expr.
124    ///
125    /// # Examples
126    ///
127    /// - An expression over one variable: `x > 0` is false for all records in a zone if the maximum
128    ///   value of the column `x` in that zone is less than or equal to zero: `max(x) <= 0`.
129    /// - An expression over two variables: `x > y` becomes `max(x) <= min(y)`.
130    /// - A conjunctive expression: `x > y AND z < x` becomes `max(x) <= min(y) OR min(z) >= max(x).
131    ///
132    /// Some expressions, in theory, have falsifications but this function does not support them
133    /// such as `x < (y < z)` or `x LIKE "needle%"`.
134    pub fn stat_falsification(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
135        self.vtable().as_dyn().stat_falsification(self, catalog)
136    }
137
138    /// Returns an expression representing the zoned statistic for the given stat, if available.
139    ///
140    /// The [`StatsCatalog`] returns expressions that can be evaluated using the zone map as a
141    /// scope. Expressions can implement this function to propagate such statistics through the
142    /// expression tree. For example, the `a + 10` expression could propagate `min: min(a) + 10`.
143    ///
144    /// NOTE(gatesn): we currently cannot represent statistics over nested fields. Please file an
145    /// issue to discuss a solution to this.
146    pub fn stat_expression(&self, stat: Stat, catalog: &dyn StatsCatalog) -> Option<Expression> {
147        self.vtable().as_dyn().stat_expression(self, stat, catalog)
148    }
149
150    /// Returns an expression representing the zoned maximum statistic, if available.
151    pub fn stat_min(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
152        self.stat_expression(Stat::Min, catalog)
153    }
154
155    /// Returns an expression representing the zoned maximum statistic, if available.
156    pub fn stat_max(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
157        self.stat_expression(Stat::Max, catalog)
158    }
159
160    /// Format the expression as a compact string.
161    ///
162    /// Since this is a recursive formatter, it is exposed on the public Expression type.
163    /// See fmt_data that is only implemented on the vtable trait.
164    pub fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
165        self.vtable().as_dyn().fmt_sql(self, f)
166    }
167
168    /// Display the expression as a formatted tree structure.
169    ///
170    /// This provides a hierarchical view of the expression that shows the relationships
171    /// between parent and child expressions, making complex nested expressions easier
172    /// to understand and debug.
173    ///
174    /// # Example
175    ///
176    /// ```rust
177    /// # use vortex_array::expr::LikeOptions;
178    /// # use vortex_array::expr::VTableExt;
179    /// # use vortex_dtype::{DType, Nullability, PType};
180    /// # use vortex_array::expr::{and, cast, eq, get_item, gt, lit, not, root, select, Like};
181    /// // Build a complex nested expression
182    /// let complex_expr = select(
183    ///     ["result"],
184    ///     and(
185    ///         not(eq(get_item("status", root()), lit("inactive"))),
186    ///         and(
187    ///             Like.new_expr(LikeOptions::default(), [get_item("name", root()), lit("%admin%")]),
188    ///             gt(
189    ///                 cast(get_item("score", root()), DType::Primitive(PType::F64, Nullability::NonNullable)),
190    ///                 lit(75.0)
191    ///             )
192    ///         )
193    ///     )
194    /// );
195    ///
196    /// println!("{}", complex_expr.display_tree());
197    /// ```
198    ///
199    /// This produces output like:
200    ///
201    /// ```text
202    /// Select(include): {result}
203    /// └── Binary(and)
204    ///     ├── lhs: Not
205    ///     │   └── Binary(=)
206    ///     │       ├── lhs: GetItem(status)
207    ///     │       │   └── Root
208    ///     │       └── rhs: Literal(value: "inactive", dtype: utf8)
209    ///     └── rhs: Binary(and)
210    ///         ├── lhs: Like
211    ///         │   ├── child: GetItem(name)
212    ///         │   │   └── Root
213    ///         │   └── pattern: Literal(value: "%admin%", dtype: utf8)
214    ///         └── rhs: Binary(>)
215    ///             ├── lhs: Cast(target: f64)
216    ///             │   └── GetItem(score)
217    ///             │       └── Root
218    ///             └── rhs: Literal(value: 75f64, dtype: f64)
219    /// ```
220    pub fn display_tree(&self) -> impl Display {
221        DisplayTreeExpr(self)
222    }
223}
224
225/// The default display implementation for expressions uses the 'SQL'-style format.
226impl Display for Expression {
227    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
228        self.fmt_sql(f)
229    }
230}
231
232/// Iterative drop for expression to avoid stack overflows.
233impl Drop for Expression {
234    fn drop(&mut self) {
235        if let Some(children) = Arc::get_mut(&mut self.children) {
236            let mut children_to_drop = std::mem::take(children);
237
238            while let Some(mut child) = children_to_drop.pop() {
239                if let Some(expr_children) = Arc::get_mut(&mut child.children) {
240                    children_to_drop.append(expr_children);
241                }
242            }
243        }
244    }
245}