Skip to main content

vortex_array/expr/
proto.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use itertools::Itertools;
7use vortex_error::VortexResult;
8use vortex_error::vortex_err;
9use vortex_proto::expr as pb;
10use vortex_session::VortexSession;
11
12use crate::expr::Expression;
13use crate::scalar_fn::ForeignScalarFnVTable;
14use crate::scalar_fn::ScalarFnId;
15use crate::scalar_fn::session::ScalarFnSessionExt;
16
17pub trait ExprSerializeProtoExt {
18    /// Serialize the expression to its protobuf representation.
19    fn serialize_proto(&self) -> VortexResult<pb::Expr>;
20}
21
22impl ExprSerializeProtoExt for Expression {
23    fn serialize_proto(&self) -> VortexResult<pb::Expr> {
24        let children = self
25            .children()
26            .iter()
27            .map(|child| child.serialize_proto())
28            .try_collect()?;
29
30        let metadata = self.options().serialize()?.ok_or_else(|| {
31            vortex_err!("Expression '{}' is not serializable: {}", self.id(), self)
32        })?;
33
34        Ok(pb::Expr {
35            id: self.id().to_string(),
36            children,
37            metadata: Some(metadata),
38        })
39    }
40}
41
42impl Expression {
43    pub fn from_proto(expr: &pb::Expr, session: &VortexSession) -> VortexResult<Expression> {
44        let expr_id = ScalarFnId::new_arc(Arc::from(expr.id.to_string()));
45        let children = expr
46            .children
47            .iter()
48            .map(|e| Expression::from_proto(e, session))
49            .collect::<VortexResult<Vec<_>>>()?;
50
51        let scalar_fn = if let Some(vtable) = session.scalar_fns().registry().find(&expr_id) {
52            vtable.deserialize(expr.metadata(), session)?
53        } else if session.allows_unknown() {
54            ForeignScalarFnVTable::make_scalar_fn(expr_id, expr.metadata().to_vec(), children.len())
55        } else {
56            return Err(vortex_err!("unknown expression id: {}", expr_id));
57        };
58
59        Expression::try_new(scalar_fn, children)
60    }
61}
62
63/// Deserialize a [`Expression`] from the protobuf representation.
64#[deprecated(note = "Use Expression::from_proto instead")]
65pub fn deserialize_expr_proto(
66    expr: &pb::Expr,
67    session: &VortexSession,
68) -> VortexResult<Expression> {
69    Expression::from_proto(expr, session)
70}
71
72#[cfg(test)]
73mod tests {
74    use prost::Message;
75    use vortex_proto::expr as pb;
76    use vortex_session::VortexSession;
77
78    use super::ExprSerializeProtoExt;
79    use crate::LEGACY_SESSION;
80    use crate::expr::Expression;
81    use crate::expr::and;
82    use crate::expr::between;
83    use crate::expr::eq;
84    use crate::expr::get_item;
85    use crate::expr::lit;
86    use crate::expr::or;
87    use crate::expr::root;
88    use crate::scalar_fn::fns::between::BetweenOptions;
89    use crate::scalar_fn::fns::between::StrictComparison;
90    use crate::scalar_fn::session::ScalarFnSession;
91
92    #[test]
93    fn expression_serde() {
94        let expr: Expression = or(
95            and(
96                between(
97                    lit(1),
98                    root(),
99                    get_item("a", root()),
100                    BetweenOptions {
101                        lower_strict: StrictComparison::Strict,
102                        upper_strict: StrictComparison::Strict,
103                    },
104                ),
105                lit(1),
106            ),
107            eq(lit(1), root()),
108        );
109
110        let s_expr = expr.serialize_proto().unwrap();
111        let buf = s_expr.encode_to_vec();
112        let s_expr = pb::Expr::decode(buf.as_slice()).unwrap();
113        let deser_expr = Expression::from_proto(&s_expr, &LEGACY_SESSION).unwrap();
114
115        assert_eq!(&deser_expr, &expr);
116    }
117
118    #[test]
119    fn unknown_expression_id_allow_unknown() {
120        let session = VortexSession::empty()
121            .with::<ScalarFnSession>()
122            .allow_unknown();
123
124        let expr_proto = pb::Expr {
125            id: "vortex.test.foreign_scalar_fn".to_string(),
126            metadata: Some(vec![1, 2, 3, 4]),
127            children: vec![root().serialize_proto().unwrap()],
128        };
129
130        let expr = Expression::from_proto(&expr_proto, &session).unwrap();
131        assert_eq!(expr.id().as_ref(), "vortex.test.foreign_scalar_fn");
132
133        let roundtrip = expr.serialize_proto().unwrap();
134        assert_eq!(roundtrip.id, expr_proto.id);
135        assert_eq!(roundtrip.metadata(), expr_proto.metadata());
136        assert_eq!(roundtrip.children.len(), 1);
137    }
138}