Skip to main content

vortex_array/scalar_fn/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt;
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::sync::Arc;
11
12use arcref::ArcRef;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17use vortex_session::VortexSession;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::stats::Stat;
25use crate::expr::traversal::Node;
26use crate::scalar_fn::ScalarFnId;
27use crate::scalar_fn::ScalarFnRef;
28use crate::scalar_fn::TypedScalarFnInstance;
29
30/// This trait defines the interface for scalar function vtables, including methods for
31/// serialization, deserialization, validation, child naming, return type computation,
32/// and evaluation.
33///
34/// This trait is non-object safe and allows the implementer to make use of associated types
35/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
36/// outputs of each function.
37///
38/// The [`ScalarFnVTable`] trait should be implemented for a struct that holds global data across
39/// all instances of the expression. In almost all cases, this struct will be an empty unit
40/// struct, since most expressions do not require any global state.
41pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync {
42    /// Options for this expression.
43    type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
44
45    /// Returns the ID of the scalar function vtable.
46    fn id(&self) -> ScalarFnId;
47
48    /// Serialize the options for this expression.
49    ///
50    /// Should return `Ok(None)` if the expression is not serializable, and `Ok(vec![])` if it is
51    /// serializable but has no metadata.
52    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
53        _ = options;
54        Ok(None)
55    }
56
57    /// Deserialize the options of this expression.
58    fn deserialize(
59        &self,
60        _metadata: &[u8],
61        _session: &VortexSession,
62    ) -> VortexResult<Self::Options> {
63        vortex_bail!("Expression {} is not deserializable", self.id());
64    }
65
66    /// Returns the arity of this expression.
67    fn arity(&self, options: &Self::Options) -> Arity;
68
69    /// Returns the name of the nth child of the expr.
70    fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName;
71
72    /// Format this expression in a nice human-readable SQL-style format
73    ///
74    /// The implementation should recursively format child expressions by calling
75    /// `expr.child(i).fmt_sql(f)`.
76    fn fmt_sql(
77        &self,
78        options: &Self::Options,
79        expr: &Expression,
80        f: &mut Formatter<'_>,
81    ) -> fmt::Result {
82        write!(f, "{}(", self.id())?;
83        let nchildren = expr.children_count();
84        for (i, child) in expr.children().iter().enumerate() {
85            child.fmt_sql(f)?;
86            if i + 1 < nchildren {
87                write!(f, ", ")?;
88            }
89        }
90        let opts = format!("{}", options);
91        if !opts.is_empty() {
92            write!(f, ", opts={}", opts)?;
93        }
94        write!(f, ")")
95    }
96
97    /// Coerce the arguments of this function.
98    ///
99    /// This is optionally used by Vortex users when performing type coercion over a Vortex
100    /// expression. Note that direct Vortex query engine integrations (e.g. DuckDB, DataFusion,
101    /// etc.) do not perform type coercion and rely on the engine's own logical planner.
102    ///
103    /// Note that the default implementation simply returns the arguments without coercion, and it
104    /// is expected that the [`ScalarFnVTable::return_dtype`] call may still fail.
105    fn coerce_args(&self, options: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
106        let _ = options;
107        Ok(args.to_vec())
108    }
109
110    /// Compute the return [`DType`] of the expression if evaluated over the given input types.
111    ///
112    /// # Preconditions
113    ///
114    /// The length of `args` must match the [`Arity`] of this function. Callers are responsible
115    /// for validating this (e.g., [`Expression::try_new`] checks arity at construction time).
116    /// Implementations may assume correct arity and will panic or return nonsensical results if
117    /// violated.
118    ///
119    /// [`Expression::try_new`]: crate::expr::Expression::try_new
120    fn return_dtype(&self, options: &Self::Options, args: &[DType]) -> VortexResult<DType>;
121
122    /// Execute the expression over the input arguments.
123    ///
124    /// Implementations are encouraged to check their inputs for constant arrays to perform
125    /// more optimized execution.
126    ///
127    /// If the input arguments cannot be directly used for execution (for example, an expression
128    /// may require canonical input arrays), then the implementation should perform a single
129    /// child execution and return a new [`crate::arrays::ScalarFnArray`] wrapping up the new child.
130    ///
131    /// This provides maximum opportunities for array-level optimizations using execute_parent
132    /// kernels.
133    fn execute(
134        &self,
135        options: &Self::Options,
136        args: &dyn ExecutionArgs,
137        ctx: &mut ExecutionCtx,
138    ) -> VortexResult<ArrayRef>;
139
140    /// Implement an abstract reduction rule over a tree of scalar functions.
141    ///
142    /// The [`ReduceNode`] can be used to traverse children, inspect their types, and
143    /// construct the result expression.
144    ///
145    /// Return `Ok(None)` if no reduction is possible.
146    fn reduce(
147        &self,
148        options: &Self::Options,
149        node: &dyn ReduceNode,
150        ctx: &dyn ReduceCtx,
151    ) -> VortexResult<Option<ReduceNodeRef>> {
152        _ = options;
153        _ = node;
154        _ = ctx;
155        Ok(None)
156    }
157
158    /// Simplify the expression if possible.
159    fn simplify(
160        &self,
161        options: &Self::Options,
162        expr: &Expression,
163        ctx: &dyn SimplifyCtx,
164    ) -> VortexResult<Option<Expression>> {
165        _ = options;
166        _ = expr;
167        _ = ctx;
168        Ok(None)
169    }
170
171    /// Simplify the expression if possible, without type information.
172    fn simplify_untyped(
173        &self,
174        options: &Self::Options,
175        expr: &Expression,
176    ) -> VortexResult<Option<Expression>> {
177        _ = options;
178        _ = expr;
179        Ok(None)
180    }
181
182    /// See [`Expression::stat_falsification`].
183    fn stat_falsification(
184        &self,
185        options: &Self::Options,
186        expr: &Expression,
187        catalog: &dyn StatsCatalog,
188    ) -> Option<Expression> {
189        _ = options;
190        _ = expr;
191        _ = catalog;
192        None
193    }
194
195    /// See [`Expression::stat_expression`].
196    fn stat_expression(
197        &self,
198        options: &Self::Options,
199        expr: &Expression,
200        stat: Stat,
201        catalog: &dyn StatsCatalog,
202    ) -> Option<Expression> {
203        _ = options;
204        _ = expr;
205        _ = stat;
206        _ = catalog;
207        None
208    }
209
210    /// Returns an expression that evaluates to the validity of the result of this expression.
211    ///
212    /// If a validity expression cannot be constructed, returns `None` and the expression will
213    /// be evaluated as normal before extracting the validity mask from the result.
214    ///
215    /// This is essentially a specialized form of a `reduce_parent`
216    fn validity(
217        &self,
218        options: &Self::Options,
219        expression: &Expression,
220    ) -> VortexResult<Option<Expression>> {
221        _ = (options, expression);
222        Ok(None)
223    }
224
225    /// Returns whether this expression itself is null-sensitive. Conservatively default to *true*.
226    ///
227    /// An expression is null-sensitive if it directly operates on null values,
228    /// such as `is_null`. Most expressions are not null-sensitive.
229    ///
230    /// The property we are interested in is if the expression (e) distributes over `mask`.
231    /// Define a `mask(a, m)` expression that applies the boolean array `m` to the validity of the
232    /// array `a`.
233    ///
234    /// A unary expression `e` is not null-sensitive iff forall arrays `a` and masks `m`,
235    /// `e(mask(a, m)) == mask(e(a), m)`.
236    ///
237    /// This can be extended to an n-ary expression.
238    ///
239    /// This method only checks the expression itself, not its children.
240    fn is_null_sensitive(&self, options: &Self::Options) -> bool {
241        _ = options;
242        true
243    }
244
245    /// Returns whether this expression is semantically fallible. Conservatively defaults to
246    /// `true`.
247    ///
248    /// An expression is semantically fallible if there exists a set of well-typed inputs that
249    /// causes the expression to produce an error as part of its _defined behavior_. For example,
250    /// `checked_add` is fallible because integer overflow is a domain error, and division is
251    /// fallible because of division by zero.
252    ///
253    /// This does **not** include execution errors that are incidental to the implementation, such
254    /// as canonicalization failures, memory allocation errors, or encoding mismatches. Those can
255    /// happen to any expression and are not what this method captures.
256    ///
257    /// This property is used by optimizations that speculatively evaluate an expression over values
258    /// that may not appear in the actual input. For example, pushing a scalar function down to a
259    /// dictionary's values array is only safe when the function is infallible or all values are
260    /// referenced, since a fallible function might error on a value left unreferenced after
261    /// slicing that would never be encountered during normal evaluation.
262    ///
263    /// Note: this is only applicable to expressions that pass type-checking via
264    /// [`ScalarFnVTable::return_dtype`].
265    fn is_fallible(&self, options: &Self::Options) -> bool {
266        _ = options;
267        true
268    }
269}
270
271/// Arguments for reduction rules.
272pub trait ReduceCtx {
273    /// Create a new reduction node from the given scalar function and children.
274    fn new_node(
275        &self,
276        scalar_fn: ScalarFnRef,
277        children: &[ReduceNodeRef],
278    ) -> VortexResult<ReduceNodeRef>;
279}
280
281pub type ReduceNodeRef = Arc<dyn ReduceNode>;
282
283/// A node used for implementing abstract reduction rules.
284pub trait ReduceNode {
285    /// Downcast to Any.
286    fn as_any(&self) -> &dyn Any;
287
288    /// Return the data type of this node.
289    fn node_dtype(&self) -> VortexResult<DType>;
290
291    /// Return this node's scalar function if it is indeed a scalar fn.
292    fn scalar_fn(&self) -> Option<&ScalarFnRef>;
293
294    /// Descend to the child of this handle.
295    fn child(&self, idx: usize) -> ReduceNodeRef;
296
297    /// Returns the number of children of this node.
298    fn child_count(&self) -> usize;
299}
300
301/// The arity (number of arguments) of a function.
302#[derive(Clone, Copy, Debug, PartialEq, Eq)]
303pub enum Arity {
304    Exact(usize),
305    Variadic { min: usize, max: Option<usize> },
306}
307
308impl Display for Arity {
309    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
310        match self {
311            Arity::Exact(n) => write!(f, "{}", n),
312            Arity::Variadic { min, max } => match max {
313                Some(max) if min == max => write!(f, "{}", min),
314                Some(max) => write!(f, "{}..{}", min, max),
315                None => write!(f, "{}+", min),
316            },
317        }
318    }
319}
320
321impl Arity {
322    /// Whether the given argument count matches this arity.
323    pub fn matches(&self, arg_count: usize) -> bool {
324        match self {
325            Arity::Exact(m) => *m == arg_count,
326            Arity::Variadic { min, max } => {
327                if arg_count < *min {
328                    return false;
329                }
330                if let Some(max) = max
331                    && arg_count > *max
332                {
333                    return false;
334                }
335                true
336            }
337        }
338    }
339}
340
341/// Context for simplification.
342///
343/// Used to lazily compute input data types where simplification requires them.
344pub trait SimplifyCtx {
345    /// Get the data type of the given expression.
346    fn return_dtype(&self, expr: &Expression) -> VortexResult<DType>;
347}
348
349/// Arguments for expression execution.
350pub trait ExecutionArgs {
351    /// Returns the input array at the given index.
352    fn get(&self, index: usize) -> VortexResult<ArrayRef>;
353
354    /// Returns the number of inputs.
355    fn num_inputs(&self) -> usize;
356
357    /// Returns the row count of the execution scope.
358    fn row_count(&self) -> usize;
359}
360
361/// A concrete [`ExecutionArgs`] backed by a `Vec<ArrayRef>`.
362pub struct VecExecutionArgs {
363    inputs: Vec<ArrayRef>,
364    row_count: usize,
365}
366
367impl VecExecutionArgs {
368    /// Create a new `VecExecutionArgs`.
369    pub fn new(inputs: Vec<ArrayRef>, row_count: usize) -> Self {
370        Self { inputs, row_count }
371    }
372}
373
374impl ExecutionArgs for VecExecutionArgs {
375    fn get(&self, index: usize) -> VortexResult<ArrayRef> {
376        self.inputs.get(index).cloned().ok_or_else(|| {
377            vortex_err!(
378                "Input index {} out of bounds (num_inputs={})",
379                index,
380                self.inputs.len()
381            )
382        })
383    }
384
385    fn num_inputs(&self) -> usize {
386        self.inputs.len()
387    }
388
389    fn row_count(&self) -> usize {
390        self.row_count
391    }
392}
393
394#[derive(Clone, Debug, PartialEq, Eq, Hash)]
395pub struct EmptyOptions;
396impl Display for EmptyOptions {
397    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
398        write!(f, "")
399    }
400}
401
402/// Factory functions for vtables.
403pub trait ScalarFnVTableExt: ScalarFnVTable {
404    /// Bind this vtable with the given options into a [`ScalarFnRef`].
405    fn bind(&self, options: Self::Options) -> ScalarFnRef {
406        TypedScalarFnInstance::new(self.clone(), options).erased()
407    }
408
409    /// Create a new expression with this vtable and the given options and children.
410    fn new_expr(
411        &self,
412        options: Self::Options,
413        children: impl IntoIterator<Item = Expression>,
414    ) -> Expression {
415        Self::try_new_expr(self, options, children).vortex_expect("Failed to create expression")
416    }
417
418    /// Try to create a new expression with this vtable and the given options and children.
419    fn try_new_expr(
420        &self,
421        options: Self::Options,
422        children: impl IntoIterator<Item = Expression>,
423    ) -> VortexResult<Expression> {
424        Expression::try_new(self.bind(options), children)
425    }
426}
427impl<V: ScalarFnVTable> ScalarFnVTableExt for V {}
428
429/// A reference to the name of a child expression.
430pub type ChildName = ArcRef<str>;