Skip to main content

vortex_array/scalar_fn/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt;
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::sync::Arc;
11
12use arcref::ArcRef;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17use vortex_session::VortexSession;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::stats::Stat;
25use crate::scalar_fn::ScalarFn;
26use crate::scalar_fn::ScalarFnId;
27use crate::scalar_fn::ScalarFnRef;
28
29/// This trait defines the interface for scalar function vtables, including methods for
30/// serialization, deserialization, validation, child naming, return type computation,
31/// and evaluation.
32///
33/// This trait is non-object safe and allows the implementer to make use of associated types
34/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
35/// outputs of each function.
36///
37/// The [`ScalarFnVTable`] trait should be implemented for a struct that holds global data across
38/// all instances of the expression. In almost all cases, this struct will be an empty unit
39/// struct, since most expressions do not require any global state.
40pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync {
41    /// Options for this expression.
42    type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
43
44    /// Returns the ID of the scalar function vtable.
45    fn id(&self) -> ScalarFnId;
46
47    /// Serialize the options for this expression.
48    ///
49    /// Should return `Ok(None)` if the expression is not serializable, and `Ok(vec![])` if it is
50    /// serializable but has no metadata.
51    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
52        _ = options;
53        Ok(None)
54    }
55
56    /// Deserialize the options of this expression.
57    fn deserialize(
58        &self,
59        _metadata: &[u8],
60        _session: &VortexSession,
61    ) -> VortexResult<Self::Options> {
62        vortex_bail!("Expression {} is not deserializable", self.id());
63    }
64
65    /// Returns the arity of this expression.
66    fn arity(&self, options: &Self::Options) -> Arity;
67
68    /// Returns the name of the nth child of the expr.
69    fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName;
70
71    /// Format this expression in nice human-readable SQL-style format
72    ///
73    /// The implementation should recursively format child expressions by calling
74    /// `expr.child(i).fmt_sql(f)`.
75    fn fmt_sql(
76        &self,
77        options: &Self::Options,
78        expr: &Expression,
79        f: &mut Formatter<'_>,
80    ) -> fmt::Result;
81
82    /// Compute the return [`DType`] of the expression if evaluated over the given input types.
83    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType>;
84
85    /// Execute the expression over the input arguments.
86    ///
87    /// Implementations are encouraged to check their inputs for constant arrays to perform
88    /// more optimized execution.
89    ///
90    /// If the input arguments cannot be directly used for execution (for example, an expression
91    /// may require canonical input arrays), then the implementation should perform a single
92    /// child execution and return a new [`crate::arrays::ScalarFnArray`] wrapping up the new child.
93    ///
94    /// This provides maximum opportunities for array-level optimizations using execute_parent
95    /// kernels.
96    fn execute(
97        &self,
98        options: &Self::Options,
99        args: &dyn ExecutionArgs,
100        ctx: &mut ExecutionCtx,
101    ) -> VortexResult<ArrayRef>;
102
103    /// Implement an abstract reduction rule over a tree of scalar functions.
104    ///
105    /// The [`ReduceNode`] can be used to traverse children, inspect their types, and
106    /// construct the result expression.
107    ///
108    /// Return `Ok(None)` if no reduction is possible.
109    fn reduce(
110        &self,
111        options: &Self::Options,
112        node: &dyn ReduceNode,
113        ctx: &dyn ReduceCtx,
114    ) -> VortexResult<Option<ReduceNodeRef>> {
115        _ = options;
116        _ = node;
117        _ = ctx;
118        Ok(None)
119    }
120
121    /// Simplify the expression if possible.
122    fn simplify(
123        &self,
124        options: &Self::Options,
125        expr: &Expression,
126        ctx: &dyn SimplifyCtx,
127    ) -> VortexResult<Option<Expression>> {
128        _ = options;
129        _ = expr;
130        _ = ctx;
131        Ok(None)
132    }
133
134    /// Simplify the expression if possible, without type information.
135    fn simplify_untyped(
136        &self,
137        options: &Self::Options,
138        expr: &Expression,
139    ) -> VortexResult<Option<Expression>> {
140        _ = options;
141        _ = expr;
142        Ok(None)
143    }
144
145    /// See [`Expression::stat_falsification`].
146    fn stat_falsification(
147        &self,
148        options: &Self::Options,
149        expr: &Expression,
150        catalog: &dyn StatsCatalog,
151    ) -> Option<Expression> {
152        _ = options;
153        _ = expr;
154        _ = catalog;
155        None
156    }
157
158    /// See [`Expression::stat_expression`].
159    fn stat_expression(
160        &self,
161        options: &Self::Options,
162        expr: &Expression,
163        stat: Stat,
164        catalog: &dyn StatsCatalog,
165    ) -> Option<Expression> {
166        _ = options;
167        _ = expr;
168        _ = stat;
169        _ = catalog;
170        None
171    }
172
173    /// Returns an expression that evaluates to the validity of the result of this expression.
174    ///
175    /// If a validity expression cannot be constructed, returns `None` and the expression will
176    /// be evaluated as normal before extracting the validity mask from the result.
177    ///
178    /// This is essentially a specialized form of a `reduce_parent`
179    fn validity(
180        &self,
181        options: &Self::Options,
182        expression: &Expression,
183    ) -> VortexResult<Option<Expression>> {
184        _ = (options, expression);
185        Ok(None)
186    }
187
188    /// Returns whether this expression itself is null-sensitive. Conservatively default to *true*.
189    ///
190    /// An expression is null-sensitive if it directly operates on null values,
191    /// such as `is_null`. Most expressions are not null-sensitive.
192    ///
193    /// The property we are interested in is if the expression (e) distributes over `mask`.
194    /// Define a `mask(a, m)` expression that applies the boolean array `m` to the validity of the
195    /// array `a`.
196    ///
197    /// A unary expression `e` is not null-sensitive iff forall arrays `a` and masks `m`,
198    /// `e(mask(a, m)) == mask(e(a), m)`.
199    ///
200    /// This can be extended to an n-ary expression.
201    ///
202    /// This method only checks the expression itself, not its children.
203    fn is_null_sensitive(&self, options: &Self::Options) -> bool {
204        _ = options;
205        true
206    }
207
208    /// Returns whether this expression itself is fallible. Conservatively default to *true*.
209    ///
210    /// An expression is runtime fallible is there is an input set that causes the expression to
211    /// panic or return an error, for example checked_add is fallible if there is overflow.
212    ///
213    /// Note: this is only applicable to expressions that pass type-checking
214    /// [`ScalarFnVTable::return_dtype`].
215    fn is_fallible(&self, options: &Self::Options) -> bool {
216        _ = options;
217        true
218    }
219}
220
221/// Arguments for reduction rules.
222pub trait ReduceCtx {
223    /// Create a new reduction node from the given scalar function and children.
224    fn new_node(
225        &self,
226        scalar_fn: ScalarFnRef,
227        children: &[ReduceNodeRef],
228    ) -> VortexResult<ReduceNodeRef>;
229}
230
231pub type ReduceNodeRef = Arc<dyn ReduceNode>;
232
233/// A node used for implementing abstract reduction rules.
234pub trait ReduceNode {
235    /// Downcast to Any.
236    fn as_any(&self) -> &dyn Any;
237
238    /// Return the data type of this node.
239    fn node_dtype(&self) -> VortexResult<DType>;
240
241    /// Return this node's scalar function if it is indeed a scalar fn.
242    fn scalar_fn(&self) -> Option<&ScalarFnRef>;
243
244    /// Descend to the child of this handle.
245    fn child(&self, idx: usize) -> ReduceNodeRef;
246
247    /// Returns the number of children of this node.
248    fn child_count(&self) -> usize;
249}
250
251/// The arity (number of arguments) of a function.
252#[derive(Clone, Copy, Debug, PartialEq, Eq)]
253pub enum Arity {
254    Exact(usize),
255    Variadic { min: usize, max: Option<usize> },
256}
257
258impl Display for Arity {
259    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
260        match self {
261            Arity::Exact(n) => write!(f, "{}", n),
262            Arity::Variadic { min, max } => match max {
263                Some(max) if min == max => write!(f, "{}", min),
264                Some(max) => write!(f, "{}..{}", min, max),
265                None => write!(f, "{}+", min),
266            },
267        }
268    }
269}
270
271impl Arity {
272    /// Whether the given argument count matches this arity.
273    pub fn matches(&self, arg_count: usize) -> bool {
274        match self {
275            Arity::Exact(m) => *m == arg_count,
276            Arity::Variadic { min, max } => {
277                if arg_count < *min {
278                    return false;
279                }
280                if let Some(max) = max
281                    && arg_count > *max
282                {
283                    return false;
284                }
285                true
286            }
287        }
288    }
289}
290
291/// Context for simplification.
292///
293/// Used to lazily compute input data types where simplification requires them.
294pub trait SimplifyCtx {
295    /// Get the data type of the given expression.
296    fn return_dtype(&self, expr: &Expression) -> VortexResult<DType>;
297}
298
299/// Arguments for expression execution.
300pub trait ExecutionArgs {
301    /// Returns the input array at the given index.
302    fn get(&self, index: usize) -> VortexResult<ArrayRef>;
303
304    /// Returns the number of inputs.
305    fn num_inputs(&self) -> usize;
306
307    /// Returns the row count of the execution scope.
308    fn row_count(&self) -> usize;
309}
310
311/// A concrete [`ExecutionArgs`] backed by a `Vec<ArrayRef>`.
312pub struct VecExecutionArgs {
313    inputs: Vec<ArrayRef>,
314    row_count: usize,
315}
316
317impl VecExecutionArgs {
318    /// Create a new `VecExecutionArgs`.
319    pub fn new(inputs: Vec<ArrayRef>, row_count: usize) -> Self {
320        Self { inputs, row_count }
321    }
322}
323
324impl ExecutionArgs for VecExecutionArgs {
325    fn get(&self, index: usize) -> VortexResult<ArrayRef> {
326        self.inputs.get(index).cloned().ok_or_else(|| {
327            vortex_err!(
328                "Input index {} out of bounds (num_inputs={})",
329                index,
330                self.inputs.len()
331            )
332        })
333    }
334
335    fn num_inputs(&self) -> usize {
336        self.inputs.len()
337    }
338
339    fn row_count(&self) -> usize {
340        self.row_count
341    }
342}
343
344#[derive(Clone, Debug, PartialEq, Eq, Hash)]
345pub struct EmptyOptions;
346impl Display for EmptyOptions {
347    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
348        write!(f, "")
349    }
350}
351
352/// Factory functions for vtables.
353pub trait ScalarFnVTableExt: ScalarFnVTable {
354    /// Bind this vtable with the given options into a [`ScalarFnRef`].
355    fn bind(&self, options: Self::Options) -> ScalarFnRef {
356        ScalarFn::new(self.clone(), options).erased()
357    }
358
359    /// Create a new expression with this vtable and the given options and children.
360    fn new_expr(
361        &self,
362        options: Self::Options,
363        children: impl IntoIterator<Item = Expression>,
364    ) -> Expression {
365        Self::try_new_expr(self, options, children).vortex_expect("Failed to create expression")
366    }
367
368    /// Try to create a new expression with this vtable and the given options and children.
369    fn try_new_expr(
370        &self,
371        options: Self::Options,
372        children: impl IntoIterator<Item = Expression>,
373    ) -> VortexResult<Expression> {
374        Expression::try_new(self.bind(options), children)
375    }
376}
377impl<V: ScalarFnVTable> ScalarFnVTableExt for V {}
378
379/// A reference to the name of a child expression.
380pub type ChildName = ArcRef<str>;