vortex_array/scalar_fn/vtable.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt;
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::sync::Arc;
11
12use arcref::ArcRef;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17use vortex_session::VortexSession;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::stats::Stat;
25use crate::scalar_fn::ScalarFn;
26use crate::scalar_fn::ScalarFnId;
27use crate::scalar_fn::ScalarFnRef;
28
29/// This trait defines the interface for scalar function vtables, including methods for
30/// serialization, deserialization, validation, child naming, return type computation,
31/// and evaluation.
32///
33/// This trait is non-object safe and allows the implementer to make use of associated types
34/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
35/// outputs of each function.
36///
37/// The [`ScalarFnVTable`] trait should be implemented for a struct that holds global data across
38/// all instances of the expression. In almost all cases, this struct will be an empty unit
39/// struct, since most expressions do not require any global state.
40pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync {
41 /// Options for this expression.
42 type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
43
44 /// Returns the ID of the scalar function vtable.
45 fn id(&self) -> ScalarFnId;
46
47 /// Serialize the options for this expression.
48 ///
49 /// Should return `Ok(None)` if the expression is not serializable, and `Ok(vec![])` if it is
50 /// serializable but has no metadata.
51 fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
52 _ = options;
53 Ok(None)
54 }
55
56 /// Deserialize the options of this expression.
57 fn deserialize(
58 &self,
59 _metadata: &[u8],
60 _session: &VortexSession,
61 ) -> VortexResult<Self::Options> {
62 vortex_bail!("Expression {} is not deserializable", self.id());
63 }
64
65 /// Returns the arity of this expression.
66 fn arity(&self, options: &Self::Options) -> Arity;
67
68 /// Returns the name of the nth child of the expr.
69 fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName;
70
71 /// Format this expression in a nice human-readable SQL-style format
72 ///
73 /// The implementation should recursively format child expressions by calling
74 /// `expr.child(i).fmt_sql(f)`.
75 fn fmt_sql(
76 &self,
77 options: &Self::Options,
78 expr: &Expression,
79 f: &mut Formatter<'_>,
80 ) -> fmt::Result;
81
82 /// Coerce the arguments of this function.
83 ///
84 /// This is optionally used by Vortex users when performing type coercion over a Vortex
85 /// expression. Note that direct Vortex query engine integrations (e.g. DuckDB, DataFusion,
86 /// etc.) do not perform type coercion and rely on the engine's own logical planner.
87 ///
88 /// Note that the default implementation simply returns the arguments without coercion, and it
89 /// is expected that the [`ScalarFnVTable::return_dtype`] call may still fail.
90 fn coerce_args(&self, options: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
91 let _ = options;
92 Ok(args.to_vec())
93 }
94
95 /// Compute the return [`DType`] of the expression if evaluated over the given input types.
96 fn return_dtype(&self, options: &Self::Options, args: &[DType]) -> VortexResult<DType>;
97
98 /// Execute the expression over the input arguments.
99 ///
100 /// Implementations are encouraged to check their inputs for constant arrays to perform
101 /// more optimized execution.
102 ///
103 /// If the input arguments cannot be directly used for execution (for example, an expression
104 /// may require canonical input arrays), then the implementation should perform a single
105 /// child execution and return a new [`crate::arrays::ScalarFnArray`] wrapping up the new child.
106 ///
107 /// This provides maximum opportunities for array-level optimizations using execute_parent
108 /// kernels.
109 fn execute(
110 &self,
111 options: &Self::Options,
112 args: &dyn ExecutionArgs,
113 ctx: &mut ExecutionCtx,
114 ) -> VortexResult<ArrayRef>;
115
116 /// Implement an abstract reduction rule over a tree of scalar functions.
117 ///
118 /// The [`ReduceNode`] can be used to traverse children, inspect their types, and
119 /// construct the result expression.
120 ///
121 /// Return `Ok(None)` if no reduction is possible.
122 fn reduce(
123 &self,
124 options: &Self::Options,
125 node: &dyn ReduceNode,
126 ctx: &dyn ReduceCtx,
127 ) -> VortexResult<Option<ReduceNodeRef>> {
128 _ = options;
129 _ = node;
130 _ = ctx;
131 Ok(None)
132 }
133
134 /// Simplify the expression if possible.
135 fn simplify(
136 &self,
137 options: &Self::Options,
138 expr: &Expression,
139 ctx: &dyn SimplifyCtx,
140 ) -> VortexResult<Option<Expression>> {
141 _ = options;
142 _ = expr;
143 _ = ctx;
144 Ok(None)
145 }
146
147 /// Simplify the expression if possible, without type information.
148 fn simplify_untyped(
149 &self,
150 options: &Self::Options,
151 expr: &Expression,
152 ) -> VortexResult<Option<Expression>> {
153 _ = options;
154 _ = expr;
155 Ok(None)
156 }
157
158 /// See [`Expression::stat_falsification`].
159 fn stat_falsification(
160 &self,
161 options: &Self::Options,
162 expr: &Expression,
163 catalog: &dyn StatsCatalog,
164 ) -> Option<Expression> {
165 _ = options;
166 _ = expr;
167 _ = catalog;
168 None
169 }
170
171 /// See [`Expression::stat_expression`].
172 fn stat_expression(
173 &self,
174 options: &Self::Options,
175 expr: &Expression,
176 stat: Stat,
177 catalog: &dyn StatsCatalog,
178 ) -> Option<Expression> {
179 _ = options;
180 _ = expr;
181 _ = stat;
182 _ = catalog;
183 None
184 }
185
186 /// Returns an expression that evaluates to the validity of the result of this expression.
187 ///
188 /// If a validity expression cannot be constructed, returns `None` and the expression will
189 /// be evaluated as normal before extracting the validity mask from the result.
190 ///
191 /// This is essentially a specialized form of a `reduce_parent`
192 fn validity(
193 &self,
194 options: &Self::Options,
195 expression: &Expression,
196 ) -> VortexResult<Option<Expression>> {
197 _ = (options, expression);
198 Ok(None)
199 }
200
201 /// Returns whether this expression itself is null-sensitive. Conservatively default to *true*.
202 ///
203 /// An expression is null-sensitive if it directly operates on null values,
204 /// such as `is_null`. Most expressions are not null-sensitive.
205 ///
206 /// The property we are interested in is if the expression (e) distributes over `mask`.
207 /// Define a `mask(a, m)` expression that applies the boolean array `m` to the validity of the
208 /// array `a`.
209 ///
210 /// A unary expression `e` is not null-sensitive iff forall arrays `a` and masks `m`,
211 /// `e(mask(a, m)) == mask(e(a), m)`.
212 ///
213 /// This can be extended to an n-ary expression.
214 ///
215 /// This method only checks the expression itself, not its children.
216 fn is_null_sensitive(&self, options: &Self::Options) -> bool {
217 _ = options;
218 true
219 }
220
221 /// Returns whether this expression itself is fallible. Conservatively default to *true*.
222 ///
223 /// An expression is runtime fallible is there is an input set that causes the expression to
224 /// panic or return an error, for example checked_add is fallible if there is overflow.
225 ///
226 /// Note: this is only applicable to expressions that pass type-checking
227 /// [`ScalarFnVTable::return_dtype`].
228 fn is_fallible(&self, options: &Self::Options) -> bool {
229 _ = options;
230 true
231 }
232}
233
234/// Arguments for reduction rules.
235pub trait ReduceCtx {
236 /// Create a new reduction node from the given scalar function and children.
237 fn new_node(
238 &self,
239 scalar_fn: ScalarFnRef,
240 children: &[ReduceNodeRef],
241 ) -> VortexResult<ReduceNodeRef>;
242}
243
244pub type ReduceNodeRef = Arc<dyn ReduceNode>;
245
246/// A node used for implementing abstract reduction rules.
247pub trait ReduceNode {
248 /// Downcast to Any.
249 fn as_any(&self) -> &dyn Any;
250
251 /// Return the data type of this node.
252 fn node_dtype(&self) -> VortexResult<DType>;
253
254 /// Return this node's scalar function if it is indeed a scalar fn.
255 fn scalar_fn(&self) -> Option<&ScalarFnRef>;
256
257 /// Descend to the child of this handle.
258 fn child(&self, idx: usize) -> ReduceNodeRef;
259
260 /// Returns the number of children of this node.
261 fn child_count(&self) -> usize;
262}
263
264/// The arity (number of arguments) of a function.
265#[derive(Clone, Copy, Debug, PartialEq, Eq)]
266pub enum Arity {
267 Exact(usize),
268 Variadic { min: usize, max: Option<usize> },
269}
270
271impl Display for Arity {
272 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
273 match self {
274 Arity::Exact(n) => write!(f, "{}", n),
275 Arity::Variadic { min, max } => match max {
276 Some(max) if min == max => write!(f, "{}", min),
277 Some(max) => write!(f, "{}..{}", min, max),
278 None => write!(f, "{}+", min),
279 },
280 }
281 }
282}
283
284impl Arity {
285 /// Whether the given argument count matches this arity.
286 pub fn matches(&self, arg_count: usize) -> bool {
287 match self {
288 Arity::Exact(m) => *m == arg_count,
289 Arity::Variadic { min, max } => {
290 if arg_count < *min {
291 return false;
292 }
293 if let Some(max) = max
294 && arg_count > *max
295 {
296 return false;
297 }
298 true
299 }
300 }
301 }
302}
303
304/// Context for simplification.
305///
306/// Used to lazily compute input data types where simplification requires them.
307pub trait SimplifyCtx {
308 /// Get the data type of the given expression.
309 fn return_dtype(&self, expr: &Expression) -> VortexResult<DType>;
310}
311
312/// Arguments for expression execution.
313pub trait ExecutionArgs {
314 /// Returns the input array at the given index.
315 fn get(&self, index: usize) -> VortexResult<ArrayRef>;
316
317 /// Returns the number of inputs.
318 fn num_inputs(&self) -> usize;
319
320 /// Returns the row count of the execution scope.
321 fn row_count(&self) -> usize;
322}
323
324/// A concrete [`ExecutionArgs`] backed by a `Vec<ArrayRef>`.
325pub struct VecExecutionArgs {
326 inputs: Vec<ArrayRef>,
327 row_count: usize,
328}
329
330impl VecExecutionArgs {
331 /// Create a new `VecExecutionArgs`.
332 pub fn new(inputs: Vec<ArrayRef>, row_count: usize) -> Self {
333 Self { inputs, row_count }
334 }
335}
336
337impl ExecutionArgs for VecExecutionArgs {
338 fn get(&self, index: usize) -> VortexResult<ArrayRef> {
339 self.inputs.get(index).cloned().ok_or_else(|| {
340 vortex_err!(
341 "Input index {} out of bounds (num_inputs={})",
342 index,
343 self.inputs.len()
344 )
345 })
346 }
347
348 fn num_inputs(&self) -> usize {
349 self.inputs.len()
350 }
351
352 fn row_count(&self) -> usize {
353 self.row_count
354 }
355}
356
357#[derive(Clone, Debug, PartialEq, Eq, Hash)]
358pub struct EmptyOptions;
359impl Display for EmptyOptions {
360 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
361 write!(f, "")
362 }
363}
364
365/// Factory functions for vtables.
366pub trait ScalarFnVTableExt: ScalarFnVTable {
367 /// Bind this vtable with the given options into a [`ScalarFnRef`].
368 fn bind(&self, options: Self::Options) -> ScalarFnRef {
369 ScalarFn::new(self.clone(), options).erased()
370 }
371
372 /// Create a new expression with this vtable and the given options and children.
373 fn new_expr(
374 &self,
375 options: Self::Options,
376 children: impl IntoIterator<Item = Expression>,
377 ) -> Expression {
378 Self::try_new_expr(self, options, children).vortex_expect("Failed to create expression")
379 }
380
381 /// Try to create a new expression with this vtable and the given options and children.
382 fn try_new_expr(
383 &self,
384 options: Self::Options,
385 children: impl IntoIterator<Item = Expression>,
386 ) -> VortexResult<Expression> {
387 Expression::try_new(self.bind(options), children)
388 }
389}
390impl<V: ScalarFnVTable> ScalarFnVTableExt for V {}
391
392/// A reference to the name of a child expression.
393pub type ChildName = ArcRef<str>;