1use itertools::Itertools;
5use vortex_error::{VortexResult, vortex_err};
6use vortex_proto::expr as pb;
7
8use crate::registry::ExprRegistry;
9use crate::{ExprRef, VortexExpr};
10
11pub trait ExprSerializeProtoExt {
12 fn serialize_proto(&self) -> VortexResult<pb::Expr>;
14}
15
16impl ExprSerializeProtoExt for dyn VortexExpr + '_ {
17 fn serialize_proto(&self) -> VortexResult<pb::Expr> {
18 let children = self
19 .children()
20 .into_iter()
21 .map(|child| child.serialize_proto())
22 .try_collect()?;
23
24 let metadata = self.metadata().ok_or_else(|| {
25 vortex_err!("Expression '{}' is not serializable: {}", self.id(), self)
26 })?;
27
28 Ok(pb::Expr {
29 id: self.id().to_string(),
30 children,
31 metadata: Some(metadata),
32 })
33 }
34}
35
36pub fn deserialize_expr_proto(expr: &pb::Expr, registry: &ExprRegistry) -> VortexResult<ExprRef> {
38 let expr_id = expr.id.as_str();
39 let encoding = registry
40 .get(expr_id)
41 .ok_or_else(|| vortex_err!("unknown expression id: {}", expr_id))?;
42
43 let children = expr
44 .children
45 .iter()
46 .map(|e| deserialize_expr_proto(e, registry))
47 .collect::<VortexResult<Vec<_>>>()?;
48
49 encoding.build(expr.metadata(), children)
50}
51
52#[cfg(test)]
53mod tests {
54 use prost::Message;
55 use vortex_array::compute::{BetweenOptions, StrictComparison};
56 use vortex_proto::expr as pb;
57
58 use crate::proto::{ExprSerializeProtoExt, deserialize_expr_proto};
59 use crate::registry::ExprRegistryExt;
60 use crate::{ExprRef, ExprRegistry, and, between, eq, get_item, lit, or, root};
61
62 #[test]
63 fn expression_serde() {
64 let registry = ExprRegistry::default();
65 let expr: ExprRef = or(
66 and(
67 between(
68 lit(1),
69 root(),
70 get_item("a", root()),
71 BetweenOptions {
72 lower_strict: StrictComparison::Strict,
73 upper_strict: StrictComparison::Strict,
74 },
75 ),
76 lit(1),
77 ),
78 eq(lit(1), root()),
79 );
80
81 let s_expr = expr.serialize_proto().unwrap();
82 let buf = s_expr.encode_to_vec();
83 let s_expr = pb::Expr::decode(buf.as_slice()).unwrap();
84 let deser_expr = deserialize_expr_proto(&s_expr, ®istry).unwrap();
85
86 assert_eq!(&deser_expr, &expr);
87 }
88}