vortex_array/scalar_fn/vtable.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt;
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::sync::Arc;
11
12use arcref::ArcRef;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_session::VortexSession;
17
18use crate::ArrayRef;
19use crate::ExecutionCtx;
20use crate::dtype::DType;
21use crate::expr::Expression;
22use crate::expr::StatsCatalog;
23use crate::expr::stats::Stat;
24use crate::scalar_fn::ScalarFn;
25use crate::scalar_fn::ScalarFnId;
26use crate::scalar_fn::ScalarFnRef;
27
28/// This trait defines the interface for scalar function vtables, including methods for
29/// serialization, deserialization, validation, child naming, return type computation,
30/// and evaluation.
31///
32/// This trait is non-object safe and allows the implementer to make use of associated types
33/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
34/// outputs of each function.
35///
36/// The [`ScalarFnVTable`] trait should be implemented for a struct that holds global data across
37/// all instances of the expression. In almost all cases, this struct will be an empty unit
38/// struct, since most expressions do not require any global state.
39pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync {
40 /// Options for this expression.
41 type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
42
43 /// Returns the ID of the scalar function vtable.
44 fn id(&self) -> ScalarFnId;
45
46 /// Serialize the options for this expression.
47 ///
48 /// Should return `Ok(None)` if the expression is not serializable, and `Ok(vec![])` if it is
49 /// serializable but has no metadata.
50 fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
51 _ = options;
52 Ok(None)
53 }
54
55 /// Deserialize the options of this expression.
56 fn deserialize(
57 &self,
58 _metadata: &[u8],
59 _session: &VortexSession,
60 ) -> VortexResult<Self::Options> {
61 vortex_bail!("Expression {} is not deserializable", self.id());
62 }
63
64 /// Returns the arity of this expression.
65 fn arity(&self, options: &Self::Options) -> Arity;
66
67 /// Returns the name of the nth child of the expr.
68 fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName;
69
70 /// Format this expression in nice human-readable SQL-style format
71 ///
72 /// The implementation should recursively format child expressions by calling
73 /// `expr.child(i).fmt_sql(f)`.
74 fn fmt_sql(
75 &self,
76 options: &Self::Options,
77 expr: &Expression,
78 f: &mut Formatter<'_>,
79 ) -> fmt::Result;
80
81 /// Compute the return [`DType`] of the expression if evaluated over the given input types.
82 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType>;
83
84 /// Execute the expression over the input arguments.
85 ///
86 /// Implementations are encouraged to check their inputs for constant arrays to perform
87 /// more optimized execution.
88 ///
89 /// If the input arguments cannot be directly used for execution (for example, an expression
90 /// may require canonical input arrays), then the implementation should perform a single
91 /// child execution and return a new [`crate::arrays::ScalarFnArray`] wrapping up the new child.
92 ///
93 /// This provides maximum opportunities for array-level optimizations using execute_parent
94 /// kernels.
95 fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef>;
96
97 /// Implement an abstract reduction rule over a tree of scalar functions.
98 ///
99 /// The [`ReduceNode`] can be used to traverse children, inspect their types, and
100 /// construct the result expression.
101 ///
102 /// Return `Ok(None)` if no reduction is possible.
103 fn reduce(
104 &self,
105 options: &Self::Options,
106 node: &dyn ReduceNode,
107 ctx: &dyn ReduceCtx,
108 ) -> VortexResult<Option<ReduceNodeRef>> {
109 _ = options;
110 _ = node;
111 _ = ctx;
112 Ok(None)
113 }
114
115 /// Simplify the expression if possible.
116 fn simplify(
117 &self,
118 options: &Self::Options,
119 expr: &Expression,
120 ctx: &dyn SimplifyCtx,
121 ) -> VortexResult<Option<Expression>> {
122 _ = options;
123 _ = expr;
124 _ = ctx;
125 Ok(None)
126 }
127
128 /// Simplify the expression if possible, without type information.
129 fn simplify_untyped(
130 &self,
131 options: &Self::Options,
132 expr: &Expression,
133 ) -> VortexResult<Option<Expression>> {
134 _ = options;
135 _ = expr;
136 Ok(None)
137 }
138
139 /// See [`Expression::stat_falsification`].
140 fn stat_falsification(
141 &self,
142 options: &Self::Options,
143 expr: &Expression,
144 catalog: &dyn StatsCatalog,
145 ) -> Option<Expression> {
146 _ = options;
147 _ = expr;
148 _ = catalog;
149 None
150 }
151
152 /// See [`Expression::stat_expression`].
153 fn stat_expression(
154 &self,
155 options: &Self::Options,
156 expr: &Expression,
157 stat: Stat,
158 catalog: &dyn StatsCatalog,
159 ) -> Option<Expression> {
160 _ = options;
161 _ = expr;
162 _ = stat;
163 _ = catalog;
164 None
165 }
166
167 /// Returns an expression that evaluates to the validity of the result of this expression.
168 ///
169 /// If a validity expression cannot be constructed, returns `None` and the expression will
170 /// be evaluated as normal before extracting the validity mask from the result.
171 ///
172 /// This is essentially a specialized form of a `reduce_parent`
173 fn validity(
174 &self,
175 options: &Self::Options,
176 expression: &Expression,
177 ) -> VortexResult<Option<Expression>> {
178 _ = (options, expression);
179 Ok(None)
180 }
181
182 /// Returns whether this expression itself is null-sensitive. Conservatively default to *true*.
183 ///
184 /// An expression is null-sensitive if it directly operates on null values,
185 /// such as `is_null`. Most expressions are not null-sensitive.
186 ///
187 /// The property we are interested in is if the expression (e) distributes over `mask`.
188 /// Define a `mask(a, m)` expression that applies the boolean array `m` to the validity of the
189 /// array `a`.
190 ///
191 /// A unary expression `e` is not null-sensitive iff forall arrays `a` and masks `m`,
192 /// `e(mask(a, m)) == mask(e(a), m)`.
193 ///
194 /// This can be extended to an n-ary expression.
195 ///
196 /// This method only checks the expression itself, not its children.
197 fn is_null_sensitive(&self, options: &Self::Options) -> bool {
198 _ = options;
199 true
200 }
201
202 /// Returns whether this expression itself is fallible. Conservatively default to *true*.
203 ///
204 /// An expression is runtime fallible is there is an input set that causes the expression to
205 /// panic or return an error, for example checked_add is fallible if there is overflow.
206 ///
207 /// Note: this is only applicable to expressions that pass type-checking
208 /// [`ScalarFnVTable::return_dtype`].
209 fn is_fallible(&self, options: &Self::Options) -> bool {
210 _ = options;
211 true
212 }
213}
214
215/// Arguments for reduction rules.
216pub trait ReduceCtx {
217 /// Create a new reduction node from the given scalar function and children.
218 fn new_node(
219 &self,
220 scalar_fn: ScalarFnRef,
221 children: &[ReduceNodeRef],
222 ) -> VortexResult<ReduceNodeRef>;
223}
224
225pub type ReduceNodeRef = Arc<dyn ReduceNode>;
226
227/// A node used for implementing abstract reduction rules.
228pub trait ReduceNode {
229 /// Downcast to Any.
230 fn as_any(&self) -> &dyn Any;
231
232 /// Return the data type of this node.
233 fn node_dtype(&self) -> VortexResult<DType>;
234
235 /// Return this node's scalar function if it is indeed a scalar fn.
236 fn scalar_fn(&self) -> Option<&ScalarFnRef>;
237
238 /// Descend to the child of this handle.
239 fn child(&self, idx: usize) -> ReduceNodeRef;
240
241 /// Returns the number of children of this node.
242 fn child_count(&self) -> usize;
243}
244
245/// The arity (number of arguments) of a function.
246#[derive(Clone, Copy, Debug, PartialEq, Eq)]
247pub enum Arity {
248 Exact(usize),
249 Variadic { min: usize, max: Option<usize> },
250}
251
252impl Display for Arity {
253 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
254 match self {
255 Arity::Exact(n) => write!(f, "{}", n),
256 Arity::Variadic { min, max } => match max {
257 Some(max) if min == max => write!(f, "{}", min),
258 Some(max) => write!(f, "{}..{}", min, max),
259 None => write!(f, "{}+", min),
260 },
261 }
262 }
263}
264
265impl Arity {
266 /// Whether the given argument count matches this arity.
267 pub fn matches(&self, arg_count: usize) -> bool {
268 match self {
269 Arity::Exact(m) => *m == arg_count,
270 Arity::Variadic { min, max } => {
271 if arg_count < *min {
272 return false;
273 }
274 if let Some(max) = max
275 && arg_count > *max
276 {
277 return false;
278 }
279 true
280 }
281 }
282 }
283}
284
285/// Context for simplification.
286///
287/// Used to lazily compute input data types where simplification requires them.
288pub trait SimplifyCtx {
289 /// Get the data type of the given expression.
290 fn return_dtype(&self, expr: &Expression) -> VortexResult<DType>;
291}
292
293/// Arguments for expression execution.
294pub struct ExecutionArgs<'a> {
295 /// The inputs for the expression, one per child.
296 pub inputs: Vec<ArrayRef>,
297 /// The row count of the execution scope.
298 pub row_count: usize,
299 /// The execution context.
300 pub ctx: &'a mut ExecutionCtx,
301}
302
303#[derive(Clone, Debug, PartialEq, Eq, Hash)]
304pub struct EmptyOptions;
305impl Display for EmptyOptions {
306 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
307 write!(f, "")
308 }
309}
310
311/// Factory functions for vtables.
312pub trait ScalarFnVTableExt: ScalarFnVTable {
313 /// Bind this vtable with the given options into a [`ScalarFnRef`].
314 fn bind(&self, options: Self::Options) -> ScalarFnRef {
315 ScalarFn::new(self.clone(), options).erased()
316 }
317
318 /// Create a new expression with this vtable and the given options and children.
319 fn new_expr(
320 &self,
321 options: Self::Options,
322 children: impl IntoIterator<Item = Expression>,
323 ) -> Expression {
324 Self::try_new_expr(self, options, children).vortex_expect("Failed to create expression")
325 }
326
327 /// Try to create a new expression with this vtable and the given options and children.
328 fn try_new_expr(
329 &self,
330 options: Self::Options,
331 children: impl IntoIterator<Item = Expression>,
332 ) -> VortexResult<Expression> {
333 Expression::try_new(self.bind(options), children)
334 }
335}
336impl<V: ScalarFnVTable> ScalarFnVTableExt for V {}
337
338/// A reference to the name of a child expression.
339pub type ChildName = ArcRef<str>;