Skip to main content

vortex_array/scalar_fn/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt;
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::sync::Arc;
11
12use arcref::ArcRef;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_session::VortexSession;
17
18use crate::ArrayRef;
19use crate::ExecutionCtx;
20use crate::dtype::DType;
21use crate::expr::Expression;
22use crate::expr::StatsCatalog;
23use crate::expr::stats::Stat;
24use crate::scalar_fn::ScalarFn;
25use crate::scalar_fn::ScalarFnId;
26use crate::scalar_fn::ScalarFnRef;
27
28/// This trait defines the interface for scalar function vtables, including methods for
29/// serialization, deserialization, validation, child naming, return type computation,
30/// and evaluation.
31///
32/// This trait is non-object safe and allows the implementer to make use of associated types
33/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
34/// outputs of each function.
35///
36/// The [`ScalarFnVTable`] trait should be implemented for a struct that holds global data across
37/// all instances of the expression. In almost all cases, this struct will be an empty unit
38/// struct, since most expressions do not require any global state.
39pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync {
40    /// Options for this expression.
41    type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
42
43    /// Returns the ID of the scalar function vtable.
44    fn id(&self) -> ScalarFnId;
45
46    /// Serialize the options for this expression.
47    ///
48    /// Should return `Ok(None)` if the expression is not serializable, and `Ok(vec![])` if it is
49    /// serializable but has no metadata.
50    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
51        _ = options;
52        Ok(None)
53    }
54
55    /// Deserialize the options of this expression.
56    fn deserialize(
57        &self,
58        _metadata: &[u8],
59        _session: &VortexSession,
60    ) -> VortexResult<Self::Options> {
61        vortex_bail!("Expression {} is not deserializable", self.id());
62    }
63
64    /// Returns the arity of this expression.
65    fn arity(&self, options: &Self::Options) -> Arity;
66
67    /// Returns the name of the nth child of the expr.
68    fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName;
69
70    /// Format this expression in nice human-readable SQL-style format
71    ///
72    /// The implementation should recursively format child expressions by calling
73    /// `expr.child(i).fmt_sql(f)`.
74    fn fmt_sql(
75        &self,
76        options: &Self::Options,
77        expr: &Expression,
78        f: &mut Formatter<'_>,
79    ) -> fmt::Result;
80
81    /// Compute the return [`DType`] of the expression if evaluated over the given input types.
82    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType>;
83
84    /// Execute the expression over the input arguments.
85    ///
86    /// Implementations are encouraged to check their inputs for constant arrays to perform
87    /// more optimized execution.
88    ///
89    /// If the input arguments cannot be directly used for execution (for example, an expression
90    /// may require canonical input arrays), then the implementation should perform a single
91    /// child execution and return a new [`crate::arrays::ScalarFnArray`] wrapping up the new child.
92    ///
93    /// This provides maximum opportunities for array-level optimizations using execute_parent
94    /// kernels.
95    fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef>;
96
97    /// Implement an abstract reduction rule over a tree of scalar functions.
98    ///
99    /// The [`ReduceNode`] can be used to traverse children, inspect their types, and
100    /// construct the result expression.
101    ///
102    /// Return `Ok(None)` if no reduction is possible.
103    fn reduce(
104        &self,
105        options: &Self::Options,
106        node: &dyn ReduceNode,
107        ctx: &dyn ReduceCtx,
108    ) -> VortexResult<Option<ReduceNodeRef>> {
109        _ = options;
110        _ = node;
111        _ = ctx;
112        Ok(None)
113    }
114
115    /// Simplify the expression if possible.
116    fn simplify(
117        &self,
118        options: &Self::Options,
119        expr: &Expression,
120        ctx: &dyn SimplifyCtx,
121    ) -> VortexResult<Option<Expression>> {
122        _ = options;
123        _ = expr;
124        _ = ctx;
125        Ok(None)
126    }
127
128    /// Simplify the expression if possible, without type information.
129    fn simplify_untyped(
130        &self,
131        options: &Self::Options,
132        expr: &Expression,
133    ) -> VortexResult<Option<Expression>> {
134        _ = options;
135        _ = expr;
136        Ok(None)
137    }
138
139    /// See [`Expression::stat_falsification`].
140    fn stat_falsification(
141        &self,
142        options: &Self::Options,
143        expr: &Expression,
144        catalog: &dyn StatsCatalog,
145    ) -> Option<Expression> {
146        _ = options;
147        _ = expr;
148        _ = catalog;
149        None
150    }
151
152    /// See [`Expression::stat_expression`].
153    fn stat_expression(
154        &self,
155        options: &Self::Options,
156        expr: &Expression,
157        stat: Stat,
158        catalog: &dyn StatsCatalog,
159    ) -> Option<Expression> {
160        _ = options;
161        _ = expr;
162        _ = stat;
163        _ = catalog;
164        None
165    }
166
167    /// Returns an expression that evaluates to the validity of the result of this expression.
168    ///
169    /// If a validity expression cannot be constructed, returns `None` and the expression will
170    /// be evaluated as normal before extracting the validity mask from the result.
171    ///
172    /// This is essentially a specialized form of a `reduce_parent`
173    fn validity(
174        &self,
175        options: &Self::Options,
176        expression: &Expression,
177    ) -> VortexResult<Option<Expression>> {
178        _ = (options, expression);
179        Ok(None)
180    }
181
182    /// Returns whether this expression itself is null-sensitive. Conservatively default to *true*.
183    ///
184    /// An expression is null-sensitive if it directly operates on null values,
185    /// such as `is_null`. Most expressions are not null-sensitive.
186    ///
187    /// The property we are interested in is if the expression (e) distributes over `mask`.
188    /// Define a `mask(a, m)` expression that applies the boolean array `m` to the validity of the
189    /// array `a`.
190    ///
191    /// A unary expression `e` is not null-sensitive iff forall arrays `a` and masks `m`,
192    /// `e(mask(a, m)) == mask(e(a), m)`.
193    ///
194    /// This can be extended to an n-ary expression.
195    ///
196    /// This method only checks the expression itself, not its children.
197    fn is_null_sensitive(&self, options: &Self::Options) -> bool {
198        _ = options;
199        true
200    }
201
202    /// Returns whether this expression itself is fallible. Conservatively default to *true*.
203    ///
204    /// An expression is runtime fallible is there is an input set that causes the expression to
205    /// panic or return an error, for example checked_add is fallible if there is overflow.
206    ///
207    /// Note: this is only applicable to expressions that pass type-checking
208    /// [`ScalarFnVTable::return_dtype`].
209    fn is_fallible(&self, options: &Self::Options) -> bool {
210        _ = options;
211        true
212    }
213}
214
215/// Arguments for reduction rules.
216pub trait ReduceCtx {
217    /// Create a new reduction node from the given scalar function and children.
218    fn new_node(
219        &self,
220        scalar_fn: ScalarFnRef,
221        children: &[ReduceNodeRef],
222    ) -> VortexResult<ReduceNodeRef>;
223}
224
225pub type ReduceNodeRef = Arc<dyn ReduceNode>;
226
227/// A node used for implementing abstract reduction rules.
228pub trait ReduceNode {
229    /// Downcast to Any.
230    fn as_any(&self) -> &dyn Any;
231
232    /// Return the data type of this node.
233    fn node_dtype(&self) -> VortexResult<DType>;
234
235    /// Return this node's scalar function if it is indeed a scalar fn.
236    fn scalar_fn(&self) -> Option<&ScalarFnRef>;
237
238    /// Descend to the child of this handle.
239    fn child(&self, idx: usize) -> ReduceNodeRef;
240
241    /// Returns the number of children of this node.
242    fn child_count(&self) -> usize;
243}
244
245/// The arity (number of arguments) of a function.
246#[derive(Clone, Copy, Debug, PartialEq, Eq)]
247pub enum Arity {
248    Exact(usize),
249    Variadic { min: usize, max: Option<usize> },
250}
251
252impl Display for Arity {
253    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
254        match self {
255            Arity::Exact(n) => write!(f, "{}", n),
256            Arity::Variadic { min, max } => match max {
257                Some(max) if min == max => write!(f, "{}", min),
258                Some(max) => write!(f, "{}..{}", min, max),
259                None => write!(f, "{}+", min),
260            },
261        }
262    }
263}
264
265impl Arity {
266    /// Whether the given argument count matches this arity.
267    pub fn matches(&self, arg_count: usize) -> bool {
268        match self {
269            Arity::Exact(m) => *m == arg_count,
270            Arity::Variadic { min, max } => {
271                if arg_count < *min {
272                    return false;
273                }
274                if let Some(max) = max
275                    && arg_count > *max
276                {
277                    return false;
278                }
279                true
280            }
281        }
282    }
283}
284
285/// Context for simplification.
286///
287/// Used to lazily compute input data types where simplification requires them.
288pub trait SimplifyCtx {
289    /// Get the data type of the given expression.
290    fn return_dtype(&self, expr: &Expression) -> VortexResult<DType>;
291}
292
293/// Arguments for expression execution.
294pub struct ExecutionArgs<'a> {
295    /// The inputs for the expression, one per child.
296    pub inputs: Vec<ArrayRef>,
297    /// The row count of the execution scope.
298    pub row_count: usize,
299    /// The execution context.
300    pub ctx: &'a mut ExecutionCtx,
301}
302
303#[derive(Clone, Debug, PartialEq, Eq, Hash)]
304pub struct EmptyOptions;
305impl Display for EmptyOptions {
306    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
307        write!(f, "")
308    }
309}
310
311/// Factory functions for vtables.
312pub trait ScalarFnVTableExt: ScalarFnVTable {
313    /// Bind this vtable with the given options into a [`ScalarFnRef`].
314    fn bind(&self, options: Self::Options) -> ScalarFnRef {
315        ScalarFn::new(self.clone(), options).erased()
316    }
317
318    /// Create a new expression with this vtable and the given options and children.
319    fn new_expr(
320        &self,
321        options: Self::Options,
322        children: impl IntoIterator<Item = Expression>,
323    ) -> Expression {
324        Self::try_new_expr(self, options, children).vortex_expect("Failed to create expression")
325    }
326
327    /// Try to create a new expression with this vtable and the given options and children.
328    fn try_new_expr(
329        &self,
330        options: Self::Options,
331        children: impl IntoIterator<Item = Expression>,
332    ) -> VortexResult<Expression> {
333        Expression::try_new(self.bind(options), children)
334    }
335}
336impl<V: ScalarFnVTable> ScalarFnVTableExt for V {}
337
338/// A reference to the name of a child expression.
339pub type ChildName = ArcRef<str>;