vortex_array/expr/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt;
6use std::fmt::{Debug, Display, Formatter};
7use std::hash::{Hash, Hasher};
8use std::sync::Arc;
9
10use arcref::ArcRef;
11use vortex_dtype::{DType, FieldPath};
12use vortex_error::{VortexExpect, VortexResult, vortex_err};
13
14use crate::ArrayRef;
15use crate::expr::expression::Expression;
16use crate::expr::{ExprId, ExpressionView, StatsCatalog};
17
18///
19/// This trait defines the interface for expression vtables, including methods for
20/// serialization, deserialization, validation, child naming, return type computation,
21/// and evaluation.
22///
23/// This trait is non-object safe and allows the implementer to make use of associated types
24/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
25/// outputs of each function.
26///
27/// The [`VTable`] trait should be implemented for a struct that holds global data across
28/// all instances of the expression. In almost all cases, this struct will be an empty unit
29/// struct, since most expressions do not require any global state.
30pub trait VTable: 'static + Sized + Send + Sync {
31    /// Instance data for this expression.
32    type Instance: 'static + Send + Sync + Debug + PartialEq + Eq + Hash;
33
34    /// Returns the ID of the expr vtable.
35    fn id(&self) -> ExprId;
36
37    /// Serialize the metadata for the expression.
38    ///
39    /// Should return `Ok(None)` if the expression is not serializable, and `Ok(vec![])` if it is
40    /// serializable but has no metadata.
41    fn serialize(&self, _instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
42        Ok(None)
43    }
44
45    /// Deserialize an instance of this expression.
46    ///
47    /// Returns `Ok(None)` if the expression is not serializable.
48    fn deserialize(&self, _metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
49        Ok(None)
50    }
51
52    /// Validate the metadata and children for the expression.
53    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()>;
54
55    /// Returns the name of the nth child of the expr.
56    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName;
57
58    /// Format this expression in nice human-readable SQL-style format
59    ///
60    /// The implementation should recursively format child expressions by calling
61    /// `expr.child(i).fmt_sql(f)`.
62    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> fmt::Result;
63
64    /// Format only the instance data for this expression.
65    ///
66    /// Defaults to a debug representation of the instance data.
67    #[allow(clippy::use_debug)]
68    fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> fmt::Result {
69        write!(f, "{:?}", instance)
70    }
71
72    /// Compute the return [`DType`] of the expression if evaluated in the given scope.
73    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType>;
74
75    /// Evaluate the expression in the given scope.
76    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef>;
77
78    /// See [`crate::expr::Expression::stat_falsification`].
79    fn stat_falsification(
80        &self,
81        _expr: &ExpressionView<Self>,
82        _catalog: &mut dyn StatsCatalog,
83    ) -> Option<Expression> {
84        None
85    }
86
87    /// See [`crate::expr::Expression::stat_max`].
88    fn stat_max(
89        &self,
90        _expr: &ExpressionView<Self>,
91        _catalog: &mut dyn StatsCatalog,
92    ) -> Option<Expression> {
93        None
94    }
95
96    /// See [`crate::expr::Expression::stat_min`].
97    fn stat_min(
98        &self,
99        _expr: &ExpressionView<Self>,
100        _catalog: &mut dyn StatsCatalog,
101    ) -> Option<Expression> {
102        None
103    }
104
105    /// See [`crate::expr::Expression::stat_nan_count`].
106    fn stat_nan_count(
107        &self,
108        _expr: &ExpressionView<Self>,
109        _catalog: &mut dyn StatsCatalog,
110    ) -> Option<Expression> {
111        None
112    }
113
114    /// See [`crate::expr::Expression::stat_field_path`].
115    fn stat_field_path(&self, _expr: &ExpressionView<Self>) -> Option<FieldPath> {
116        None
117    }
118}
119
120/// Factory functions for static vtables.
121pub trait VTableExt: VTable {
122    fn new_expr(
123        &'static self,
124        instance: Self::Instance,
125        children: impl Into<Arc<[Expression]>>,
126    ) -> Expression {
127        Self::try_new_expr(self, instance, children)
128            .vortex_expect("Failed to create expression instance")
129    }
130
131    fn try_new_expr(
132        &'static self,
133        instance: Self::Instance,
134        children: impl Into<Arc<[Expression]>>,
135    ) -> VortexResult<Expression> {
136        Expression::try_new(
137            ExprVTable::from_static(self),
138            Arc::new(instance),
139            children.into(),
140        )
141    }
142}
143impl<V: VTable> VTableExt for V {}
144
145/// A reference to the name of a child expression.
146pub type ChildName = ArcRef<str>;
147
148/// A placeholder vtable implementation for unsupported optional functionality of an expression.
149pub struct NotSupported;
150
151/// An object-safe trait for dynamic dispatch of Vortex expression vtables.
152///
153/// This trait is automatically implemented via the [`VTableAdapter`] for any type that
154/// implements [`VTable`], and lifts the associated types into dynamic trait objects.
155pub trait DynExprVTable: 'static + Send + Sync + private::Sealed {
156    fn as_any(&self) -> &dyn Any;
157    fn id(&self) -> ExprId;
158    fn serialize(&self, instance: &dyn Any) -> VortexResult<Option<Vec<u8>>>;
159    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Arc<dyn Any + Send + Sync>>>;
160    fn child_name(&self, instance: &dyn Any, child_idx: usize) -> ChildName;
161    fn validate(&self, expression: &Expression) -> VortexResult<()>;
162    fn fmt_sql(&self, expression: &Expression, f: &mut Formatter<'_>) -> fmt::Result;
163    fn fmt_data(&self, instance: &dyn Any, f: &mut Formatter<'_>) -> fmt::Result;
164    fn return_dtype(&self, expression: &Expression, scope: &DType) -> VortexResult<DType>;
165    fn evaluate(&self, expression: &Expression, scope: &ArrayRef) -> VortexResult<ArrayRef>;
166
167    fn stat_falsification(
168        &self,
169        expression: &Expression,
170        catalog: &mut dyn StatsCatalog,
171    ) -> Option<Expression>;
172    fn stat_max(
173        &self,
174        expression: &Expression,
175        catalog: &mut dyn StatsCatalog,
176    ) -> Option<Expression>;
177    fn stat_min(
178        &self,
179        expression: &Expression,
180        catalog: &mut dyn StatsCatalog,
181    ) -> Option<Expression>;
182    fn stat_nan_count(
183        &self,
184        expression: &Expression,
185        catalog: &mut dyn StatsCatalog,
186    ) -> Option<Expression>;
187    fn stat_field_path(&self, expression: &Expression) -> Option<FieldPath>;
188
189    fn dyn_eq(&self, instance: &dyn Any, other: &dyn Any) -> bool;
190    fn dyn_hash(&self, instance: &dyn Any, state: &mut dyn Hasher);
191}
192
193#[repr(transparent)]
194pub struct VTableAdapter<V>(V);
195
196impl<V: VTable> DynExprVTable for VTableAdapter<V> {
197    #[inline(always)]
198    fn as_any(&self) -> &dyn Any {
199        self
200    }
201
202    #[inline(always)]
203    fn id(&self) -> ExprId {
204        V::id(&self.0)
205    }
206
207    fn serialize(&self, instance: &dyn Any) -> VortexResult<Option<Vec<u8>>> {
208        let instance = instance
209            .downcast_ref::<V::Instance>()
210            .vortex_expect("Failed to downcast expression instance to expected type");
211        V::serialize(&self.0, instance)
212    }
213
214    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Arc<dyn Any + Send + Sync>>> {
215        Ok(V::deserialize(&self.0, metadata)?
216            .map(|data| Arc::new(data) as Arc<dyn Any + Send + Sync>))
217    }
218
219    fn child_name(&self, instance: &dyn Any, child_idx: usize) -> ChildName {
220        let instance = instance
221            .downcast_ref::<V::Instance>()
222            .vortex_expect("Failed to downcast expression instance to expected type");
223        V::child_name(&self.0, instance, child_idx)
224    }
225
226    fn validate(&self, expression: &Expression) -> VortexResult<()> {
227        let expr = ExpressionView::new(expression);
228        V::validate(&self.0, &expr)
229    }
230
231    fn fmt_sql(&self, expression: &Expression, f: &mut Formatter<'_>) -> fmt::Result {
232        let expr = ExpressionView::new(expression);
233        V::fmt_sql(&self.0, &expr, f)
234    }
235
236    fn fmt_data(&self, instance: &dyn Any, f: &mut Formatter<'_>) -> fmt::Result {
237        let instance = instance
238            .downcast_ref::<V::Instance>()
239            .vortex_expect("Failed to downcast expression instance to expected type");
240        V::fmt_data(&self.0, instance, f)
241    }
242
243    fn return_dtype(&self, expression: &Expression, scope: &DType) -> VortexResult<DType> {
244        let expr = ExpressionView::new(expression);
245        V::return_dtype(&self.0, &expr, scope)
246    }
247
248    fn evaluate(&self, expression: &Expression, scope: &ArrayRef) -> VortexResult<ArrayRef> {
249        let expr = ExpressionView::new(expression);
250        V::evaluate(&self.0, &expr, scope)
251    }
252
253    fn stat_falsification(
254        &self,
255        expression: &Expression,
256        catalog: &mut dyn StatsCatalog,
257    ) -> Option<Expression> {
258        let expr = ExpressionView::new(expression);
259        V::stat_falsification(&self.0, &expr, catalog)
260    }
261
262    fn stat_max(
263        &self,
264        expression: &Expression,
265        catalog: &mut dyn StatsCatalog,
266    ) -> Option<Expression> {
267        let expr = ExpressionView::new(expression);
268        V::stat_max(&self.0, &expr, catalog)
269    }
270
271    fn stat_min(
272        &self,
273        expression: &Expression,
274        catalog: &mut dyn StatsCatalog,
275    ) -> Option<Expression> {
276        let expr = ExpressionView::new(expression);
277        V::stat_min(&self.0, &expr, catalog)
278    }
279
280    fn stat_nan_count(
281        &self,
282        expression: &Expression,
283        catalog: &mut dyn StatsCatalog,
284    ) -> Option<Expression> {
285        let expr = ExpressionView::new(expression);
286        V::stat_nan_count(&self.0, &expr, catalog)
287    }
288
289    fn stat_field_path(&self, expression: &Expression) -> Option<FieldPath> {
290        let expr = ExpressionView::new(expression);
291        V::stat_field_path(&self.0, &expr)
292    }
293
294    fn dyn_eq(&self, instance: &dyn Any, other: &dyn Any) -> bool {
295        let this_instance = instance
296            .downcast_ref::<V::Instance>()
297            .vortex_expect("Failed to downcast expression instance to expected type");
298        let other_instance = other
299            .downcast_ref::<V::Instance>()
300            .vortex_expect("Failed to downcast expression instance to expected type");
301        this_instance == other_instance
302    }
303
304    fn dyn_hash(&self, instance: &dyn Any, mut state: &mut dyn Hasher) {
305        let this_instance = instance
306            .downcast_ref::<V::Instance>()
307            .vortex_expect("Failed to downcast expression instance to expected type");
308        this_instance.hash(&mut state);
309    }
310}
311
312mod private {
313    use crate::expr::{VTable, VTableAdapter};
314
315    pub trait Sealed {}
316    impl<V: VTable> Sealed for VTableAdapter<V> {}
317}
318
319/// A Vortex expression vtable, used to deserialize or instantiate expressions dynamically.
320#[derive(Clone)]
321pub struct ExprVTable(ArcRef<dyn DynExprVTable>);
322
323impl ExprVTable {
324    /// Only the vortex-array crate can actually invoke the vtable methods.
325    /// All other users must go via session extensions.
326    pub(crate) fn as_dyn(&self) -> &dyn DynExprVTable {
327        self.0.as_ref()
328    }
329
330    /// Creates a new [`ExprVTable`] from a static reference to a vtable.
331    pub const fn from_static<V: VTable>(vtable: &'static V) -> Self {
332        // SAFETY: We can safely cast the vtable to a VTableAdapter since it has the same layout.
333        let adapted: &'static VTableAdapter<V> =
334            unsafe { &*(vtable as *const V as *const VTableAdapter<V>) };
335        Self(ArcRef::new_ref(adapted as &'static dyn DynExprVTable))
336    }
337
338    /// Returns the ID of this vtable.
339    pub fn id(&self) -> ExprId {
340        self.0.id()
341    }
342
343    /// Returns whether this vtable is of a given type.
344    pub fn is<V: VTable>(&self) -> bool {
345        self.0.as_any().is::<VTableAdapter<V>>()
346    }
347
348    /// Returns the typed VTable for this expression.
349    pub fn as_opt<V: VTable>(&self) -> Option<&V> {
350        self.0
351            .as_any()
352            .downcast_ref::<VTableAdapter<V>>()
353            .map(|adapter| &adapter.0)
354    }
355
356    /// Deserialize an instance of this expression vtable from metadata.
357    pub fn deserialize(
358        &self,
359        metadata: &[u8],
360        children: Arc<[Expression]>,
361    ) -> VortexResult<Expression> {
362        let instance_data = self.as_dyn().deserialize(metadata)?.ok_or_else(|| {
363            vortex_err!(
364                "Expression vtable {} is not deserializable",
365                self.as_dyn().id()
366            )
367        })?;
368        Expression::try_new(self.clone(), instance_data, children)
369    }
370}
371
372impl PartialEq for ExprVTable {
373    fn eq(&self, other: &Self) -> bool {
374        self.0.id() == other.0.id()
375    }
376}
377impl Eq for ExprVTable {}
378
379impl Display for ExprVTable {
380    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
381        write!(f, "{}", self.as_dyn().id())
382    }
383}
384
385impl Debug for ExprVTable {
386    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
387        write!(f, "{}", self.as_dyn().id())
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use rstest::{fixture, rstest};
394
395    use super::*;
396    use crate::expr::exprs::between::between;
397    use crate::expr::exprs::binary::{and, checked_add, eq, gt, gt_eq, lt, lt_eq, not_eq, or};
398    use crate::expr::exprs::cast::cast;
399    use crate::expr::exprs::get_item::{col, get_item};
400    use crate::expr::exprs::is_null::is_null;
401    use crate::expr::exprs::list_contains::list_contains;
402    use crate::expr::exprs::literal::lit;
403    use crate::expr::exprs::merge::merge;
404    use crate::expr::exprs::not::not;
405    use crate::expr::exprs::pack::pack;
406    use crate::expr::exprs::root::root;
407    use crate::expr::exprs::select::{select, select_exclude};
408    use crate::expr::proto::{ExprSerializeProtoExt, deserialize_expr_proto};
409    use crate::expr::session::{ExprRegistry, ExprSession};
410
411    #[fixture]
412    #[once]
413    fn registry() -> ExprRegistry {
414        ExprSession::default().registry().clone()
415    }
416
417    #[rstest]
418    // Root and selection expressions
419    #[case(root())]
420    #[case(select(["hello", "world"], root()))]
421    #[case(select_exclude(["world", "hello"], root()))]
422    // Literal expressions
423    #[case(lit(42i32))]
424    #[case(lit(std::f64::consts::PI))]
425    #[case(lit(true))]
426    #[case(lit("hello"))]
427    // Column access expressions
428    #[case(col("column_name"))]
429    #[case(get_item("field", root()))]
430    // Binary comparison expressions
431    #[case(eq(col("a"), lit(10)))]
432    #[case(not_eq(col("a"), lit(10)))]
433    #[case(gt(col("a"), lit(10)))]
434    #[case(gt_eq(col("a"), lit(10)))]
435    #[case(lt(col("a"), lit(10)))]
436    #[case(lt_eq(col("a"), lit(10)))]
437    // Logical expressions
438    #[case(and(col("a"), col("b")))]
439    #[case(or(col("a"), col("b")))]
440    #[case(not(col("a")))]
441    // Arithmetic expressions
442    #[case(checked_add(col("a"), lit(5)))]
443    // Null check expressions
444    #[case(is_null(col("nullable_col")))]
445    // Type casting expressions
446    #[case(cast(
447        col("a"),
448        DType::Primitive(vortex_dtype::PType::I64, vortex_dtype::Nullability::NonNullable)
449    ))]
450    // Between expressions
451    #[case(between(col("a"), lit(10), lit(20), crate::compute::BetweenOptions { lower_strict: crate::compute::StrictComparison::NonStrict, upper_strict: crate::compute::StrictComparison::NonStrict }))]
452    // List contains expressions
453    #[case(list_contains(col("list_col"), lit("item")))]
454    // Pack expressions - creating struct from fields
455    #[case(pack([("field1", col("a")), ("field2", col("b"))], vortex_dtype::Nullability::NonNullable))]
456    // Merge expressions - merging struct expressions
457    #[case(merge([col("struct1"), col("struct2")]))]
458    // Complex nested expressions
459    #[case(and(gt(col("a"), lit(0)), lt(col("a"), lit(100))))]
460    #[case(or(is_null(col("a")), eq(col("a"), lit(0))))]
461    #[case(not(and(eq(col("status"), lit("active")), gt(col("age"), lit(18)))))]
462    fn text_expr_serde_round_trip(
463        registry: &ExprRegistry,
464        #[case] expr: Expression,
465    ) -> VortexResult<()> {
466        let serialized_pb = (&expr).serialize_proto()?;
467        let deserialized_expr = deserialize_expr_proto(&serialized_pb, registry)?;
468
469        assert_eq!(&expr, &deserialized_expr);
470
471        Ok(())
472    }
473}