Skip to main content

vortex_array/aggregate_fn/
proto.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use arcref::ArcRef;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_err;
10use vortex_proto::expr as pb;
11use vortex_session::VortexSession;
12
13use crate::aggregate_fn::AggregateFnId;
14use crate::aggregate_fn::AggregateFnRef;
15use crate::aggregate_fn::new_foreign_aggregate_fn;
16use crate::aggregate_fn::session::AggregateFnSessionExt;
17
18impl AggregateFnRef {
19    /// Serialize this aggregate function to its protobuf representation.
20    ///
21    /// Note: the serialization format is not stable and may change between versions.
22    pub fn serialize_proto(&self) -> VortexResult<pb::AggregateFn> {
23        let metadata = self
24            .options()
25            .serialize()?
26            .ok_or_else(|| vortex_err!("Aggregate function '{}' is not serializable", self.id()))?;
27
28        Ok(pb::AggregateFn {
29            id: self.id().to_string(),
30            metadata: Some(metadata),
31        })
32    }
33
34    /// Deserialize an aggregate function from its protobuf representation.
35    ///
36    /// Looks up the aggregate function plugin by ID in the session's registry
37    /// and delegates deserialization to it.
38    ///
39    /// Note: the serialization format is not stable and may change between versions.
40    pub fn from_proto(proto: &pb::AggregateFn, session: &VortexSession) -> VortexResult<Self> {
41        let agg_fn_id: AggregateFnId = ArcRef::new_arc(Arc::from(proto.id.as_str()));
42        let agg_fn = if let Some(plugin) = session.aggregate_fns().registry().find(&agg_fn_id) {
43            plugin.deserialize(proto.metadata(), session)?
44        } else if session.allows_unknown() {
45            new_foreign_aggregate_fn(agg_fn_id.clone(), proto.metadata().to_vec())
46        } else {
47            return Err(vortex_err!("unknown aggregate function id: {}", proto.id));
48        };
49
50        if agg_fn.id() != agg_fn_id {
51            vortex_bail!(
52                "Aggregate function ID mismatch: expected {}, got {}",
53                agg_fn_id,
54                agg_fn.id()
55            );
56        }
57
58        Ok(agg_fn)
59    }
60}
61
62#[cfg(test)]
63mod tests {
64    use prost::Message;
65    use vortex_error::VortexResult;
66    use vortex_error::vortex_panic;
67    use vortex_proto::expr as pb;
68    use vortex_session::VortexSession;
69
70    use crate::ArrayRef;
71    use crate::Columnar;
72    use crate::ExecutionCtx;
73    use crate::aggregate_fn::AggregateFnId;
74    use crate::aggregate_fn::AggregateFnRef;
75    use crate::aggregate_fn::AggregateFnVTable;
76    use crate::aggregate_fn::AggregateFnVTableExt;
77    use crate::aggregate_fn::EmptyOptions;
78    use crate::aggregate_fn::session::AggregateFnSession;
79    use crate::aggregate_fn::session::AggregateFnSessionExt;
80    use crate::dtype::DType;
81    use crate::scalar::Scalar;
82
83    /// A minimal serializable aggregate function used solely to exercise the serde round-trip.
84    #[derive(Clone, Debug)]
85    struct TestAgg;
86
87    impl AggregateFnVTable for TestAgg {
88        type Options = EmptyOptions;
89        type Partial = ();
90
91        fn id(&self) -> AggregateFnId {
92            AggregateFnId::new_ref("vortex.test.proto")
93        }
94
95        fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
96            Ok(Some(vec![]))
97        }
98
99        fn deserialize(
100            &self,
101            _metadata: &[u8],
102            _session: &VortexSession,
103        ) -> VortexResult<Self::Options> {
104            Ok(EmptyOptions)
105        }
106
107        fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
108            Some(input_dtype.clone())
109        }
110
111        fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
112            self.return_dtype(options, input_dtype)
113        }
114
115        fn empty_partial(
116            &self,
117            _options: &Self::Options,
118            _input_dtype: &DType,
119        ) -> VortexResult<Self::Partial> {
120            Ok(())
121        }
122
123        fn combine_partials(
124            &self,
125            _partial: &mut Self::Partial,
126            _other: Scalar,
127        ) -> VortexResult<()> {
128            Ok(())
129        }
130
131        fn to_scalar(&self, _partial: &Self::Partial) -> VortexResult<Scalar> {
132            vortex_panic!("TestAgg is for serde tests only");
133        }
134
135        fn reset(&self, _partial: &mut Self::Partial) {}
136
137        fn is_saturated(&self, _partial: &Self::Partial) -> bool {
138            true
139        }
140
141        fn accumulate(
142            &self,
143            _state: &mut Self::Partial,
144            _batch: &Columnar,
145            _ctx: &mut ExecutionCtx,
146        ) -> VortexResult<()> {
147            Ok(())
148        }
149
150        fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
151            Ok(partials)
152        }
153    }
154
155    #[test]
156    fn aggregate_fn_serde() {
157        let session = VortexSession::empty().with::<AggregateFnSession>();
158        session.aggregate_fns().register(TestAgg);
159
160        let agg_fn = TestAgg.bind(EmptyOptions);
161
162        let serialized = agg_fn.serialize_proto().unwrap();
163        let buf = serialized.encode_to_vec();
164        let deserialized_proto = pb::AggregateFn::decode(buf.as_slice()).unwrap();
165        let deserialized = AggregateFnRef::from_proto(&deserialized_proto, &session).unwrap();
166
167        assert_eq!(deserialized, agg_fn);
168    }
169
170    #[test]
171    fn unknown_aggregate_fn_id_allow_unknown() {
172        let session = VortexSession::empty()
173            .with::<AggregateFnSession>()
174            .allow_unknown();
175
176        let proto = pb::AggregateFn {
177            id: "vortex.test.foreign_aggregate".to_string(),
178            metadata: Some(vec![7, 8, 9]),
179        };
180
181        let agg_fn = AggregateFnRef::from_proto(&proto, &session).unwrap();
182        assert_eq!(agg_fn.id().as_ref(), "vortex.test.foreign_aggregate");
183
184        let roundtrip = agg_fn.serialize_proto().unwrap();
185        assert_eq!(roundtrip.id, proto.id);
186        assert_eq!(roundtrip.metadata(), proto.metadata());
187    }
188}