Skip to main content

vortex_array/aggregate_fn/
proto.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_bail;
6use vortex_error::vortex_err;
7use vortex_proto::expr as pb;
8use vortex_session::VortexSession;
9
10use crate::aggregate_fn::AggregateFnId;
11use crate::aggregate_fn::AggregateFnRef;
12use crate::aggregate_fn::new_foreign_aggregate_fn;
13use crate::aggregate_fn::session::AggregateFnSessionExt;
14
15impl AggregateFnRef {
16    /// Serialize this aggregate function to its protobuf representation.
17    ///
18    /// Note: the serialization format is not stable and may change between versions.
19    pub fn serialize_proto(&self) -> VortexResult<pb::AggregateFn> {
20        let metadata = self
21            .options()
22            .serialize()?
23            .ok_or_else(|| vortex_err!("Aggregate function '{}' is not serializable", self.id()))?;
24
25        Ok(pb::AggregateFn {
26            id: self.id().to_string(),
27            metadata: Some(metadata),
28        })
29    }
30
31    /// Deserialize an aggregate function from its protobuf representation.
32    ///
33    /// Looks up the aggregate function plugin by ID in the session's registry
34    /// and delegates deserialization to it.
35    ///
36    /// Note: the serialization format is not stable and may change between versions.
37    pub fn from_proto(proto: &pb::AggregateFn, session: &VortexSession) -> VortexResult<Self> {
38        let agg_fn_id: AggregateFnId = AggregateFnId::new(proto.id.as_str());
39        let agg_fn = if let Some(plugin) = session.aggregate_fns().registry().find(&agg_fn_id) {
40            plugin.deserialize(proto.metadata(), session)?
41        } else if session.allows_unknown() {
42            new_foreign_aggregate_fn(agg_fn_id, proto.metadata().to_vec())
43        } else {
44            return Err(vortex_err!("unknown aggregate function id: {}", proto.id));
45        };
46
47        if agg_fn.id() != agg_fn_id {
48            vortex_bail!(
49                "Aggregate function ID mismatch: expected {}, got {}",
50                agg_fn_id,
51                agg_fn.id()
52            );
53        }
54
55        Ok(agg_fn)
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use prost::Message;
62    use vortex_error::VortexResult;
63    use vortex_error::vortex_panic;
64    use vortex_proto::expr as pb;
65    use vortex_session::VortexSession;
66
67    use crate::ArrayRef;
68    use crate::Columnar;
69    use crate::ExecutionCtx;
70    use crate::aggregate_fn::AggregateFnId;
71    use crate::aggregate_fn::AggregateFnRef;
72    use crate::aggregate_fn::AggregateFnVTable;
73    use crate::aggregate_fn::AggregateFnVTableExt;
74    use crate::aggregate_fn::EmptyOptions;
75    use crate::aggregate_fn::session::AggregateFnSession;
76    use crate::aggregate_fn::session::AggregateFnSessionExt;
77    use crate::dtype::DType;
78    use crate::scalar::Scalar;
79
80    /// A minimal serializable aggregate function used solely to exercise the serde round-trip.
81    #[derive(Clone, Debug)]
82    struct TestAgg;
83
84    impl AggregateFnVTable for TestAgg {
85        type Options = EmptyOptions;
86        type Partial = ();
87
88        fn id(&self) -> AggregateFnId {
89            AggregateFnId::new("vortex.test.proto")
90        }
91
92        fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
93            Ok(Some(vec![]))
94        }
95
96        fn deserialize(
97            &self,
98            _metadata: &[u8],
99            _session: &VortexSession,
100        ) -> VortexResult<Self::Options> {
101            Ok(EmptyOptions)
102        }
103
104        fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
105            Some(input_dtype.clone())
106        }
107
108        fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
109            self.return_dtype(options, input_dtype)
110        }
111
112        fn empty_partial(
113            &self,
114            _options: &Self::Options,
115            _input_dtype: &DType,
116        ) -> VortexResult<Self::Partial> {
117            Ok(())
118        }
119
120        fn combine_partials(
121            &self,
122            _partial: &mut Self::Partial,
123            _other: Scalar,
124        ) -> VortexResult<()> {
125            Ok(())
126        }
127
128        fn to_scalar(&self, _partial: &Self::Partial) -> VortexResult<Scalar> {
129            vortex_panic!("TestAgg is for serde tests only");
130        }
131
132        fn reset(&self, _partial: &mut Self::Partial) {}
133
134        fn is_saturated(&self, _partial: &Self::Partial) -> bool {
135            true
136        }
137
138        fn accumulate(
139            &self,
140            _state: &mut Self::Partial,
141            _batch: &Columnar,
142            _ctx: &mut ExecutionCtx,
143        ) -> VortexResult<()> {
144            Ok(())
145        }
146
147        fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
148            Ok(partials)
149        }
150
151        fn finalize_scalar(&self, _partial: &Self::Partial) -> VortexResult<Scalar> {
152            vortex_panic!("TestAgg is for serde tests only");
153        }
154    }
155
156    #[test]
157    fn aggregate_fn_serde() {
158        let session = VortexSession::empty().with::<AggregateFnSession>();
159        session.aggregate_fns().register(TestAgg);
160
161        let agg_fn = TestAgg.bind(EmptyOptions);
162
163        let serialized = agg_fn.serialize_proto().unwrap();
164        let buf = serialized.encode_to_vec();
165        let deserialized_proto = pb::AggregateFn::decode(buf.as_slice()).unwrap();
166        let deserialized = AggregateFnRef::from_proto(&deserialized_proto, &session).unwrap();
167
168        assert_eq!(deserialized, agg_fn);
169    }
170
171    #[test]
172    fn unknown_aggregate_fn_id_allow_unknown() {
173        let session = VortexSession::empty()
174            .with::<AggregateFnSession>()
175            .allow_unknown();
176
177        let proto = pb::AggregateFn {
178            id: "vortex.test.foreign_aggregate".to_string(),
179            metadata: Some(vec![7, 8, 9]),
180        };
181
182        let agg_fn = AggregateFnRef::from_proto(&proto, &session).unwrap();
183        assert_eq!(agg_fn.id().as_ref(), "vortex.test.foreign_aggregate");
184
185        let roundtrip = agg_fn.serialize_proto().unwrap();
186        assert_eq!(roundtrip.id, proto.id);
187        assert_eq!(roundtrip.metadata(), proto.metadata());
188    }
189}