Skip to main content

vortex_array/scalar_fn/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt;
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::sync::Arc;
11
12use arcref::ArcRef;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17use vortex_session::VortexSession;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::stats::Stat;
25use crate::expr::traversal::Node;
26use crate::scalar_fn::ScalarFnId;
27use crate::scalar_fn::ScalarFnRef;
28use crate::scalar_fn::TypedScalarFnInstance;
29
30/// This trait defines the interface for scalar function vtables, including methods for
31/// serialization, deserialization, validation, child naming, return type computation,
32/// and evaluation.
33///
34/// This trait is non-object safe and allows the implementer to make use of associated types
35/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
36/// outputs of each function.
37///
38/// The [`ScalarFnVTable`] trait should be implemented for a struct that holds global data across
39/// all instances of the expression. In almost all cases, this struct will be an empty unit
40/// struct, since most expressions do not require any global state.
41pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync {
42    /// Options for this expression.
43    type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
44
45    /// Returns the ID of the scalar function vtable.
46    fn id(&self) -> ScalarFnId;
47
48    /// Serialize the options for this expression.
49    ///
50    /// Should return `Ok(None)` if the expression is not serializable, and `Ok(vec![])` if it is
51    /// serializable but has no metadata.
52    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
53        _ = options;
54        Ok(None)
55    }
56
57    /// Deserialize the options of this expression.
58    fn deserialize(
59        &self,
60        _metadata: &[u8],
61        _session: &VortexSession,
62    ) -> VortexResult<Self::Options> {
63        vortex_bail!("Expression {} is not deserializable", self.id());
64    }
65
66    /// Returns the arity of this expression.
67    fn arity(&self, options: &Self::Options) -> Arity;
68
69    /// Returns the name of the nth child of the expr.
70    fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName;
71
72    /// Format this expression in a nice human-readable SQL-style format
73    ///
74    /// The implementation should recursively format child expressions by calling
75    /// `expr.child(i).fmt_sql(f)`.
76    fn fmt_sql(
77        &self,
78        options: &Self::Options,
79        expr: &Expression,
80        f: &mut Formatter<'_>,
81    ) -> fmt::Result {
82        write!(f, "{}(", self.id())?;
83        let nchildren = expr.children_count();
84        for (i, child) in expr.children().iter().enumerate() {
85            child.fmt_sql(f)?;
86            if i + 1 < nchildren {
87                write!(f, ", ")?;
88            }
89        }
90        let opts = format!("{}", options);
91        if !opts.is_empty() {
92            write!(f, ", opts={}", opts)?;
93        }
94        write!(f, ")")
95    }
96
97    /// Compute the return [`DType`] of the expression if evaluated over the given input types.
98    ///
99    /// # Preconditions
100    ///
101    /// The length of `args` must match the [`Arity`] of this function. Callers are responsible
102    /// for validating this (e.g., [`Expression::try_new`] checks arity at construction time).
103    /// Implementations may assume correct arity and will panic or return nonsensical results if
104    /// violated.
105    ///
106    /// [`Expression::try_new`]: crate::expr::Expression::try_new
107    fn return_dtype(&self, options: &Self::Options, args: &[DType]) -> VortexResult<DType>;
108
109    /// Execute the expression over the input arguments.
110    ///
111    /// Implementations are encouraged to check their inputs for constant arrays to perform
112    /// more optimized execution.
113    ///
114    /// If the input arguments cannot be directly used for execution (for example, an expression
115    /// may require canonical input arrays), then the implementation should perform a single
116    /// child execution and return a new [`crate::arrays::ScalarFnArray`] wrapping up the new child.
117    ///
118    /// This provides maximum opportunities for array-level optimizations using execute_parent
119    /// kernels.
120    fn execute(
121        &self,
122        options: &Self::Options,
123        args: &dyn ExecutionArgs,
124        ctx: &mut ExecutionCtx,
125    ) -> VortexResult<ArrayRef>;
126
127    /// Implement an abstract reduction rule over a tree of scalar functions.
128    ///
129    /// The [`ReduceNode`] can be used to traverse children, inspect their types, and
130    /// construct the result expression.
131    ///
132    /// Return `Ok(None)` if no reduction is possible.
133    fn reduce(
134        &self,
135        options: &Self::Options,
136        node: &dyn ReduceNode,
137        ctx: &dyn ReduceCtx,
138    ) -> VortexResult<Option<ReduceNodeRef>> {
139        _ = options;
140        _ = node;
141        _ = ctx;
142        Ok(None)
143    }
144
145    /// Simplify the expression if possible.
146    fn simplify(
147        &self,
148        options: &Self::Options,
149        expr: &Expression,
150        ctx: &dyn SimplifyCtx,
151    ) -> VortexResult<Option<Expression>> {
152        _ = options;
153        _ = expr;
154        _ = ctx;
155        Ok(None)
156    }
157
158    /// Simplify the expression if possible, without type information.
159    fn simplify_untyped(
160        &self,
161        options: &Self::Options,
162        expr: &Expression,
163    ) -> VortexResult<Option<Expression>> {
164        _ = options;
165        _ = expr;
166        Ok(None)
167    }
168
169    /// See [`Expression::stat_falsification`].
170    fn stat_falsification(
171        &self,
172        options: &Self::Options,
173        expr: &Expression,
174        catalog: &dyn StatsCatalog,
175    ) -> Option<Expression> {
176        _ = options;
177        _ = expr;
178        _ = catalog;
179        None
180    }
181
182    /// See [`Expression::stat_expression`].
183    fn stat_expression(
184        &self,
185        options: &Self::Options,
186        expr: &Expression,
187        stat: Stat,
188        catalog: &dyn StatsCatalog,
189    ) -> Option<Expression> {
190        _ = options;
191        _ = expr;
192        _ = stat;
193        _ = catalog;
194        None
195    }
196
197    /// Returns an expression that evaluates to the validity of the result of this expression.
198    ///
199    /// If a validity expression cannot be constructed, returns `None` and the expression will
200    /// be evaluated as normal before extracting the validity mask from the result.
201    ///
202    /// This is essentially a specialized form of a `reduce_parent`
203    fn validity(
204        &self,
205        options: &Self::Options,
206        expression: &Expression,
207    ) -> VortexResult<Option<Expression>> {
208        _ = (options, expression);
209        Ok(None)
210    }
211
212    /// Returns whether this expression itself is null-sensitive. Conservatively default to *true*.
213    ///
214    /// An expression is null-sensitive if it directly operates on null values,
215    /// such as `is_null`. Most expressions are not null-sensitive.
216    ///
217    /// The property we are interested in is if the expression (e) distributes over `mask`.
218    /// Define a `mask(a, m)` expression that applies the boolean array `m` to the validity of the
219    /// array `a`.
220    ///
221    /// A unary expression `e` is not null-sensitive iff forall arrays `a` and masks `m`,
222    /// `e(mask(a, m)) == mask(e(a), m)`.
223    ///
224    /// This can be extended to an n-ary expression.
225    ///
226    /// This method only checks the expression itself, not its children.
227    fn is_null_sensitive(&self, options: &Self::Options) -> bool {
228        _ = options;
229        true
230    }
231
232    /// Returns whether this expression is semantically fallible. Conservatively defaults to
233    /// `true`.
234    ///
235    /// An expression is semantically fallible if there exists a set of well-typed inputs that
236    /// causes the expression to produce an error as part of its _defined behavior_. For example,
237    /// `checked_add` is fallible because integer overflow is a domain error, and division is
238    /// fallible because of division by zero.
239    ///
240    /// This does **not** include execution errors that are incidental to the implementation, such
241    /// as canonicalization failures, memory allocation errors, or encoding mismatches. Those can
242    /// happen to any expression and are not what this method captures.
243    ///
244    /// This property is used by optimizations that speculatively evaluate an expression over values
245    /// that may not appear in the actual input. For example, pushing a scalar function down to a
246    /// dictionary's values array is only safe when the function is infallible or all values are
247    /// referenced, since a fallible function might error on a value left unreferenced after
248    /// slicing that would never be encountered during normal evaluation.
249    ///
250    /// Note: this is only applicable to expressions that pass type-checking via
251    /// [`ScalarFnVTable::return_dtype`].
252    fn is_fallible(&self, options: &Self::Options) -> bool {
253        _ = options;
254        true
255    }
256}
257
258/// Arguments for reduction rules.
259pub trait ReduceCtx {
260    /// Create a new reduction node from the given scalar function and children.
261    fn new_node(
262        &self,
263        scalar_fn: ScalarFnRef,
264        children: &[ReduceNodeRef],
265    ) -> VortexResult<ReduceNodeRef>;
266}
267
268pub type ReduceNodeRef = Arc<dyn ReduceNode>;
269
270/// A node used for implementing abstract reduction rules.
271pub trait ReduceNode {
272    /// Downcast to Any.
273    fn as_any(&self) -> &dyn Any;
274
275    /// Return the data type of this node.
276    fn node_dtype(&self) -> VortexResult<DType>;
277
278    /// Return this node's scalar function if it is indeed a scalar fn.
279    fn scalar_fn(&self) -> Option<&ScalarFnRef>;
280
281    /// Descend to the child of this handle.
282    fn child(&self, idx: usize) -> ReduceNodeRef;
283
284    /// Returns the number of children of this node.
285    fn child_count(&self) -> usize;
286}
287
288/// The arity (number of arguments) of a function.
289#[derive(Clone, Copy, Debug, PartialEq, Eq)]
290pub enum Arity {
291    Exact(usize),
292    Variadic { min: usize, max: Option<usize> },
293}
294
295impl Display for Arity {
296    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
297        match self {
298            Arity::Exact(n) => write!(f, "{}", n),
299            Arity::Variadic { min, max } => match max {
300                Some(max) if min == max => write!(f, "{}", min),
301                Some(max) => write!(f, "{}..{}", min, max),
302                None => write!(f, "{}+", min),
303            },
304        }
305    }
306}
307
308impl Arity {
309    /// Whether the given argument count matches this arity.
310    pub fn matches(&self, arg_count: usize) -> bool {
311        match self {
312            Arity::Exact(m) => *m == arg_count,
313            Arity::Variadic { min, max } => {
314                if arg_count < *min {
315                    return false;
316                }
317                if let Some(max) = max
318                    && arg_count > *max
319                {
320                    return false;
321                }
322                true
323            }
324        }
325    }
326}
327
328/// Context for simplification.
329///
330/// Used to lazily compute input data types where simplification requires them.
331pub trait SimplifyCtx {
332    /// Get the data type of the given expression.
333    fn return_dtype(&self, expr: &Expression) -> VortexResult<DType>;
334}
335
336/// Arguments for expression execution.
337pub trait ExecutionArgs {
338    /// Returns the input array at the given index.
339    fn get(&self, index: usize) -> VortexResult<ArrayRef>;
340
341    /// Returns the number of inputs.
342    fn num_inputs(&self) -> usize;
343
344    /// Returns the row count of the execution scope.
345    fn row_count(&self) -> usize;
346}
347
348/// A concrete [`ExecutionArgs`] backed by a `Vec<ArrayRef>`.
349pub struct VecExecutionArgs {
350    inputs: Vec<ArrayRef>,
351    row_count: usize,
352}
353
354impl VecExecutionArgs {
355    /// Create a new `VecExecutionArgs`.
356    pub fn new(inputs: Vec<ArrayRef>, row_count: usize) -> Self {
357        Self { inputs, row_count }
358    }
359}
360
361impl ExecutionArgs for VecExecutionArgs {
362    fn get(&self, index: usize) -> VortexResult<ArrayRef> {
363        self.inputs.get(index).cloned().ok_or_else(|| {
364            vortex_err!(
365                "Input index {} out of bounds (num_inputs={})",
366                index,
367                self.inputs.len()
368            )
369        })
370    }
371
372    fn num_inputs(&self) -> usize {
373        self.inputs.len()
374    }
375
376    fn row_count(&self) -> usize {
377        self.row_count
378    }
379}
380
381#[derive(Clone, Debug, PartialEq, Eq, Hash)]
382pub struct EmptyOptions;
383impl Display for EmptyOptions {
384    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
385        write!(f, "")
386    }
387}
388
389/// Factory functions for vtables.
390pub trait ScalarFnVTableExt: ScalarFnVTable {
391    /// Bind this vtable with the given options into a [`ScalarFnRef`].
392    fn bind(&self, options: Self::Options) -> ScalarFnRef {
393        TypedScalarFnInstance::new(self.clone(), options).erased()
394    }
395
396    /// Create a new expression with this vtable and the given options and children.
397    fn new_expr(
398        &self,
399        options: Self::Options,
400        children: impl IntoIterator<Item = Expression>,
401    ) -> Expression {
402        Self::try_new_expr(self, options, children).vortex_expect("Failed to create expression")
403    }
404
405    /// Try to create a new expression with this vtable and the given options and children.
406    fn try_new_expr(
407        &self,
408        options: Self::Options,
409        children: impl IntoIterator<Item = Expression>,
410    ) -> VortexResult<Expression> {
411        Expression::try_new(self.bind(options), children)
412    }
413}
414impl<V: ScalarFnVTable> ScalarFnVTableExt for V {}
415
416/// A reference to the name of a child expression.
417pub type ChildName = ArcRef<str>;