Skip to main content

vortex_array/scalar_fn/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt;
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::sync::Arc;
11
12use arcref::ArcRef;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17use vortex_session::VortexSession;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::stats::Stat;
25use crate::scalar_fn::ScalarFn;
26use crate::scalar_fn::ScalarFnId;
27use crate::scalar_fn::ScalarFnRef;
28
29/// This trait defines the interface for scalar function vtables, including methods for
30/// serialization, deserialization, validation, child naming, return type computation,
31/// and evaluation.
32///
33/// This trait is non-object safe and allows the implementer to make use of associated types
34/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
35/// outputs of each function.
36///
37/// The [`ScalarFnVTable`] trait should be implemented for a struct that holds global data across
38/// all instances of the expression. In almost all cases, this struct will be an empty unit
39/// struct, since most expressions do not require any global state.
40pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync {
41    /// Options for this expression.
42    type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
43
44    /// Returns the ID of the scalar function vtable.
45    fn id(&self) -> ScalarFnId;
46
47    /// Serialize the options for this expression.
48    ///
49    /// Should return `Ok(None)` if the expression is not serializable, and `Ok(vec![])` if it is
50    /// serializable but has no metadata.
51    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
52        _ = options;
53        Ok(None)
54    }
55
56    /// Deserialize the options of this expression.
57    fn deserialize(
58        &self,
59        _metadata: &[u8],
60        _session: &VortexSession,
61    ) -> VortexResult<Self::Options> {
62        vortex_bail!("Expression {} is not deserializable", self.id());
63    }
64
65    /// Returns the arity of this expression.
66    fn arity(&self, options: &Self::Options) -> Arity;
67
68    /// Returns the name of the nth child of the expr.
69    fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName;
70
71    /// Format this expression in a nice human-readable SQL-style format
72    ///
73    /// The implementation should recursively format child expressions by calling
74    /// `expr.child(i).fmt_sql(f)`.
75    fn fmt_sql(
76        &self,
77        options: &Self::Options,
78        expr: &Expression,
79        f: &mut Formatter<'_>,
80    ) -> fmt::Result;
81
82    /// Coerce the arguments of this function.
83    ///
84    /// This is optionally used by Vortex users when performing type coercion over a Vortex
85    /// expression. Note that direct Vortex query engine integrations (e.g. DuckDB, DataFusion,
86    /// etc.) do not perform type coercion and rely on the engine's own logical planner.
87    ///
88    /// Note that the default implementation simply returns the arguments without coercion, and it
89    /// is expected that the [`ScalarFnVTable::return_dtype`] call may still fail.
90    fn coerce_args(&self, options: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
91        let _ = options;
92        Ok(args.to_vec())
93    }
94
95    /// Compute the return [`DType`] of the expression if evaluated over the given input types.
96    ///
97    /// # Preconditions
98    ///
99    /// The length of `args` must match the [`Arity`] of this function. Callers are responsible
100    /// for validating this (e.g., [`Expression::try_new`] checks arity at construction time).
101    /// Implementations may assume correct arity and will panic or return nonsensical results if
102    /// violated.
103    ///
104    /// [`Expression::try_new`]: crate::expr::Expression::try_new
105    fn return_dtype(&self, options: &Self::Options, args: &[DType]) -> VortexResult<DType>;
106
107    /// Execute the expression over the input arguments.
108    ///
109    /// Implementations are encouraged to check their inputs for constant arrays to perform
110    /// more optimized execution.
111    ///
112    /// If the input arguments cannot be directly used for execution (for example, an expression
113    /// may require canonical input arrays), then the implementation should perform a single
114    /// child execution and return a new [`crate::arrays::ScalarFnArray`] wrapping up the new child.
115    ///
116    /// This provides maximum opportunities for array-level optimizations using execute_parent
117    /// kernels.
118    fn execute(
119        &self,
120        options: &Self::Options,
121        args: &dyn ExecutionArgs,
122        ctx: &mut ExecutionCtx,
123    ) -> VortexResult<ArrayRef>;
124
125    /// Implement an abstract reduction rule over a tree of scalar functions.
126    ///
127    /// The [`ReduceNode`] can be used to traverse children, inspect their types, and
128    /// construct the result expression.
129    ///
130    /// Return `Ok(None)` if no reduction is possible.
131    fn reduce(
132        &self,
133        options: &Self::Options,
134        node: &dyn ReduceNode,
135        ctx: &dyn ReduceCtx,
136    ) -> VortexResult<Option<ReduceNodeRef>> {
137        _ = options;
138        _ = node;
139        _ = ctx;
140        Ok(None)
141    }
142
143    /// Simplify the expression if possible.
144    fn simplify(
145        &self,
146        options: &Self::Options,
147        expr: &Expression,
148        ctx: &dyn SimplifyCtx,
149    ) -> VortexResult<Option<Expression>> {
150        _ = options;
151        _ = expr;
152        _ = ctx;
153        Ok(None)
154    }
155
156    /// Simplify the expression if possible, without type information.
157    fn simplify_untyped(
158        &self,
159        options: &Self::Options,
160        expr: &Expression,
161    ) -> VortexResult<Option<Expression>> {
162        _ = options;
163        _ = expr;
164        Ok(None)
165    }
166
167    /// See [`Expression::stat_falsification`].
168    fn stat_falsification(
169        &self,
170        options: &Self::Options,
171        expr: &Expression,
172        catalog: &dyn StatsCatalog,
173    ) -> Option<Expression> {
174        _ = options;
175        _ = expr;
176        _ = catalog;
177        None
178    }
179
180    /// See [`Expression::stat_expression`].
181    fn stat_expression(
182        &self,
183        options: &Self::Options,
184        expr: &Expression,
185        stat: Stat,
186        catalog: &dyn StatsCatalog,
187    ) -> Option<Expression> {
188        _ = options;
189        _ = expr;
190        _ = stat;
191        _ = catalog;
192        None
193    }
194
195    /// Returns an expression that evaluates to the validity of the result of this expression.
196    ///
197    /// If a validity expression cannot be constructed, returns `None` and the expression will
198    /// be evaluated as normal before extracting the validity mask from the result.
199    ///
200    /// This is essentially a specialized form of a `reduce_parent`
201    fn validity(
202        &self,
203        options: &Self::Options,
204        expression: &Expression,
205    ) -> VortexResult<Option<Expression>> {
206        _ = (options, expression);
207        Ok(None)
208    }
209
210    /// Returns whether this expression itself is null-sensitive. Conservatively default to *true*.
211    ///
212    /// An expression is null-sensitive if it directly operates on null values,
213    /// such as `is_null`. Most expressions are not null-sensitive.
214    ///
215    /// The property we are interested in is if the expression (e) distributes over `mask`.
216    /// Define a `mask(a, m)` expression that applies the boolean array `m` to the validity of the
217    /// array `a`.
218    ///
219    /// A unary expression `e` is not null-sensitive iff forall arrays `a` and masks `m`,
220    /// `e(mask(a, m)) == mask(e(a), m)`.
221    ///
222    /// This can be extended to an n-ary expression.
223    ///
224    /// This method only checks the expression itself, not its children.
225    fn is_null_sensitive(&self, options: &Self::Options) -> bool {
226        _ = options;
227        true
228    }
229
230    /// Returns whether this expression is semantically fallible. Conservatively defaults to
231    /// `true`.
232    ///
233    /// An expression is semantically fallible if there exists a set of well-typed inputs that
234    /// causes the expression to produce an error as part of its _defined behavior_. For example,
235    /// `checked_add` is fallible because integer overflow is a domain error, and division is
236    /// fallible because of division by zero.
237    ///
238    /// This does **not** include execution errors that are incidental to the implementation, such
239    /// as canonicalization failures, memory allocation errors, or encoding mismatches. Those can
240    /// happen to any expression and are not what this method captures.
241    ///
242    /// This property is used by optimizations that speculatively evaluate an expression over values
243    /// that may not appear in the actual input. For example, pushing a scalar function down to a
244    /// dictionary's values array is only safe when the function is infallible or all values are
245    /// referenced, since a fallible function might error on a value left unreferenced after
246    /// slicing that would never be encountered during normal evaluation.
247    ///
248    /// Note: this is only applicable to expressions that pass type-checking via
249    /// [`ScalarFnVTable::return_dtype`].
250    fn is_fallible(&self, options: &Self::Options) -> bool {
251        _ = options;
252        true
253    }
254}
255
256/// Arguments for reduction rules.
257pub trait ReduceCtx {
258    /// Create a new reduction node from the given scalar function and children.
259    fn new_node(
260        &self,
261        scalar_fn: ScalarFnRef,
262        children: &[ReduceNodeRef],
263    ) -> VortexResult<ReduceNodeRef>;
264}
265
266pub type ReduceNodeRef = Arc<dyn ReduceNode>;
267
268/// A node used for implementing abstract reduction rules.
269pub trait ReduceNode {
270    /// Downcast to Any.
271    fn as_any(&self) -> &dyn Any;
272
273    /// Return the data type of this node.
274    fn node_dtype(&self) -> VortexResult<DType>;
275
276    /// Return this node's scalar function if it is indeed a scalar fn.
277    fn scalar_fn(&self) -> Option<&ScalarFnRef>;
278
279    /// Descend to the child of this handle.
280    fn child(&self, idx: usize) -> ReduceNodeRef;
281
282    /// Returns the number of children of this node.
283    fn child_count(&self) -> usize;
284}
285
286/// The arity (number of arguments) of a function.
287#[derive(Clone, Copy, Debug, PartialEq, Eq)]
288pub enum Arity {
289    Exact(usize),
290    Variadic { min: usize, max: Option<usize> },
291}
292
293impl Display for Arity {
294    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
295        match self {
296            Arity::Exact(n) => write!(f, "{}", n),
297            Arity::Variadic { min, max } => match max {
298                Some(max) if min == max => write!(f, "{}", min),
299                Some(max) => write!(f, "{}..{}", min, max),
300                None => write!(f, "{}+", min),
301            },
302        }
303    }
304}
305
306impl Arity {
307    /// Whether the given argument count matches this arity.
308    pub fn matches(&self, arg_count: usize) -> bool {
309        match self {
310            Arity::Exact(m) => *m == arg_count,
311            Arity::Variadic { min, max } => {
312                if arg_count < *min {
313                    return false;
314                }
315                if let Some(max) = max
316                    && arg_count > *max
317                {
318                    return false;
319                }
320                true
321            }
322        }
323    }
324}
325
326/// Context for simplification.
327///
328/// Used to lazily compute input data types where simplification requires them.
329pub trait SimplifyCtx {
330    /// Get the data type of the given expression.
331    fn return_dtype(&self, expr: &Expression) -> VortexResult<DType>;
332}
333
334/// Arguments for expression execution.
335pub trait ExecutionArgs {
336    /// Returns the input array at the given index.
337    fn get(&self, index: usize) -> VortexResult<ArrayRef>;
338
339    /// Returns the number of inputs.
340    fn num_inputs(&self) -> usize;
341
342    /// Returns the row count of the execution scope.
343    fn row_count(&self) -> usize;
344}
345
346/// A concrete [`ExecutionArgs`] backed by a `Vec<ArrayRef>`.
347pub struct VecExecutionArgs {
348    inputs: Vec<ArrayRef>,
349    row_count: usize,
350}
351
352impl VecExecutionArgs {
353    /// Create a new `VecExecutionArgs`.
354    pub fn new(inputs: Vec<ArrayRef>, row_count: usize) -> Self {
355        Self { inputs, row_count }
356    }
357}
358
359impl ExecutionArgs for VecExecutionArgs {
360    fn get(&self, index: usize) -> VortexResult<ArrayRef> {
361        self.inputs.get(index).cloned().ok_or_else(|| {
362            vortex_err!(
363                "Input index {} out of bounds (num_inputs={})",
364                index,
365                self.inputs.len()
366            )
367        })
368    }
369
370    fn num_inputs(&self) -> usize {
371        self.inputs.len()
372    }
373
374    fn row_count(&self) -> usize {
375        self.row_count
376    }
377}
378
379#[derive(Clone, Debug, PartialEq, Eq, Hash)]
380pub struct EmptyOptions;
381impl Display for EmptyOptions {
382    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
383        write!(f, "")
384    }
385}
386
387/// Factory functions for vtables.
388pub trait ScalarFnVTableExt: ScalarFnVTable {
389    /// Bind this vtable with the given options into a [`ScalarFnRef`].
390    fn bind(&self, options: Self::Options) -> ScalarFnRef {
391        ScalarFn::new(self.clone(), options).erased()
392    }
393
394    /// Create a new expression with this vtable and the given options and children.
395    fn new_expr(
396        &self,
397        options: Self::Options,
398        children: impl IntoIterator<Item = Expression>,
399    ) -> Expression {
400        Self::try_new_expr(self, options, children).vortex_expect("Failed to create expression")
401    }
402
403    /// Try to create a new expression with this vtable and the given options and children.
404    fn try_new_expr(
405        &self,
406        options: Self::Options,
407        children: impl IntoIterator<Item = Expression>,
408    ) -> VortexResult<Expression> {
409        Expression::try_new(self.bind(options), children)
410    }
411}
412impl<V: ScalarFnVTable> ScalarFnVTableExt for V {}
413
414/// A reference to the name of a child expression.
415pub type ChildName = ArcRef<str>;