Skip to main content

vortex_array/scalar_fn/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt;
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::sync::Arc;
11
12use arcref::ArcRef;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17use vortex_session::VortexSession;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::stats::Stat;
25use crate::scalar_fn::ScalarFn;
26use crate::scalar_fn::ScalarFnId;
27use crate::scalar_fn::ScalarFnRef;
28
29/// This trait defines the interface for scalar function vtables, including methods for
30/// serialization, deserialization, validation, child naming, return type computation,
31/// and evaluation.
32///
33/// This trait is non-object safe and allows the implementer to make use of associated types
34/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
35/// outputs of each function.
36///
37/// The [`ScalarFnVTable`] trait should be implemented for a struct that holds global data across
38/// all instances of the expression. In almost all cases, this struct will be an empty unit
39/// struct, since most expressions do not require any global state.
40pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync {
41    /// Options for this expression.
42    type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
43
44    /// Returns the ID of the scalar function vtable.
45    fn id(&self) -> ScalarFnId;
46
47    /// Serialize the options for this expression.
48    ///
49    /// Should return `Ok(None)` if the expression is not serializable, and `Ok(vec![])` if it is
50    /// serializable but has no metadata.
51    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
52        _ = options;
53        Ok(None)
54    }
55
56    /// Deserialize the options of this expression.
57    fn deserialize(
58        &self,
59        _metadata: &[u8],
60        _session: &VortexSession,
61    ) -> VortexResult<Self::Options> {
62        vortex_bail!("Expression {} is not deserializable", self.id());
63    }
64
65    /// Returns the arity of this expression.
66    fn arity(&self, options: &Self::Options) -> Arity;
67
68    /// Returns the name of the nth child of the expr.
69    fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName;
70
71    /// Format this expression in a nice human-readable SQL-style format
72    ///
73    /// The implementation should recursively format child expressions by calling
74    /// `expr.child(i).fmt_sql(f)`.
75    fn fmt_sql(
76        &self,
77        options: &Self::Options,
78        expr: &Expression,
79        f: &mut Formatter<'_>,
80    ) -> fmt::Result;
81
82    /// Coerce the arguments of this function.
83    ///
84    /// This is optionally used by Vortex users when performing type coercion over a Vortex
85    /// expression. Note that direct Vortex query engine integrations (e.g. DuckDB, DataFusion,
86    /// etc.) do not perform type coercion and rely on the engine's own logical planner.
87    ///
88    /// Note that the default implementation simply returns the arguments without coercion, and it
89    /// is expected that the [`ScalarFnVTable::return_dtype`] call may still fail.
90    fn coerce_args(&self, options: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
91        let _ = options;
92        Ok(args.to_vec())
93    }
94
95    /// Compute the return [`DType`] of the expression if evaluated over the given input types.
96    fn return_dtype(&self, options: &Self::Options, args: &[DType]) -> VortexResult<DType>;
97
98    /// Execute the expression over the input arguments.
99    ///
100    /// Implementations are encouraged to check their inputs for constant arrays to perform
101    /// more optimized execution.
102    ///
103    /// If the input arguments cannot be directly used for execution (for example, an expression
104    /// may require canonical input arrays), then the implementation should perform a single
105    /// child execution and return a new [`crate::arrays::ScalarFnArray`] wrapping up the new child.
106    ///
107    /// This provides maximum opportunities for array-level optimizations using execute_parent
108    /// kernels.
109    fn execute(
110        &self,
111        options: &Self::Options,
112        args: &dyn ExecutionArgs,
113        ctx: &mut ExecutionCtx,
114    ) -> VortexResult<ArrayRef>;
115
116    /// Implement an abstract reduction rule over a tree of scalar functions.
117    ///
118    /// The [`ReduceNode`] can be used to traverse children, inspect their types, and
119    /// construct the result expression.
120    ///
121    /// Return `Ok(None)` if no reduction is possible.
122    fn reduce(
123        &self,
124        options: &Self::Options,
125        node: &dyn ReduceNode,
126        ctx: &dyn ReduceCtx,
127    ) -> VortexResult<Option<ReduceNodeRef>> {
128        _ = options;
129        _ = node;
130        _ = ctx;
131        Ok(None)
132    }
133
134    /// Simplify the expression if possible.
135    fn simplify(
136        &self,
137        options: &Self::Options,
138        expr: &Expression,
139        ctx: &dyn SimplifyCtx,
140    ) -> VortexResult<Option<Expression>> {
141        _ = options;
142        _ = expr;
143        _ = ctx;
144        Ok(None)
145    }
146
147    /// Simplify the expression if possible, without type information.
148    fn simplify_untyped(
149        &self,
150        options: &Self::Options,
151        expr: &Expression,
152    ) -> VortexResult<Option<Expression>> {
153        _ = options;
154        _ = expr;
155        Ok(None)
156    }
157
158    /// See [`Expression::stat_falsification`].
159    fn stat_falsification(
160        &self,
161        options: &Self::Options,
162        expr: &Expression,
163        catalog: &dyn StatsCatalog,
164    ) -> Option<Expression> {
165        _ = options;
166        _ = expr;
167        _ = catalog;
168        None
169    }
170
171    /// See [`Expression::stat_expression`].
172    fn stat_expression(
173        &self,
174        options: &Self::Options,
175        expr: &Expression,
176        stat: Stat,
177        catalog: &dyn StatsCatalog,
178    ) -> Option<Expression> {
179        _ = options;
180        _ = expr;
181        _ = stat;
182        _ = catalog;
183        None
184    }
185
186    /// Returns an expression that evaluates to the validity of the result of this expression.
187    ///
188    /// If a validity expression cannot be constructed, returns `None` and the expression will
189    /// be evaluated as normal before extracting the validity mask from the result.
190    ///
191    /// This is essentially a specialized form of a `reduce_parent`
192    fn validity(
193        &self,
194        options: &Self::Options,
195        expression: &Expression,
196    ) -> VortexResult<Option<Expression>> {
197        _ = (options, expression);
198        Ok(None)
199    }
200
201    /// Returns whether this expression itself is null-sensitive. Conservatively default to *true*.
202    ///
203    /// An expression is null-sensitive if it directly operates on null values,
204    /// such as `is_null`. Most expressions are not null-sensitive.
205    ///
206    /// The property we are interested in is if the expression (e) distributes over `mask`.
207    /// Define a `mask(a, m)` expression that applies the boolean array `m` to the validity of the
208    /// array `a`.
209    ///
210    /// A unary expression `e` is not null-sensitive iff forall arrays `a` and masks `m`,
211    /// `e(mask(a, m)) == mask(e(a), m)`.
212    ///
213    /// This can be extended to an n-ary expression.
214    ///
215    /// This method only checks the expression itself, not its children.
216    fn is_null_sensitive(&self, options: &Self::Options) -> bool {
217        _ = options;
218        true
219    }
220
221    /// Returns whether this expression itself is fallible. Conservatively default to *true*.
222    ///
223    /// An expression is runtime fallible is there is an input set that causes the expression to
224    /// panic or return an error, for example checked_add is fallible if there is overflow.
225    ///
226    /// Note: this is only applicable to expressions that pass type-checking
227    /// [`ScalarFnVTable::return_dtype`].
228    fn is_fallible(&self, options: &Self::Options) -> bool {
229        _ = options;
230        true
231    }
232}
233
234/// Arguments for reduction rules.
235pub trait ReduceCtx {
236    /// Create a new reduction node from the given scalar function and children.
237    fn new_node(
238        &self,
239        scalar_fn: ScalarFnRef,
240        children: &[ReduceNodeRef],
241    ) -> VortexResult<ReduceNodeRef>;
242}
243
244pub type ReduceNodeRef = Arc<dyn ReduceNode>;
245
246/// A node used for implementing abstract reduction rules.
247pub trait ReduceNode {
248    /// Downcast to Any.
249    fn as_any(&self) -> &dyn Any;
250
251    /// Return the data type of this node.
252    fn node_dtype(&self) -> VortexResult<DType>;
253
254    /// Return this node's scalar function if it is indeed a scalar fn.
255    fn scalar_fn(&self) -> Option<&ScalarFnRef>;
256
257    /// Descend to the child of this handle.
258    fn child(&self, idx: usize) -> ReduceNodeRef;
259
260    /// Returns the number of children of this node.
261    fn child_count(&self) -> usize;
262}
263
264/// The arity (number of arguments) of a function.
265#[derive(Clone, Copy, Debug, PartialEq, Eq)]
266pub enum Arity {
267    Exact(usize),
268    Variadic { min: usize, max: Option<usize> },
269}
270
271impl Display for Arity {
272    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
273        match self {
274            Arity::Exact(n) => write!(f, "{}", n),
275            Arity::Variadic { min, max } => match max {
276                Some(max) if min == max => write!(f, "{}", min),
277                Some(max) => write!(f, "{}..{}", min, max),
278                None => write!(f, "{}+", min),
279            },
280        }
281    }
282}
283
284impl Arity {
285    /// Whether the given argument count matches this arity.
286    pub fn matches(&self, arg_count: usize) -> bool {
287        match self {
288            Arity::Exact(m) => *m == arg_count,
289            Arity::Variadic { min, max } => {
290                if arg_count < *min {
291                    return false;
292                }
293                if let Some(max) = max
294                    && arg_count > *max
295                {
296                    return false;
297                }
298                true
299            }
300        }
301    }
302}
303
304/// Context for simplification.
305///
306/// Used to lazily compute input data types where simplification requires them.
307pub trait SimplifyCtx {
308    /// Get the data type of the given expression.
309    fn return_dtype(&self, expr: &Expression) -> VortexResult<DType>;
310}
311
312/// Arguments for expression execution.
313pub trait ExecutionArgs {
314    /// Returns the input array at the given index.
315    fn get(&self, index: usize) -> VortexResult<ArrayRef>;
316
317    /// Returns the number of inputs.
318    fn num_inputs(&self) -> usize;
319
320    /// Returns the row count of the execution scope.
321    fn row_count(&self) -> usize;
322}
323
324/// A concrete [`ExecutionArgs`] backed by a `Vec<ArrayRef>`.
325pub struct VecExecutionArgs {
326    inputs: Vec<ArrayRef>,
327    row_count: usize,
328}
329
330impl VecExecutionArgs {
331    /// Create a new `VecExecutionArgs`.
332    pub fn new(inputs: Vec<ArrayRef>, row_count: usize) -> Self {
333        Self { inputs, row_count }
334    }
335}
336
337impl ExecutionArgs for VecExecutionArgs {
338    fn get(&self, index: usize) -> VortexResult<ArrayRef> {
339        self.inputs.get(index).cloned().ok_or_else(|| {
340            vortex_err!(
341                "Input index {} out of bounds (num_inputs={})",
342                index,
343                self.inputs.len()
344            )
345        })
346    }
347
348    fn num_inputs(&self) -> usize {
349        self.inputs.len()
350    }
351
352    fn row_count(&self) -> usize {
353        self.row_count
354    }
355}
356
357#[derive(Clone, Debug, PartialEq, Eq, Hash)]
358pub struct EmptyOptions;
359impl Display for EmptyOptions {
360    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
361        write!(f, "")
362    }
363}
364
365/// Factory functions for vtables.
366pub trait ScalarFnVTableExt: ScalarFnVTable {
367    /// Bind this vtable with the given options into a [`ScalarFnRef`].
368    fn bind(&self, options: Self::Options) -> ScalarFnRef {
369        ScalarFn::new(self.clone(), options).erased()
370    }
371
372    /// Create a new expression with this vtable and the given options and children.
373    fn new_expr(
374        &self,
375        options: Self::Options,
376        children: impl IntoIterator<Item = Expression>,
377    ) -> Expression {
378        Self::try_new_expr(self, options, children).vortex_expect("Failed to create expression")
379    }
380
381    /// Try to create a new expression with this vtable and the given options and children.
382    fn try_new_expr(
383        &self,
384        options: Self::Options,
385        children: impl IntoIterator<Item = Expression>,
386    ) -> VortexResult<Expression> {
387        Expression::try_new(self.bind(options), children)
388    }
389}
390impl<V: ScalarFnVTable> ScalarFnVTableExt for V {}
391
392/// A reference to the name of a child expression.
393pub type ChildName = ArcRef<str>;