vortex_array/scalar_fn/vtable.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt;
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::sync::Arc;
11
12use arcref::ArcRef;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17use vortex_session::VortexSession;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::traversal::Node;
24use crate::scalar_fn::ScalarFnId;
25use crate::scalar_fn::ScalarFnRef;
26use crate::scalar_fn::TypedScalarFnInstance;
27
28/// This trait defines the interface for scalar function vtables, including methods for
29/// serialization, deserialization, validation, child naming, return type computation,
30/// and evaluation.
31///
32/// This trait is non-object safe and allows the implementer to make use of associated types
33/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
34/// outputs of each function.
35///
36/// The [`ScalarFnVTable`] trait should be implemented for a struct that holds global data across
37/// all instances of the expression. In almost all cases, this struct will be an empty unit
38/// struct, since most expressions do not require any global state.
39pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync {
40 /// Options for this expression.
41 type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
42
43 /// Returns the ID of the scalar function vtable.
44 fn id(&self) -> ScalarFnId;
45
46 /// Serialize the options for this expression.
47 ///
48 /// Should return `Ok(None)` if the expression is not serializable, and `Ok(vec![])` if it is
49 /// serializable but has no metadata.
50 fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
51 _ = options;
52 Ok(None)
53 }
54
55 /// Deserialize the options of this expression.
56 fn deserialize(
57 &self,
58 _metadata: &[u8],
59 _session: &VortexSession,
60 ) -> VortexResult<Self::Options> {
61 vortex_bail!("Expression {} is not deserializable", self.id());
62 }
63
64 /// Returns the arity of this expression.
65 fn arity(&self, options: &Self::Options) -> Arity;
66
67 /// Returns the name of the nth child of the expr.
68 fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName;
69
70 /// Format this expression in a nice human-readable SQL-style format
71 ///
72 /// The implementation should recursively format child expressions by calling
73 /// `expr.child(i).fmt_sql(f)`.
74 fn fmt_sql(
75 &self,
76 options: &Self::Options,
77 expr: &Expression,
78 f: &mut Formatter<'_>,
79 ) -> fmt::Result {
80 write!(f, "{}(", self.id())?;
81 let nchildren = expr.children_count();
82 for (i, child) in expr.children().iter().enumerate() {
83 child.fmt_sql(f)?;
84 if i + 1 < nchildren {
85 write!(f, ", ")?;
86 }
87 }
88 let opts = format!("{}", options);
89 if !opts.is_empty() {
90 write!(f, ", opts={}", opts)?;
91 }
92 write!(f, ")")
93 }
94
95 /// Coerce the arguments of this function.
96 ///
97 /// This is optionally used by Vortex users when performing type coercion over a Vortex
98 /// expression. Note that direct Vortex query engine integrations (e.g. DuckDB, DataFusion,
99 /// etc.) do not perform type coercion and rely on the engine's own logical planner.
100 ///
101 /// Note that the default implementation simply returns the arguments without coercion, and it
102 /// is expected that the [`ScalarFnVTable::return_dtype`] call may still fail.
103 fn coerce_args(&self, options: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
104 let _ = options;
105 Ok(args.to_vec())
106 }
107
108 /// Compute the return [`DType`] of the expression if evaluated over the given input types.
109 ///
110 /// # Preconditions
111 ///
112 /// The length of `args` must match the [`Arity`] of this function. Callers are responsible
113 /// for validating this (e.g., [`Expression::try_new`] checks arity at construction time).
114 /// Implementations may assume correct arity and will panic or return nonsensical results if
115 /// violated.
116 ///
117 /// [`Expression::try_new`]: crate::expr::Expression::try_new
118 fn return_dtype(&self, options: &Self::Options, args: &[DType]) -> VortexResult<DType>;
119
120 /// Execute the expression over the input arguments.
121 ///
122 /// Implementations are encouraged to check their inputs for constant arrays to perform
123 /// more optimized execution.
124 ///
125 /// If the input arguments cannot be directly used for execution (for example, an expression
126 /// may require canonical input arrays), then the implementation should perform a single
127 /// child execution and return a new [`crate::arrays::ScalarFnArray`] wrapping up the new child.
128 ///
129 /// This provides maximum opportunities for array-level optimizations using execute_parent
130 /// kernels.
131 fn execute(
132 &self,
133 options: &Self::Options,
134 args: &dyn ExecutionArgs,
135 ctx: &mut ExecutionCtx,
136 ) -> VortexResult<ArrayRef>;
137
138 /// Implement an abstract reduction rule over a tree of scalar functions.
139 ///
140 /// The [`ReduceNode`] can be used to traverse children, inspect their types, and
141 /// construct the result expression.
142 ///
143 /// Return `Ok(None)` if no reduction is possible.
144 fn reduce(
145 &self,
146 options: &Self::Options,
147 node: &dyn ReduceNode,
148 ctx: &dyn ReduceCtx,
149 ) -> VortexResult<Option<ReduceNodeRef>> {
150 _ = options;
151 _ = node;
152 _ = ctx;
153 Ok(None)
154 }
155
156 /// Simplify the expression if possible.
157 fn simplify(
158 &self,
159 options: &Self::Options,
160 expr: &Expression,
161 ctx: &dyn SimplifyCtx,
162 ) -> VortexResult<Option<Expression>> {
163 _ = options;
164 _ = expr;
165 _ = ctx;
166 Ok(None)
167 }
168
169 /// Simplify the expression if possible, without type information.
170 fn simplify_untyped(
171 &self,
172 options: &Self::Options,
173 expr: &Expression,
174 ) -> VortexResult<Option<Expression>> {
175 _ = options;
176 _ = expr;
177 Ok(None)
178 }
179
180 /// Returns an expression that evaluates to the validity of the result of this expression.
181 ///
182 /// If a validity expression cannot be constructed, returns `None` and the expression will
183 /// be evaluated as normal before extracting the validity mask from the result.
184 ///
185 /// This is essentially a specialized form of a `reduce_parent`
186 fn validity(
187 &self,
188 options: &Self::Options,
189 expression: &Expression,
190 ) -> VortexResult<Option<Expression>> {
191 _ = (options, expression);
192 Ok(None)
193 }
194
195 /// Returns whether this expression itself is null-sensitive. Conservatively default to *true*.
196 ///
197 /// An expression is null-sensitive if it directly operates on null values,
198 /// such as `is_null`. Most expressions are not null-sensitive.
199 ///
200 /// The property we are interested in is if the expression (e) distributes over `mask`.
201 /// Define a `mask(a, m)` expression that applies the boolean array `m` to the validity of the
202 /// array `a`.
203 ///
204 /// A unary expression `e` is not null-sensitive iff forall arrays `a` and masks `m`,
205 /// `e(mask(a, m)) == mask(e(a), m)`.
206 ///
207 /// This can be extended to an n-ary expression.
208 ///
209 /// This method only checks the expression itself, not its children.
210 fn is_null_sensitive(&self, options: &Self::Options) -> bool {
211 _ = options;
212 true
213 }
214
215 /// Returns whether this expression is semantically fallible. Conservatively defaults to
216 /// `true`.
217 ///
218 /// An expression is semantically fallible if there exists a set of well-typed inputs that
219 /// causes the expression to produce an error as part of its _defined behavior_. For example,
220 /// `checked_add` is fallible because integer overflow is a domain error, and division is
221 /// fallible because of division by zero.
222 ///
223 /// This does **not** include execution errors that are incidental to the implementation, such
224 /// as canonicalization failures, memory allocation errors, or encoding mismatches. Those can
225 /// happen to any expression and are not what this method captures.
226 ///
227 /// This property is used by optimizations that speculatively evaluate an expression over values
228 /// that may not appear in the actual input. For example, pushing a scalar function down to a
229 /// dictionary's values array is only safe when the function is infallible or all values are
230 /// referenced, since a fallible function might error on a value left unreferenced after
231 /// slicing that would never be encountered during normal evaluation.
232 ///
233 /// Note: this is only applicable to expressions that pass type-checking via
234 /// [`ScalarFnVTable::return_dtype`].
235 fn is_fallible(&self, options: &Self::Options) -> bool {
236 _ = options;
237 true
238 }
239}
240
241/// Arguments for reduction rules.
242pub trait ReduceCtx {
243 /// Create a new reduction node from the given scalar function and children.
244 fn new_node(
245 &self,
246 scalar_fn: ScalarFnRef,
247 children: &[ReduceNodeRef],
248 ) -> VortexResult<ReduceNodeRef>;
249}
250
251pub type ReduceNodeRef = Arc<dyn ReduceNode>;
252
253/// A node used for implementing abstract reduction rules.
254pub trait ReduceNode {
255 /// Downcast to Any.
256 fn as_any(&self) -> &dyn Any;
257
258 /// Return the data type of this node.
259 fn node_dtype(&self) -> VortexResult<DType>;
260
261 /// Return this node's scalar function if it is indeed a scalar fn.
262 fn scalar_fn(&self) -> Option<&ScalarFnRef>;
263
264 /// Descend to the child of this handle.
265 fn child(&self, idx: usize) -> ReduceNodeRef;
266
267 /// Returns the number of children of this node.
268 fn child_count(&self) -> usize;
269}
270
271/// The arity (number of arguments) of a function.
272#[derive(Clone, Copy, Debug, PartialEq, Eq)]
273pub enum Arity {
274 Exact(usize),
275 Variadic { min: usize, max: Option<usize> },
276}
277
278impl Display for Arity {
279 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
280 match self {
281 Arity::Exact(n) => write!(f, "{}", n),
282 Arity::Variadic { min, max } => match max {
283 Some(max) if min == max => write!(f, "{}", min),
284 Some(max) => write!(f, "{}..{}", min, max),
285 None => write!(f, "{}+", min),
286 },
287 }
288 }
289}
290
291impl Arity {
292 /// Whether the given argument count matches this arity.
293 pub fn matches(&self, arg_count: usize) -> bool {
294 match self {
295 Arity::Exact(m) => *m == arg_count,
296 Arity::Variadic { min, max } => {
297 if arg_count < *min {
298 return false;
299 }
300 if let Some(max) = max
301 && arg_count > *max
302 {
303 return false;
304 }
305 true
306 }
307 }
308 }
309}
310
311/// Context for simplification.
312///
313/// Used to lazily compute input data types where simplification requires them.
314pub trait SimplifyCtx {
315 /// Get the data type of the given expression.
316 fn return_dtype(&self, expr: &Expression) -> VortexResult<DType>;
317}
318
319/// Arguments for expression execution.
320pub trait ExecutionArgs {
321 /// Returns the input array at the given index.
322 fn get(&self, index: usize) -> VortexResult<ArrayRef>;
323
324 /// Returns the number of inputs.
325 fn num_inputs(&self) -> usize;
326
327 /// Returns the row count of the execution scope.
328 fn row_count(&self) -> usize;
329}
330
331/// A concrete [`ExecutionArgs`] backed by a `Vec<ArrayRef>`.
332pub struct VecExecutionArgs {
333 inputs: Vec<ArrayRef>,
334 row_count: usize,
335}
336
337impl VecExecutionArgs {
338 /// Create a new `VecExecutionArgs`.
339 pub fn new(inputs: Vec<ArrayRef>, row_count: usize) -> Self {
340 Self { inputs, row_count }
341 }
342}
343
344impl ExecutionArgs for VecExecutionArgs {
345 fn get(&self, index: usize) -> VortexResult<ArrayRef> {
346 self.inputs.get(index).cloned().ok_or_else(|| {
347 vortex_err!(
348 "Input index {} out of bounds (num_inputs={})",
349 index,
350 self.inputs.len()
351 )
352 })
353 }
354
355 fn num_inputs(&self) -> usize {
356 self.inputs.len()
357 }
358
359 fn row_count(&self) -> usize {
360 self.row_count
361 }
362}
363
364#[derive(Clone, Debug, PartialEq, Eq, Hash)]
365pub struct EmptyOptions;
366impl Display for EmptyOptions {
367 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
368 write!(f, "")
369 }
370}
371
372/// Factory functions for vtables.
373pub trait ScalarFnVTableExt: ScalarFnVTable {
374 /// Bind this vtable with the given options into a [`ScalarFnRef`].
375 fn bind(&self, options: Self::Options) -> ScalarFnRef {
376 TypedScalarFnInstance::new(self.clone(), options).erased()
377 }
378
379 /// Create a new expression with this vtable and the given options and children.
380 fn new_expr(
381 &self,
382 options: Self::Options,
383 children: impl IntoIterator<Item = Expression>,
384 ) -> Expression {
385 Self::try_new_expr(self, options, children).vortex_expect("Failed to create expression")
386 }
387
388 /// Try to create a new expression with this vtable and the given options and children.
389 fn try_new_expr(
390 &self,
391 options: Self::Options,
392 children: impl IntoIterator<Item = Expression>,
393 ) -> VortexResult<Expression> {
394 Expression::try_new(self.bind(options), children)
395 }
396}
397impl<V: ScalarFnVTable> ScalarFnVTableExt for V {}
398
399/// A reference to the name of a child expression.
400pub type ChildName = ArcRef<str>;