vortex_expr/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::ops::Deref;
7
8use vortex_array::{ArrayRef, DeserializeMetadata, SerializeMetadata};
9use vortex_dtype::DType;
10use vortex_error::VortexResult;
11
12use crate::display::DisplayAs;
13use crate::{
14    AnalysisExpr, ExprEncoding, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VortexExpr,
15};
16
17pub trait VTable: 'static + Sized + Send + Sync + Debug {
18    type Expr: 'static
19        + Send
20        + Sync
21        + Clone
22        + Debug
23        + DisplayAs
24        + PartialEq
25        + Eq
26        + Hash
27        + Deref<Target = dyn VortexExpr>
28        + IntoExpr
29        + AnalysisExpr;
30    type Encoding: 'static + Send + Sync + Deref<Target = dyn ExprEncoding>;
31    type Metadata: SerializeMetadata + DeserializeMetadata + Debug;
32
33    /// Returns the ID of the expr encoding.
34    fn id(encoding: &Self::Encoding) -> ExprId;
35
36    /// Returns the encoding for the expr.
37    fn encoding(expr: &Self::Expr) -> ExprEncodingRef;
38
39    /// Returns the serialize-able metadata for the expr, or `None` if serialization is not
40    /// supported.
41    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata>;
42
43    /// Returns the children of the expr.
44    fn children(expr: &Self::Expr) -> Vec<&ExprRef>;
45
46    /// Return a new instance of the expression with the children replaced.
47    ///
48    /// ## Preconditions
49    ///
50    /// The number of children will match the current number of children in the expression.
51    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr>;
52
53    /// Construct a new [`VortexExpr`] from the provided parts.
54    fn build(
55        encoding: &Self::Encoding,
56        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
57        children: Vec<ExprRef>,
58    ) -> VortexResult<Self::Expr>;
59
60    /// Evaluate the expression in the given scope.
61    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef>;
62
63    /// Compute the return [`DType`] of the expression if evaluated in the given scope.
64    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType>;
65}
66
67#[macro_export]
68macro_rules! vtable {
69    ($V:ident) => {
70        $crate::aliases::paste::paste! {
71            #[derive(Debug)]
72            pub struct [<$V VTable>];
73
74            impl AsRef<dyn $crate::VortexExpr> for [<$V Expr>] {
75                fn as_ref(&self) -> &dyn $crate::VortexExpr {
76                    // We can unsafe cast ourselves to a ExprAdapter.
77                    unsafe { &*(self as *const [<$V Expr>] as *const $crate::ExprAdapter<[<$V VTable>]>) }
78                }
79            }
80
81            impl std::ops::Deref for [<$V Expr>] {
82                type Target = dyn $crate::VortexExpr;
83
84                fn deref(&self) -> &Self::Target {
85                    // We can unsafe cast ourselves to an ExprAdapter.
86                    unsafe { &*(self as *const [<$V Expr>] as *const $crate::ExprAdapter<[<$V VTable>]>) }
87                }
88            }
89
90            impl $crate::IntoExpr for [<$V Expr>] {
91                fn into_expr(self) -> $crate::ExprRef {
92                    // We can unsafe transmute ourselves to an ExprAdapter.
93                    std::sync::Arc::new(unsafe { std::mem::transmute::<[<$V Expr>], $crate::ExprAdapter::<[<$V VTable>]>>(self) })
94                }
95            }
96
97            impl From<[<$V Expr>]> for $crate::ExprRef {
98                fn from(value: [<$V Expr>]) -> $crate::ExprRef {
99                    use $crate::IntoExpr;
100                    value.into_expr()
101                }
102            }
103
104            impl AsRef<dyn $crate::ExprEncoding> for [<$V ExprEncoding>] {
105                fn as_ref(&self) -> &dyn $crate::ExprEncoding {
106                    // We can unsafe cast ourselves to an ExprEncodingAdapter.
107                    unsafe { &*(self as *const [<$V ExprEncoding>] as *const $crate::ExprEncodingAdapter<[<$V VTable>]>) }
108                }
109            }
110
111            impl std::ops::Deref for [<$V ExprEncoding>] {
112                type Target = dyn $crate::ExprEncoding;
113
114                fn deref(&self) -> &Self::Target {
115                    // We can unsafe cast ourselves to an ExprEncodingAdapter.
116                    unsafe { &*(self as *const [<$V ExprEncoding>] as *const $crate::ExprEncodingAdapter<[<$V VTable>]>) }
117                }
118            }
119        }
120    };
121}
122
123#[cfg(test)]
124mod tests {
125
126    use rstest::{fixture, rstest};
127
128    use super::*;
129    use crate::proto::{ExprSerializeProtoExt, deserialize_expr_proto};
130    use crate::*;
131
132    #[fixture]
133    #[once]
134    fn registry() -> ExprRegistry {
135        ExprRegistry::default()
136    }
137
138    #[rstest]
139    // Root and selection expressions
140    #[case(root())]
141    #[case(select(["hello", "world"], root()))]
142    #[case(select_exclude(["world", "hello"], root()))]
143    // Literal expressions
144    #[case(lit(42i32))]
145    #[case(lit(std::f64::consts::PI))]
146    #[case(lit(true))]
147    #[case(lit("hello"))]
148    // Column access expressions
149    #[case(col("column_name"))]
150    #[case(get_item("field", root()))]
151    // Binary comparison expressions
152    #[case(eq(col("a"), lit(10)))]
153    #[case(not_eq(col("a"), lit(10)))]
154    #[case(gt(col("a"), lit(10)))]
155    #[case(gt_eq(col("a"), lit(10)))]
156    #[case(lt(col("a"), lit(10)))]
157    #[case(lt_eq(col("a"), lit(10)))]
158    // Logical expressions
159    #[case(and(col("a"), col("b")))]
160    #[case(or(col("a"), col("b")))]
161    #[case(not(col("a")))]
162    // Arithmetic expressions
163    #[case(checked_add(col("a"), lit(5)))]
164    // Null check expressions
165    #[case(is_null(col("nullable_col")))]
166    // Type casting expressions
167    #[case(cast(
168        col("a"),
169        DType::Primitive(vortex_dtype::PType::I64, vortex_dtype::Nullability::NonNullable)
170    ))]
171    // Between expressions
172    #[case(between(col("a"), lit(10), lit(20), vortex_array::compute::BetweenOptions { lower_strict: vortex_array::compute::StrictComparison::NonStrict, upper_strict: vortex_array::compute::StrictComparison::NonStrict }))]
173    // List contains expressions
174    #[case(list_contains(col("list_col"), lit("item")))]
175    // Pack expressions - creating struct from fields
176    #[case(pack([("field1", col("a")), ("field2", col("b"))], vortex_dtype::Nullability::NonNullable))]
177    // Merge expressions - merging struct expressions
178    #[case(merge([col("struct1"), col("struct2")], vortex_dtype::Nullability::NonNullable))]
179    // Complex nested expressions
180    #[case(and(gt(col("a"), lit(0)), lt(col("a"), lit(100))))]
181    #[case(or(is_null(col("a")), eq(col("a"), lit(0))))]
182    #[case(not(and(eq(col("status"), lit("active")), gt(col("age"), lit(18)))))]
183    fn text_expr_serde_round_trip(
184        registry: &ExprRegistry,
185        #[case] expr: ExprRef,
186    ) -> anyhow::Result<()> {
187        let serialized_pb = expr.serialize_proto()?;
188        let deserialized_expr = deserialize_expr_proto(&serialized_pb, registry)?;
189
190        assert_eq!(&expr, &deserialized_expr);
191
192        Ok(())
193    }
194}