vortex_array/scalar_fn/vtable.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt;
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::sync::Arc;
11
12use arcref::ArcRef;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17use vortex_session::VortexSession;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::stats::Stat;
25use crate::expr::traversal::Node;
26use crate::scalar_fn::ScalarFnId;
27use crate::scalar_fn::ScalarFnRef;
28use crate::scalar_fn::TypedScalarFnInstance;
29
30/// This trait defines the interface for scalar function vtables, including methods for
31/// serialization, deserialization, validation, child naming, return type computation,
32/// and evaluation.
33///
34/// This trait is non-object safe and allows the implementer to make use of associated types
35/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
36/// outputs of each function.
37///
38/// The [`ScalarFnVTable`] trait should be implemented for a struct that holds global data across
39/// all instances of the expression. In almost all cases, this struct will be an empty unit
40/// struct, since most expressions do not require any global state.
41pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync {
42 /// Options for this expression.
43 type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
44
45 /// Returns the ID of the scalar function vtable.
46 fn id(&self) -> ScalarFnId;
47
48 /// Serialize the options for this expression.
49 ///
50 /// Should return `Ok(None)` if the expression is not serializable, and `Ok(vec![])` if it is
51 /// serializable but has no metadata.
52 fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
53 _ = options;
54 Ok(None)
55 }
56
57 /// Deserialize the options of this expression.
58 fn deserialize(
59 &self,
60 _metadata: &[u8],
61 _session: &VortexSession,
62 ) -> VortexResult<Self::Options> {
63 vortex_bail!("Expression {} is not deserializable", self.id());
64 }
65
66 /// Returns the arity of this expression.
67 fn arity(&self, options: &Self::Options) -> Arity;
68
69 /// Returns the name of the nth child of the expr.
70 fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName;
71
72 /// Format this expression in a nice human-readable SQL-style format
73 ///
74 /// The implementation should recursively format child expressions by calling
75 /// `expr.child(i).fmt_sql(f)`.
76 fn fmt_sql(
77 &self,
78 options: &Self::Options,
79 expr: &Expression,
80 f: &mut Formatter<'_>,
81 ) -> fmt::Result {
82 write!(f, "{}(", self.id())?;
83 let nchildren = expr.children_count();
84 for (i, child) in expr.children().iter().enumerate() {
85 child.fmt_sql(f)?;
86 if i + 1 < nchildren {
87 write!(f, ", ")?;
88 }
89 }
90 let opts = format!("{}", options);
91 if !opts.is_empty() {
92 write!(f, ", opts={}", opts)?;
93 }
94 write!(f, ")")
95 }
96
97 /// Coerce the arguments of this function.
98 ///
99 /// This is optionally used by Vortex users when performing type coercion over a Vortex
100 /// expression. Note that direct Vortex query engine integrations (e.g. DuckDB, DataFusion,
101 /// etc.) do not perform type coercion and rely on the engine's own logical planner.
102 ///
103 /// Note that the default implementation simply returns the arguments without coercion, and it
104 /// is expected that the [`ScalarFnVTable::return_dtype`] call may still fail.
105 fn coerce_args(&self, options: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
106 let _ = options;
107 Ok(args.to_vec())
108 }
109
110 /// Compute the return [`DType`] of the expression if evaluated over the given input types.
111 ///
112 /// # Preconditions
113 ///
114 /// The length of `args` must match the [`Arity`] of this function. Callers are responsible
115 /// for validating this (e.g., [`Expression::try_new`] checks arity at construction time).
116 /// Implementations may assume correct arity and will panic or return nonsensical results if
117 /// violated.
118 ///
119 /// [`Expression::try_new`]: crate::expr::Expression::try_new
120 fn return_dtype(&self, options: &Self::Options, args: &[DType]) -> VortexResult<DType>;
121
122 /// Execute the expression over the input arguments.
123 ///
124 /// Implementations are encouraged to check their inputs for constant arrays to perform
125 /// more optimized execution.
126 ///
127 /// If the input arguments cannot be directly used for execution (for example, an expression
128 /// may require canonical input arrays), then the implementation should perform a single
129 /// child execution and return a new [`crate::arrays::ScalarFnArray`] wrapping up the new child.
130 ///
131 /// This provides maximum opportunities for array-level optimizations using execute_parent
132 /// kernels.
133 fn execute(
134 &self,
135 options: &Self::Options,
136 args: &dyn ExecutionArgs,
137 ctx: &mut ExecutionCtx,
138 ) -> VortexResult<ArrayRef>;
139
140 /// Implement an abstract reduction rule over a tree of scalar functions.
141 ///
142 /// The [`ReduceNode`] can be used to traverse children, inspect their types, and
143 /// construct the result expression.
144 ///
145 /// Return `Ok(None)` if no reduction is possible.
146 fn reduce(
147 &self,
148 options: &Self::Options,
149 node: &dyn ReduceNode,
150 ctx: &dyn ReduceCtx,
151 ) -> VortexResult<Option<ReduceNodeRef>> {
152 _ = options;
153 _ = node;
154 _ = ctx;
155 Ok(None)
156 }
157
158 /// Simplify the expression if possible.
159 fn simplify(
160 &self,
161 options: &Self::Options,
162 expr: &Expression,
163 ctx: &dyn SimplifyCtx,
164 ) -> VortexResult<Option<Expression>> {
165 _ = options;
166 _ = expr;
167 _ = ctx;
168 Ok(None)
169 }
170
171 /// Simplify the expression if possible, without type information.
172 fn simplify_untyped(
173 &self,
174 options: &Self::Options,
175 expr: &Expression,
176 ) -> VortexResult<Option<Expression>> {
177 _ = options;
178 _ = expr;
179 Ok(None)
180 }
181
182 /// See [`Expression::stat_falsification`].
183 fn stat_falsification(
184 &self,
185 options: &Self::Options,
186 expr: &Expression,
187 catalog: &dyn StatsCatalog,
188 ) -> Option<Expression> {
189 _ = options;
190 _ = expr;
191 _ = catalog;
192 None
193 }
194
195 /// See [`Expression::stat_expression`].
196 fn stat_expression(
197 &self,
198 options: &Self::Options,
199 expr: &Expression,
200 stat: Stat,
201 catalog: &dyn StatsCatalog,
202 ) -> Option<Expression> {
203 _ = options;
204 _ = expr;
205 _ = stat;
206 _ = catalog;
207 None
208 }
209
210 /// Returns an expression that evaluates to the validity of the result of this expression.
211 ///
212 /// If a validity expression cannot be constructed, returns `None` and the expression will
213 /// be evaluated as normal before extracting the validity mask from the result.
214 ///
215 /// This is essentially a specialized form of a `reduce_parent`
216 fn validity(
217 &self,
218 options: &Self::Options,
219 expression: &Expression,
220 ) -> VortexResult<Option<Expression>> {
221 _ = (options, expression);
222 Ok(None)
223 }
224
225 /// Returns whether this expression itself is null-sensitive. Conservatively default to *true*.
226 ///
227 /// An expression is null-sensitive if it directly operates on null values,
228 /// such as `is_null`. Most expressions are not null-sensitive.
229 ///
230 /// The property we are interested in is if the expression (e) distributes over `mask`.
231 /// Define a `mask(a, m)` expression that applies the boolean array `m` to the validity of the
232 /// array `a`.
233 ///
234 /// A unary expression `e` is not null-sensitive iff forall arrays `a` and masks `m`,
235 /// `e(mask(a, m)) == mask(e(a), m)`.
236 ///
237 /// This can be extended to an n-ary expression.
238 ///
239 /// This method only checks the expression itself, not its children.
240 fn is_null_sensitive(&self, options: &Self::Options) -> bool {
241 _ = options;
242 true
243 }
244
245 /// Returns whether this expression is semantically fallible. Conservatively defaults to
246 /// `true`.
247 ///
248 /// An expression is semantically fallible if there exists a set of well-typed inputs that
249 /// causes the expression to produce an error as part of its _defined behavior_. For example,
250 /// `checked_add` is fallible because integer overflow is a domain error, and division is
251 /// fallible because of division by zero.
252 ///
253 /// This does **not** include execution errors that are incidental to the implementation, such
254 /// as canonicalization failures, memory allocation errors, or encoding mismatches. Those can
255 /// happen to any expression and are not what this method captures.
256 ///
257 /// This property is used by optimizations that speculatively evaluate an expression over values
258 /// that may not appear in the actual input. For example, pushing a scalar function down to a
259 /// dictionary's values array is only safe when the function is infallible or all values are
260 /// referenced, since a fallible function might error on a value left unreferenced after
261 /// slicing that would never be encountered during normal evaluation.
262 ///
263 /// Note: this is only applicable to expressions that pass type-checking via
264 /// [`ScalarFnVTable::return_dtype`].
265 fn is_fallible(&self, options: &Self::Options) -> bool {
266 _ = options;
267 true
268 }
269}
270
271/// Arguments for reduction rules.
272pub trait ReduceCtx {
273 /// Create a new reduction node from the given scalar function and children.
274 fn new_node(
275 &self,
276 scalar_fn: ScalarFnRef,
277 children: &[ReduceNodeRef],
278 ) -> VortexResult<ReduceNodeRef>;
279}
280
281pub type ReduceNodeRef = Arc<dyn ReduceNode>;
282
283/// A node used for implementing abstract reduction rules.
284pub trait ReduceNode {
285 /// Downcast to Any.
286 fn as_any(&self) -> &dyn Any;
287
288 /// Return the data type of this node.
289 fn node_dtype(&self) -> VortexResult<DType>;
290
291 /// Return this node's scalar function if it is indeed a scalar fn.
292 fn scalar_fn(&self) -> Option<&ScalarFnRef>;
293
294 /// Descend to the child of this handle.
295 fn child(&self, idx: usize) -> ReduceNodeRef;
296
297 /// Returns the number of children of this node.
298 fn child_count(&self) -> usize;
299}
300
301/// The arity (number of arguments) of a function.
302#[derive(Clone, Copy, Debug, PartialEq, Eq)]
303pub enum Arity {
304 Exact(usize),
305 Variadic { min: usize, max: Option<usize> },
306}
307
308impl Display for Arity {
309 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
310 match self {
311 Arity::Exact(n) => write!(f, "{}", n),
312 Arity::Variadic { min, max } => match max {
313 Some(max) if min == max => write!(f, "{}", min),
314 Some(max) => write!(f, "{}..{}", min, max),
315 None => write!(f, "{}+", min),
316 },
317 }
318 }
319}
320
321impl Arity {
322 /// Whether the given argument count matches this arity.
323 pub fn matches(&self, arg_count: usize) -> bool {
324 match self {
325 Arity::Exact(m) => *m == arg_count,
326 Arity::Variadic { min, max } => {
327 if arg_count < *min {
328 return false;
329 }
330 if let Some(max) = max
331 && arg_count > *max
332 {
333 return false;
334 }
335 true
336 }
337 }
338 }
339}
340
341/// Context for simplification.
342///
343/// Used to lazily compute input data types where simplification requires them.
344pub trait SimplifyCtx {
345 /// Get the data type of the given expression.
346 fn return_dtype(&self, expr: &Expression) -> VortexResult<DType>;
347}
348
349/// Arguments for expression execution.
350pub trait ExecutionArgs {
351 /// Returns the input array at the given index.
352 fn get(&self, index: usize) -> VortexResult<ArrayRef>;
353
354 /// Returns the number of inputs.
355 fn num_inputs(&self) -> usize;
356
357 /// Returns the row count of the execution scope.
358 fn row_count(&self) -> usize;
359}
360
361/// A concrete [`ExecutionArgs`] backed by a `Vec<ArrayRef>`.
362pub struct VecExecutionArgs {
363 inputs: Vec<ArrayRef>,
364 row_count: usize,
365}
366
367impl VecExecutionArgs {
368 /// Create a new `VecExecutionArgs`.
369 pub fn new(inputs: Vec<ArrayRef>, row_count: usize) -> Self {
370 Self { inputs, row_count }
371 }
372}
373
374impl ExecutionArgs for VecExecutionArgs {
375 fn get(&self, index: usize) -> VortexResult<ArrayRef> {
376 self.inputs.get(index).cloned().ok_or_else(|| {
377 vortex_err!(
378 "Input index {} out of bounds (num_inputs={})",
379 index,
380 self.inputs.len()
381 )
382 })
383 }
384
385 fn num_inputs(&self) -> usize {
386 self.inputs.len()
387 }
388
389 fn row_count(&self) -> usize {
390 self.row_count
391 }
392}
393
394#[derive(Clone, Debug, PartialEq, Eq, Hash)]
395pub struct EmptyOptions;
396impl Display for EmptyOptions {
397 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
398 write!(f, "")
399 }
400}
401
402/// Factory functions for vtables.
403pub trait ScalarFnVTableExt: ScalarFnVTable {
404 /// Bind this vtable with the given options into a [`ScalarFnRef`].
405 fn bind(&self, options: Self::Options) -> ScalarFnRef {
406 TypedScalarFnInstance::new(self.clone(), options).erased()
407 }
408
409 /// Create a new expression with this vtable and the given options and children.
410 fn new_expr(
411 &self,
412 options: Self::Options,
413 children: impl IntoIterator<Item = Expression>,
414 ) -> Expression {
415 Self::try_new_expr(self, options, children).vortex_expect("Failed to create expression")
416 }
417
418 /// Try to create a new expression with this vtable and the given options and children.
419 fn try_new_expr(
420 &self,
421 options: Self::Options,
422 children: impl IntoIterator<Item = Expression>,
423 ) -> VortexResult<Expression> {
424 Expression::try_new(self.bind(options), children)
425 }
426}
427impl<V: ScalarFnVTable> ScalarFnVTableExt for V {}
428
429/// A reference to the name of a child expression.
430pub type ChildName = ArcRef<str>;