Skip to main content

vortex_array/scalar_fn/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt;
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::sync::Arc;
11
12use arcref::ArcRef;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17use vortex_session::VortexSession;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::traversal::Node;
24use crate::scalar_fn::ScalarFnId;
25use crate::scalar_fn::ScalarFnRef;
26use crate::scalar_fn::TypedScalarFnInstance;
27
28/// This trait defines the interface for scalar function vtables, including methods for
29/// serialization, deserialization, validation, child naming, return type computation,
30/// and evaluation.
31///
32/// This trait is non-object safe and allows the implementer to make use of associated types
33/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
34/// outputs of each function.
35///
36/// The [`ScalarFnVTable`] trait should be implemented for a struct that holds global data across
37/// all instances of the expression. In almost all cases, this struct will be an empty unit
38/// struct, since most expressions do not require any global state.
39pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync {
40    /// Options for this expression.
41    type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
42
43    /// Returns the ID of the scalar function vtable.
44    fn id(&self) -> ScalarFnId;
45
46    /// Serialize the options for this expression.
47    ///
48    /// Should return `Ok(None)` if the expression is not serializable, and `Ok(vec![])` if it is
49    /// serializable but has no metadata.
50    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
51        _ = options;
52        Ok(None)
53    }
54
55    /// Deserialize the options of this expression.
56    fn deserialize(
57        &self,
58        _metadata: &[u8],
59        _session: &VortexSession,
60    ) -> VortexResult<Self::Options> {
61        vortex_bail!("Expression {} is not deserializable", self.id());
62    }
63
64    /// Returns the arity of this expression.
65    fn arity(&self, options: &Self::Options) -> Arity;
66
67    /// Returns the name of the nth child of the expr.
68    fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName;
69
70    /// Format this expression in a nice human-readable SQL-style format
71    ///
72    /// The implementation should recursively format child expressions by calling
73    /// `expr.child(i).fmt_sql(f)`.
74    fn fmt_sql(
75        &self,
76        options: &Self::Options,
77        expr: &Expression,
78        f: &mut Formatter<'_>,
79    ) -> fmt::Result {
80        write!(f, "{}(", self.id())?;
81        let nchildren = expr.children_count();
82        for (i, child) in expr.children().iter().enumerate() {
83            child.fmt_sql(f)?;
84            if i + 1 < nchildren {
85                write!(f, ", ")?;
86            }
87        }
88        let opts = format!("{}", options);
89        if !opts.is_empty() {
90            write!(f, ", opts={}", opts)?;
91        }
92        write!(f, ")")
93    }
94
95    /// Coerce the arguments of this function.
96    ///
97    /// This is optionally used by Vortex users when performing type coercion over a Vortex
98    /// expression. Note that direct Vortex query engine integrations (e.g. DuckDB, DataFusion,
99    /// etc.) do not perform type coercion and rely on the engine's own logical planner.
100    ///
101    /// Note that the default implementation simply returns the arguments without coercion, and it
102    /// is expected that the [`ScalarFnVTable::return_dtype`] call may still fail.
103    fn coerce_args(&self, options: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
104        let _ = options;
105        Ok(args.to_vec())
106    }
107
108    /// Compute the return [`DType`] of the expression if evaluated over the given input types.
109    ///
110    /// # Preconditions
111    ///
112    /// The length of `args` must match the [`Arity`] of this function. Callers are responsible
113    /// for validating this (e.g., [`Expression::try_new`] checks arity at construction time).
114    /// Implementations may assume correct arity and will panic or return nonsensical results if
115    /// violated.
116    ///
117    /// [`Expression::try_new`]: crate::expr::Expression::try_new
118    fn return_dtype(&self, options: &Self::Options, args: &[DType]) -> VortexResult<DType>;
119
120    /// Execute the expression over the input arguments.
121    ///
122    /// Implementations are encouraged to check their inputs for constant arrays to perform
123    /// more optimized execution.
124    ///
125    /// If the input arguments cannot be directly used for execution (for example, an expression
126    /// may require canonical input arrays), then the implementation should perform a single
127    /// child execution and return a new [`crate::arrays::ScalarFnArray`] wrapping up the new child.
128    ///
129    /// This provides maximum opportunities for array-level optimizations using execute_parent
130    /// kernels.
131    fn execute(
132        &self,
133        options: &Self::Options,
134        args: &dyn ExecutionArgs,
135        ctx: &mut ExecutionCtx,
136    ) -> VortexResult<ArrayRef>;
137
138    /// Implement an abstract reduction rule over a tree of scalar functions.
139    ///
140    /// The [`ReduceNode`] can be used to traverse children, inspect their types, and
141    /// construct the result expression.
142    ///
143    /// Return `Ok(None)` if no reduction is possible.
144    fn reduce(
145        &self,
146        options: &Self::Options,
147        node: &dyn ReduceNode,
148        ctx: &dyn ReduceCtx,
149    ) -> VortexResult<Option<ReduceNodeRef>> {
150        _ = options;
151        _ = node;
152        _ = ctx;
153        Ok(None)
154    }
155
156    /// Simplify the expression if possible.
157    fn simplify(
158        &self,
159        options: &Self::Options,
160        expr: &Expression,
161        ctx: &dyn SimplifyCtx,
162    ) -> VortexResult<Option<Expression>> {
163        _ = options;
164        _ = expr;
165        _ = ctx;
166        Ok(None)
167    }
168
169    /// Simplify the expression if possible, without type information.
170    fn simplify_untyped(
171        &self,
172        options: &Self::Options,
173        expr: &Expression,
174    ) -> VortexResult<Option<Expression>> {
175        _ = options;
176        _ = expr;
177        Ok(None)
178    }
179
180    /// Returns an expression that evaluates to the validity of the result of this expression.
181    ///
182    /// If a validity expression cannot be constructed, returns `None` and the expression will
183    /// be evaluated as normal before extracting the validity mask from the result.
184    ///
185    /// This is essentially a specialized form of a `reduce_parent`
186    fn validity(
187        &self,
188        options: &Self::Options,
189        expression: &Expression,
190    ) -> VortexResult<Option<Expression>> {
191        _ = (options, expression);
192        Ok(None)
193    }
194
195    /// Returns whether this expression itself is null-sensitive. Conservatively default to *true*.
196    ///
197    /// An expression is null-sensitive if it directly operates on null values,
198    /// such as `is_null`. Most expressions are not null-sensitive.
199    ///
200    /// The property we are interested in is if the expression (e) distributes over `mask`.
201    /// Define a `mask(a, m)` expression that applies the boolean array `m` to the validity of the
202    /// array `a`.
203    ///
204    /// A unary expression `e` is not null-sensitive iff forall arrays `a` and masks `m`,
205    /// `e(mask(a, m)) == mask(e(a), m)`.
206    ///
207    /// This can be extended to an n-ary expression.
208    ///
209    /// This method only checks the expression itself, not its children.
210    fn is_null_sensitive(&self, options: &Self::Options) -> bool {
211        _ = options;
212        true
213    }
214
215    /// Returns whether this expression is semantically fallible. Conservatively defaults to
216    /// `true`.
217    ///
218    /// An expression is semantically fallible if there exists a set of well-typed inputs that
219    /// causes the expression to produce an error as part of its _defined behavior_. For example,
220    /// `checked_add` is fallible because integer overflow is a domain error, and division is
221    /// fallible because of division by zero.
222    ///
223    /// This does **not** include execution errors that are incidental to the implementation, such
224    /// as canonicalization failures, memory allocation errors, or encoding mismatches. Those can
225    /// happen to any expression and are not what this method captures.
226    ///
227    /// This property is used by optimizations that speculatively evaluate an expression over values
228    /// that may not appear in the actual input. For example, pushing a scalar function down to a
229    /// dictionary's values array is only safe when the function is infallible or all values are
230    /// referenced, since a fallible function might error on a value left unreferenced after
231    /// slicing that would never be encountered during normal evaluation.
232    ///
233    /// Note: this is only applicable to expressions that pass type-checking via
234    /// [`ScalarFnVTable::return_dtype`].
235    fn is_fallible(&self, options: &Self::Options) -> bool {
236        _ = options;
237        true
238    }
239}
240
241/// Arguments for reduction rules.
242pub trait ReduceCtx {
243    /// Create a new reduction node from the given scalar function and children.
244    fn new_node(
245        &self,
246        scalar_fn: ScalarFnRef,
247        children: &[ReduceNodeRef],
248    ) -> VortexResult<ReduceNodeRef>;
249}
250
251pub type ReduceNodeRef = Arc<dyn ReduceNode>;
252
253/// A node used for implementing abstract reduction rules.
254pub trait ReduceNode {
255    /// Downcast to Any.
256    fn as_any(&self) -> &dyn Any;
257
258    /// Return the data type of this node.
259    fn node_dtype(&self) -> VortexResult<DType>;
260
261    /// Return this node's scalar function if it is indeed a scalar fn.
262    fn scalar_fn(&self) -> Option<&ScalarFnRef>;
263
264    /// Descend to the child of this handle.
265    fn child(&self, idx: usize) -> ReduceNodeRef;
266
267    /// Returns the number of children of this node.
268    fn child_count(&self) -> usize;
269}
270
271/// The arity (number of arguments) of a function.
272#[derive(Clone, Copy, Debug, PartialEq, Eq)]
273pub enum Arity {
274    Exact(usize),
275    Variadic { min: usize, max: Option<usize> },
276}
277
278impl Display for Arity {
279    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
280        match self {
281            Arity::Exact(n) => write!(f, "{}", n),
282            Arity::Variadic { min, max } => match max {
283                Some(max) if min == max => write!(f, "{}", min),
284                Some(max) => write!(f, "{}..{}", min, max),
285                None => write!(f, "{}+", min),
286            },
287        }
288    }
289}
290
291impl Arity {
292    /// Whether the given argument count matches this arity.
293    pub fn matches(&self, arg_count: usize) -> bool {
294        match self {
295            Arity::Exact(m) => *m == arg_count,
296            Arity::Variadic { min, max } => {
297                if arg_count < *min {
298                    return false;
299                }
300                if let Some(max) = max
301                    && arg_count > *max
302                {
303                    return false;
304                }
305                true
306            }
307        }
308    }
309}
310
311/// Context for simplification.
312///
313/// Used to lazily compute input data types where simplification requires them.
314pub trait SimplifyCtx {
315    /// Get the data type of the given expression.
316    fn return_dtype(&self, expr: &Expression) -> VortexResult<DType>;
317}
318
319/// Arguments for expression execution.
320pub trait ExecutionArgs {
321    /// Returns the input array at the given index.
322    fn get(&self, index: usize) -> VortexResult<ArrayRef>;
323
324    /// Returns the number of inputs.
325    fn num_inputs(&self) -> usize;
326
327    /// Returns the row count of the execution scope.
328    fn row_count(&self) -> usize;
329}
330
331/// A concrete [`ExecutionArgs`] backed by a `Vec<ArrayRef>`.
332pub struct VecExecutionArgs {
333    inputs: Vec<ArrayRef>,
334    row_count: usize,
335}
336
337impl VecExecutionArgs {
338    /// Create a new `VecExecutionArgs`.
339    pub fn new(inputs: Vec<ArrayRef>, row_count: usize) -> Self {
340        Self { inputs, row_count }
341    }
342}
343
344impl ExecutionArgs for VecExecutionArgs {
345    fn get(&self, index: usize) -> VortexResult<ArrayRef> {
346        self.inputs.get(index).cloned().ok_or_else(|| {
347            vortex_err!(
348                "Input index {} out of bounds (num_inputs={})",
349                index,
350                self.inputs.len()
351            )
352        })
353    }
354
355    fn num_inputs(&self) -> usize {
356        self.inputs.len()
357    }
358
359    fn row_count(&self) -> usize {
360        self.row_count
361    }
362}
363
364#[derive(Clone, Debug, PartialEq, Eq, Hash)]
365pub struct EmptyOptions;
366impl Display for EmptyOptions {
367    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
368        write!(f, "")
369    }
370}
371
372/// Factory functions for vtables.
373pub trait ScalarFnVTableExt: ScalarFnVTable {
374    /// Bind this vtable with the given options into a [`ScalarFnRef`].
375    fn bind(&self, options: Self::Options) -> ScalarFnRef {
376        TypedScalarFnInstance::new(self.clone(), options).erased()
377    }
378
379    /// Create a new expression with this vtable and the given options and children.
380    fn new_expr(
381        &self,
382        options: Self::Options,
383        children: impl IntoIterator<Item = Expression>,
384    ) -> Expression {
385        Self::try_new_expr(self, options, children).vortex_expect("Failed to create expression")
386    }
387
388    /// Try to create a new expression with this vtable and the given options and children.
389    fn try_new_expr(
390        &self,
391        options: Self::Options,
392        children: impl IntoIterator<Item = Expression>,
393    ) -> VortexResult<Expression> {
394        Expression::try_new(self.bind(options), children)
395    }
396}
397impl<V: ScalarFnVTable> ScalarFnVTableExt for V {}
398
399/// A reference to the name of a child expression.
400pub type ChildName = ArcRef<str>;