Skip to main content

vortex_array/aggregate_fn/
proto.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use arcref::ArcRef;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_err;
10use vortex_proto::expr as pb;
11use vortex_session::VortexSession;
12
13use crate::aggregate_fn::AggregateFnId;
14use crate::aggregate_fn::AggregateFnRef;
15use crate::aggregate_fn::session::AggregateFnSessionExt;
16
17impl AggregateFnRef {
18    /// Serialize this aggregate function to its protobuf representation.
19    ///
20    /// Note: the serialization format is not stable and may change between versions.
21    pub fn serialize_proto(&self) -> VortexResult<pb::AggregateFn> {
22        let metadata = self
23            .options()
24            .serialize()?
25            .ok_or_else(|| vortex_err!("Aggregate function '{}' is not serializable", self.id()))?;
26
27        Ok(pb::AggregateFn {
28            id: self.id().to_string(),
29            metadata: Some(metadata),
30        })
31    }
32
33    /// Deserialize an aggregate function from its protobuf representation.
34    ///
35    /// Looks up the aggregate function plugin by ID in the session's registry
36    /// and delegates deserialization to it.
37    ///
38    /// Note: the serialization format is not stable and may change between versions.
39    pub fn from_proto(proto: &pb::AggregateFn, session: &VortexSession) -> VortexResult<Self> {
40        let agg_fn_id: AggregateFnId = ArcRef::new_arc(Arc::from(proto.id.as_str()));
41        let plugin = session
42            .aggregate_fns()
43            .registry()
44            .find(&agg_fn_id)
45            .ok_or_else(|| vortex_err!("unknown aggregate function id: {}", proto.id))?;
46        let agg_fn = plugin.deserialize(proto.metadata(), session)?;
47
48        if agg_fn.id() != agg_fn_id {
49            vortex_bail!(
50                "Aggregate function ID mismatch: expected {}, got {}",
51                agg_fn_id,
52                agg_fn.id()
53            );
54        }
55
56        Ok(agg_fn)
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use prost::Message;
63    use vortex_proto::expr as pb;
64    use vortex_session::VortexSession;
65
66    use crate::aggregate_fn::AggregateFnRef;
67    use crate::aggregate_fn::AggregateFnVTableExt;
68    use crate::aggregate_fn::EmptyOptions;
69    use crate::aggregate_fn::fns::sum::Sum;
70    use crate::aggregate_fn::session::AggregateFnSession;
71    use crate::aggregate_fn::session::AggregateFnSessionExt;
72
73    #[test]
74    fn aggregate_fn_serde() {
75        let session = VortexSession::empty().with::<AggregateFnSession>();
76        session.aggregate_fns().register(Sum);
77
78        let agg_fn = Sum.bind(EmptyOptions);
79
80        let serialized = agg_fn.serialize_proto().unwrap();
81        let buf = serialized.encode_to_vec();
82        let deserialized_proto = pb::AggregateFn::decode(buf.as_slice()).unwrap();
83        let deserialized = AggregateFnRef::from_proto(&deserialized_proto, &session).unwrap();
84
85        assert_eq!(deserialized, agg_fn);
86    }
87}