vortex_array/expr/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt;
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::hash::Hasher;
11use std::ops::Deref;
12use std::sync::Arc;
13
14use arcref::ArcRef;
15use vortex_dtype::DType;
16use vortex_error::VortexExpect;
17use vortex_error::VortexResult;
18use vortex_error::vortex_bail;
19use vortex_error::vortex_err;
20use vortex_vector::Vector;
21use vortex_vector::VectorOps;
22
23use crate::ArrayRef;
24use crate::expr::ExprId;
25use crate::expr::ExpressionView;
26use crate::expr::StatsCatalog;
27use crate::expr::expression::Expression;
28use crate::expr::stats::Stat;
29
30///
31/// This trait defines the interface for expression vtables, including methods for
32/// serialization, deserialization, validation, child naming, return type computation,
33/// and evaluation.
34///
35/// This trait is non-object safe and allows the implementer to make use of associated types
36/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
37/// outputs of each function.
38///
39/// The [`VTable`] trait should be implemented for a struct that holds global data across
40/// all instances of the expression. In almost all cases, this struct will be an empty unit
41/// struct, since most expressions do not require any global state.
42pub trait VTable: 'static + Sized + Send + Sync {
43    /// Instance data for this expression.
44    type Instance: 'static + Send + Sync + Debug + PartialEq + Eq + Hash;
45
46    /// Returns the ID of the expr vtable.
47    fn id(&self) -> ExprId;
48
49    /// Serialize the metadata for the expression.
50    ///
51    /// Should return `Ok(None)` if the expression is not serializable, and `Ok(vec![])` if it is
52    /// serializable but has no metadata.
53    fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
54        _ = instance;
55        Ok(None)
56    }
57
58    /// Deserialize an instance of this expression.
59    ///
60    /// Returns `Ok(None)` if the expression is not serializable.
61    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
62        _ = metadata;
63        Ok(None)
64    }
65
66    /// Validate the metadata and children for the expression.
67    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()>;
68
69    /// Returns the name of the nth child of the expr.
70    fn child_name(&self, instance: &Self::Instance, child_idx: usize) -> ChildName;
71
72    /// Format this expression in nice human-readable SQL-style format
73    ///
74    /// The implementation should recursively format child expressions by calling
75    /// `expr.child(i).fmt_sql(f)`.
76    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> fmt::Result;
77
78    /// Format only the instance data for this expression.
79    ///
80    /// Defaults to a debug representation of the instance data.
81    #[allow(clippy::use_debug)]
82    fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> fmt::Result {
83        write!(f, "{:?}", instance)
84    }
85
86    /// Compute the return [`DType`] of the expression if evaluated in the given scope.
87    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType>;
88
89    /// Evaluate the expression in the given scope.
90    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef>;
91
92    /// Execute the expression on the given vector with the given dtype.
93    fn execute(&self, data: &Self::Instance, args: ExecutionArgs) -> VortexResult<Vector> {
94        _ = data;
95        let _args = args;
96        // TODO(ngates): remove this once we port to vector execution
97        // TODO(ngates): I think we should take/return an enum of Vector/Scalar.
98        vortex_bail!("Expression {} does not support execution", self.id());
99    }
100
101    /// See [`Expression::stat_falsification`].
102    fn stat_falsification(
103        &self,
104        expr: &ExpressionView<Self>,
105        catalog: &dyn StatsCatalog,
106    ) -> Option<Expression> {
107        _ = expr;
108        _ = catalog;
109        None
110    }
111
112    /// See [`Expression::stat_expression`].
113    fn stat_expression(
114        &self,
115        expr: &ExpressionView<Self>,
116        stat: Stat,
117        catalog: &dyn StatsCatalog,
118    ) -> Option<Expression> {
119        _ = expr;
120        _ = stat;
121        _ = catalog;
122        None
123    }
124
125    /// Returns whether this expression itself is null-sensitive. Conservatively default to *true*.
126    ///
127    /// An expression is null-sensitive if it directly operates on null values,
128    /// such as `is_null`. Most expressions are not null-sensitive.
129    ///
130    /// The property we are interested in is if the expression (e) distributes over
131    /// mask.
132    /// Define a `mask(a, m)` expression that applies the boolean array `m` to the validity of the
133    /// array `a`.
134    /// An unary expression `e` to be null-sensitive iff forall arrays `a` and masks `m`.
135    /// `e(mask(a, m)) == mask(e(a), m)`.
136    /// This can be extended to an n-ary expression.
137    ///
138    /// This method only checks the expression itself, not its children. To check
139    /// if an expression or any of its descendants are null-sensitive.
140    fn is_null_sensitive(&self, instance: &Self::Instance) -> bool {
141        _ = instance;
142        true
143    }
144
145    /// Returns whether this expression itself is fallible. Conservatively default to *true*.
146    ///
147    /// An expression is runtime fallible is there is an input set that causes the expression to
148    /// panic or return an error, for example checked_add is fallible if there is overflow.
149    ///
150    /// Note: this is only applicable to expressions that pass type-checking
151    /// [`VTable::return_dtype`].
152    fn is_fallible(&self, instance: &Self::Instance) -> bool {
153        _ = instance;
154        true
155    }
156
157    /// **For internal usage**. This will return an Expression that is part of the
158    /// expression -> new_expression migration.
159    #[doc(hidden)]
160    fn expr_v2(&self, expr: &ExpressionView<Self>) -> VortexResult<Expression> {
161        Ok(expr.deref().clone())
162    }
163}
164
165/// Arguments for expression execution.
166pub struct ExecutionArgs {
167    /// The input vectors for the expression, one per child.
168    pub vectors: Vec<Vector>,
169    /// The input dtypes for the expression, one per child.
170    pub dtypes: Vec<DType>,
171    /// The row count of the execution scope.
172    pub row_count: usize,
173    /// The expected return dtype of the expression, as computed by [`Expression::return_dtype`].
174    pub return_dtype: DType,
175}
176
177/// Factory functions for static vtables.
178pub trait VTableExt: VTable {
179    fn new_expr(
180        &'static self,
181        instance: Self::Instance,
182        children: impl Into<Arc<[Expression]>>,
183    ) -> Expression {
184        Self::try_new_expr(self, instance, children)
185            .vortex_expect("Failed to create expression instance")
186    }
187
188    fn try_new_expr(
189        &'static self,
190        instance: Self::Instance,
191        children: impl Into<Arc<[Expression]>>,
192    ) -> VortexResult<Expression> {
193        Expression::try_new_erased(
194            ExprVTable::new_static(self),
195            Arc::new(instance),
196            children.into(),
197        )
198    }
199}
200impl<V: VTable> VTableExt for V {}
201
202/// A reference to the name of a child expression.
203pub type ChildName = ArcRef<str>;
204
205/// A placeholder vtable implementation for unsupported optional functionality of an expression.
206pub struct NotSupported;
207
208/// An object-safe trait for dynamic dispatch of Vortex expression vtables.
209///
210/// This trait is automatically implemented via the [`VTableAdapter`] for any type that
211/// implements [`VTable`], and lifts the associated types into dynamic trait objects.
212pub trait DynExprVTable: 'static + Send + Sync + private::Sealed {
213    fn as_any(&self) -> &dyn Any;
214    fn id(&self) -> ExprId;
215    fn serialize(&self, instance: &dyn Any) -> VortexResult<Option<Vec<u8>>>;
216    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Arc<dyn Any + Send + Sync>>>;
217    fn child_name(&self, instance: &dyn Any, child_idx: usize) -> ChildName;
218    fn validate(&self, expression: &Expression) -> VortexResult<()>;
219    fn fmt_sql(&self, expression: &Expression, f: &mut Formatter<'_>) -> fmt::Result;
220    fn fmt_data(&self, instance: &dyn Any, f: &mut Formatter<'_>) -> fmt::Result;
221    fn return_dtype(&self, expression: &Expression, scope: &DType) -> VortexResult<DType>;
222    fn evaluate(&self, expression: &Expression, scope: &ArrayRef) -> VortexResult<ArrayRef>;
223    fn execute(&self, data: &dyn Any, args: ExecutionArgs) -> VortexResult<Vector>;
224
225    fn stat_falsification(
226        &self,
227        expression: &Expression,
228        catalog: &dyn StatsCatalog,
229    ) -> Option<Expression>;
230    fn stat_expression(
231        &self,
232        expression: &Expression,
233        stat: Stat,
234        catalog: &dyn StatsCatalog,
235    ) -> Option<Expression>;
236
237    /// See [`VTable::is_null_sensitive`].
238    fn is_null_sensitive(&self, instance: &dyn Any) -> bool;
239    /// See [`VTable::is_fallible`].
240    fn is_fallible(&self, instance: &dyn Any) -> bool;
241
242    fn dyn_eq(&self, instance: &dyn Any, other: &dyn Any) -> bool;
243    fn dyn_hash(&self, instance: &dyn Any, state: &mut dyn Hasher);
244}
245
246#[repr(transparent)]
247pub struct VTableAdapter<V>(V);
248
249impl<V: VTable> DynExprVTable for VTableAdapter<V> {
250    #[inline(always)]
251    fn as_any(&self) -> &dyn Any {
252        self
253    }
254
255    #[inline(always)]
256    fn id(&self) -> ExprId {
257        V::id(&self.0)
258    }
259
260    fn serialize(&self, instance: &dyn Any) -> VortexResult<Option<Vec<u8>>> {
261        let instance = instance
262            .downcast_ref::<V::Instance>()
263            .vortex_expect("Failed to downcast expression instance to expected type");
264        V::serialize(&self.0, instance)
265    }
266
267    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Arc<dyn Any + Send + Sync>>> {
268        Ok(V::deserialize(&self.0, metadata)?
269            .map(|data| Arc::new(data) as Arc<dyn Any + Send + Sync>))
270    }
271
272    fn child_name(&self, instance: &dyn Any, child_idx: usize) -> ChildName {
273        let instance = instance
274            .downcast_ref::<V::Instance>()
275            .vortex_expect("Failed to downcast expression instance to expected type");
276        V::child_name(&self.0, instance, child_idx)
277    }
278
279    fn validate(&self, expression: &Expression) -> VortexResult<()> {
280        let expr = ExpressionView::new(expression);
281        V::validate(&self.0, &expr)
282    }
283
284    fn fmt_sql(&self, expression: &Expression, f: &mut Formatter<'_>) -> fmt::Result {
285        let expr = ExpressionView::new(expression);
286        V::fmt_sql(&self.0, &expr, f)
287    }
288
289    fn fmt_data(&self, instance: &dyn Any, f: &mut Formatter<'_>) -> fmt::Result {
290        let instance = instance
291            .downcast_ref::<V::Instance>()
292            .vortex_expect("Failed to downcast expression instance to expected type");
293        V::fmt_data(&self.0, instance, f)
294    }
295
296    fn return_dtype(&self, expression: &Expression, scope: &DType) -> VortexResult<DType> {
297        let expr = ExpressionView::new(expression);
298        V::return_dtype(&self.0, &expr, scope)
299    }
300
301    fn evaluate(&self, expression: &Expression, scope: &ArrayRef) -> VortexResult<ArrayRef> {
302        let expr = ExpressionView::new(expression);
303        V::evaluate(&self.0, &expr, scope)
304    }
305
306    fn execute(&self, data: &dyn Any, args: ExecutionArgs) -> VortexResult<Vector> {
307        let data = data
308            .downcast_ref::<V::Instance>()
309            .vortex_expect("Failed to downcast expression instance to expected type");
310
311        let expected_row_count = args.row_count;
312        #[cfg(debug_assertions)]
313        let expected_dtype = args.return_dtype.clone();
314
315        let result = V::execute(&self.0, data, args)?;
316
317        assert_eq!(
318            result.len(),
319            expected_row_count,
320            "Expression execution returned vector of length {}, but expected {}",
321            result.len(),
322            expected_row_count,
323        );
324
325        // In debug mode, validate that the output dtype matches the expected return dtype.
326        #[cfg(debug_assertions)]
327        {
328            use vortex_error::vortex_ensure;
329            use vortex_vector::vector_matches_dtype;
330            vortex_ensure!(
331                vector_matches_dtype(&result, &expected_dtype),
332                "Expression execution invalid for dtype {}",
333                expected_dtype
334            );
335        }
336
337        Ok(result)
338    }
339
340    fn stat_falsification(
341        &self,
342        expression: &Expression,
343        catalog: &dyn StatsCatalog,
344    ) -> Option<Expression> {
345        let expr = ExpressionView::new(expression);
346        V::stat_falsification(&self.0, &expr, catalog)
347    }
348
349    fn stat_expression(
350        &self,
351        expression: &Expression,
352        stat: Stat,
353        catalog: &dyn StatsCatalog,
354    ) -> Option<Expression> {
355        let expr = ExpressionView::new(expression);
356        V::stat_expression(&self.0, &expr, stat, catalog)
357    }
358
359    fn is_null_sensitive(&self, instance: &dyn Any) -> bool {
360        let instance = instance
361            .downcast_ref::<V::Instance>()
362            .vortex_expect("Failed to downcast expression instance to expected type");
363        V::is_null_sensitive(&self.0, instance)
364    }
365
366    fn is_fallible(&self, instance: &dyn Any) -> bool {
367        let instance = instance
368            .downcast_ref::<V::Instance>()
369            .vortex_expect("Failed to downcast expression instance to expected type");
370        V::is_fallible(&self.0, instance)
371    }
372
373    fn dyn_eq(&self, instance: &dyn Any, other: &dyn Any) -> bool {
374        let this_instance = instance
375            .downcast_ref::<V::Instance>()
376            .vortex_expect("Failed to downcast expression instance to expected type");
377        let other_instance = other
378            .downcast_ref::<V::Instance>()
379            .vortex_expect("Failed to downcast expression instance to expected type");
380        this_instance == other_instance
381    }
382
383    fn dyn_hash(&self, instance: &dyn Any, mut state: &mut dyn Hasher) {
384        let this_instance = instance
385            .downcast_ref::<V::Instance>()
386            .vortex_expect("Failed to downcast expression instance to expected type");
387        this_instance.hash(&mut state);
388    }
389}
390
391mod private {
392    use crate::expr::VTable;
393    use crate::expr::VTableAdapter;
394
395    pub trait Sealed {}
396    impl<V: VTable> Sealed for VTableAdapter<V> {}
397}
398
399/// A Vortex expression vtable, used to deserialize or instantiate expressions dynamically.
400#[derive(Clone)]
401pub struct ExprVTable(ArcRef<dyn DynExprVTable>);
402
403impl ExprVTable {
404    /// Only the vortex-array crate can actually invoke the vtable methods.
405    /// All other users must go via session extensions.
406    pub(crate) fn as_dyn(&self) -> &dyn DynExprVTable {
407        self.0.as_ref()
408    }
409
410    /// Creates a new [`ExprVTable`] from a vtable.
411    pub fn new<V: VTable>(vtable: V) -> Self {
412        Self(ArcRef::new_arc(Arc::new(VTableAdapter(vtable))))
413    }
414
415    /// Creates a new [`ExprVTable`] from a static reference to a vtable.
416    pub const fn new_static<V: VTable>(vtable: &'static V) -> Self {
417        // SAFETY: We can safely cast the vtable to a VTableAdapter since it has the same layout.
418        let adapted: &'static VTableAdapter<V> =
419            unsafe { &*(vtable as *const V as *const VTableAdapter<V>) };
420        Self(ArcRef::new_ref(adapted as &'static dyn DynExprVTable))
421    }
422
423    /// Returns the ID of this vtable.
424    pub fn id(&self) -> ExprId {
425        self.0.id()
426    }
427
428    /// Returns whether this vtable is of a given type.
429    pub fn is<V: VTable>(&self) -> bool {
430        self.0.as_any().is::<VTableAdapter<V>>()
431    }
432
433    /// Returns the typed VTable for this expression.
434    pub fn as_opt<V: VTable>(&self) -> Option<&V> {
435        self.0
436            .as_any()
437            .downcast_ref::<VTableAdapter<V>>()
438            .map(|adapter| &adapter.0)
439    }
440
441    /// Deserialize an instance of this expression vtable from metadata.
442    pub fn deserialize(
443        &self,
444        metadata: &[u8],
445        children: Arc<[Expression]>,
446    ) -> VortexResult<Expression> {
447        let instance_data = self.as_dyn().deserialize(metadata)?.ok_or_else(|| {
448            vortex_err!(
449                "Expression vtable {} is not deserializable",
450                self.as_dyn().id()
451            )
452        })?;
453        Expression::try_new_erased(self.clone(), instance_data, children)
454    }
455}
456
457impl PartialEq for ExprVTable {
458    fn eq(&self, other: &Self) -> bool {
459        self.0.id() == other.0.id()
460    }
461}
462impl Eq for ExprVTable {}
463
464impl Display for ExprVTable {
465    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
466        write!(f, "{}", self.as_dyn().id())
467    }
468}
469
470impl Debug for ExprVTable {
471    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
472        write!(f, "{}", self.as_dyn().id())
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use rstest::fixture;
479    use rstest::rstest;
480
481    use super::*;
482    use crate::expr::exprs::between::between;
483    use crate::expr::exprs::binary::and;
484    use crate::expr::exprs::binary::checked_add;
485    use crate::expr::exprs::binary::eq;
486    use crate::expr::exprs::binary::gt;
487    use crate::expr::exprs::binary::gt_eq;
488    use crate::expr::exprs::binary::lt;
489    use crate::expr::exprs::binary::lt_eq;
490    use crate::expr::exprs::binary::not_eq;
491    use crate::expr::exprs::binary::or;
492    use crate::expr::exprs::cast::cast;
493    use crate::expr::exprs::get_item::col;
494    use crate::expr::exprs::get_item::get_item;
495    use crate::expr::exprs::is_null::is_null;
496    use crate::expr::exprs::list_contains::list_contains;
497    use crate::expr::exprs::literal::lit;
498    use crate::expr::exprs::merge::merge;
499    use crate::expr::exprs::not::not;
500    use crate::expr::exprs::pack::pack;
501    use crate::expr::exprs::root::root;
502    use crate::expr::exprs::select::select;
503    use crate::expr::exprs::select::select_exclude;
504    use crate::expr::proto::ExprSerializeProtoExt;
505    use crate::expr::proto::deserialize_expr_proto;
506    use crate::expr::session::ExprRegistry;
507    use crate::expr::session::ExprSession;
508
509    #[fixture]
510    #[once]
511    fn registry() -> ExprRegistry {
512        ExprSession::default().registry().clone()
513    }
514
515    #[rstest]
516    // Root and selection expressions
517    #[case(root())]
518    #[case(select(["hello", "world"], root()))]
519    #[case(select_exclude(["world", "hello"], root()))]
520    // Literal expressions
521    #[case(lit(42i32))]
522    #[case(lit(std::f64::consts::PI))]
523    #[case(lit(true))]
524    #[case(lit("hello"))]
525    // Column access expressions
526    #[case(col("column_name"))]
527    #[case(get_item("field", root()))]
528    // Binary comparison expressions
529    #[case(eq(col("a"), lit(10)))]
530    #[case(not_eq(col("a"), lit(10)))]
531    #[case(gt(col("a"), lit(10)))]
532    #[case(gt_eq(col("a"), lit(10)))]
533    #[case(lt(col("a"), lit(10)))]
534    #[case(lt_eq(col("a"), lit(10)))]
535    // Logical expressions
536    #[case(and(col("a"), col("b")))]
537    #[case(or(col("a"), col("b")))]
538    #[case(not(col("a")))]
539    // Arithmetic expressions
540    #[case(checked_add(col("a"), lit(5)))]
541    // Null check expressions
542    #[case(is_null(col("nullable_col")))]
543    // Type casting expressions
544    #[case(cast(
545        col("a"),
546        DType::Primitive(vortex_dtype::PType::I64, vortex_dtype::Nullability::NonNullable)
547    ))]
548    // Between expressions
549    #[case(between(col("a"), lit(10), lit(20), crate::compute::BetweenOptions { lower_strict: crate::compute::StrictComparison::NonStrict, upper_strict: crate::compute::StrictComparison::NonStrict }))]
550    // List contains expressions
551    #[case(list_contains(col("list_col"), lit("item")))]
552    // Pack expressions - creating struct from fields
553    #[case(pack([("field1", col("a")), ("field2", col("b"))], vortex_dtype::Nullability::NonNullable))]
554    // Merge expressions - merging struct expressions
555    #[case(merge([col("struct1"), col("struct2")]))]
556    // Complex nested expressions
557    #[case(and(gt(col("a"), lit(0)), lt(col("a"), lit(100))))]
558    #[case(or(is_null(col("a")), eq(col("a"), lit(0))))]
559    #[case(not(and(eq(col("status"), lit("active")), gt(col("age"), lit(18)))))]
560    fn text_expr_serde_round_trip(
561        registry: &ExprRegistry,
562        #[case] expr: Expression,
563    ) -> VortexResult<()> {
564        let serialized_pb = (&expr).serialize_proto()?;
565        let deserialized_expr = deserialize_expr_proto(&serialized_pb, registry)?;
566
567        assert_eq!(&expr, &deserialized_expr);
568
569        Ok(())
570    }
571}