Skip to main content

vortex_array/aggregate_fn/
proto.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_bail;
6use vortex_error::vortex_err;
7use vortex_proto::expr as pb;
8use vortex_session::VortexSession;
9
10use crate::aggregate_fn::AggregateFnId;
11use crate::aggregate_fn::AggregateFnRef;
12use crate::aggregate_fn::new_foreign_aggregate_fn;
13use crate::aggregate_fn::session::AggregateFnSessionExt;
14
15impl AggregateFnRef {
16    /// Serialize this aggregate function to its protobuf representation.
17    ///
18    /// Note: the serialization format is not stable and may change between versions.
19    pub fn serialize_proto(&self) -> VortexResult<pb::AggregateFn> {
20        let metadata = self
21            .options()
22            .serialize()?
23            .ok_or_else(|| vortex_err!("Aggregate function '{}' is not serializable", self.id()))?;
24
25        Ok(pb::AggregateFn {
26            id: self.id().to_string(),
27            metadata: Some(metadata),
28        })
29    }
30
31    /// Deserialize an aggregate function from its protobuf representation.
32    ///
33    /// Looks up the aggregate function plugin by ID in the session's registry
34    /// and delegates deserialization to it.
35    ///
36    /// Note: the serialization format is not stable and may change between versions.
37    pub fn from_proto(proto: &pb::AggregateFn, session: &VortexSession) -> VortexResult<Self> {
38        #[expect(clippy::disallowed_methods, reason = "interning a dynamic id")]
39        let agg_fn_id: AggregateFnId = AggregateFnId::new(proto.id.as_str());
40        let agg_fn = if let Some(plugin) = session.aggregate_fns().find_plugin(&agg_fn_id) {
41            plugin.deserialize(proto.metadata(), session)?
42        } else if session.allows_unknown() {
43            new_foreign_aggregate_fn(agg_fn_id, proto.metadata().to_vec())
44        } else {
45            return Err(vortex_err!("unknown aggregate function id: {}", proto.id));
46        };
47
48        if agg_fn.id() != agg_fn_id {
49            vortex_bail!(
50                "Aggregate function ID mismatch: expected {}, got {}",
51                agg_fn_id,
52                agg_fn.id()
53            );
54        }
55
56        Ok(agg_fn)
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use prost::Message;
63    use rstest::rstest;
64    use vortex_error::VortexResult;
65    use vortex_error::vortex_panic;
66    use vortex_proto::expr as pb;
67    use vortex_session::VortexSession;
68
69    use crate::ArrayRef;
70    use crate::Columnar;
71    use crate::ExecutionCtx;
72    use crate::aggregate_fn::AggregateFnId;
73    use crate::aggregate_fn::AggregateFnRef;
74    use crate::aggregate_fn::AggregateFnVTable;
75    use crate::aggregate_fn::AggregateFnVTableExt;
76    use crate::aggregate_fn::EmptyOptions;
77    use crate::aggregate_fn::NumericalAggregateOpts;
78    use crate::aggregate_fn::fns::sum::Sum;
79    use crate::aggregate_fn::session::AggregateFnSession;
80    use crate::aggregate_fn::session::AggregateFnSessionExt;
81    use crate::dtype::DType;
82    use crate::scalar::Scalar;
83
84    /// A minimal serializable aggregate function used solely to exercise the serde round-trip.
85    #[derive(Clone, Debug)]
86    struct TestAgg;
87
88    impl AggregateFnVTable for TestAgg {
89        type Options = EmptyOptions;
90        type Partial = ();
91
92        #[expect(clippy::disallowed_methods, reason = "test-only id")]
93        fn id(&self) -> AggregateFnId {
94            AggregateFnId::new("vortex.test.proto")
95        }
96
97        fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
98            Ok(Some(vec![]))
99        }
100
101        fn deserialize(
102            &self,
103            _metadata: &[u8],
104            _session: &VortexSession,
105        ) -> VortexResult<Self::Options> {
106            Ok(EmptyOptions)
107        }
108
109        fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
110            Some(input_dtype.clone())
111        }
112
113        fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
114            self.return_dtype(options, input_dtype)
115        }
116
117        fn empty_partial(
118            &self,
119            _options: &Self::Options,
120            _input_dtype: &DType,
121        ) -> VortexResult<Self::Partial> {
122            Ok(())
123        }
124
125        fn combine_partials(
126            &self,
127            _partial: &mut Self::Partial,
128            _other: Scalar,
129        ) -> VortexResult<()> {
130            Ok(())
131        }
132
133        fn to_scalar(&self, _partial: &Self::Partial) -> VortexResult<Scalar> {
134            vortex_panic!("TestAgg is for serde tests only");
135        }
136
137        fn reset(&self, _partial: &mut Self::Partial) {}
138
139        fn is_saturated(&self, _partial: &Self::Partial) -> bool {
140            true
141        }
142
143        fn accumulate(
144            &self,
145            _state: &mut Self::Partial,
146            _batch: &Columnar,
147            _ctx: &mut ExecutionCtx,
148        ) -> VortexResult<()> {
149            Ok(())
150        }
151
152        fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
153            Ok(partials)
154        }
155
156        fn finalize_scalar(&self, _partial: &Self::Partial) -> VortexResult<Scalar> {
157            vortex_panic!("TestAgg is for serde tests only");
158        }
159    }
160
161    #[test]
162    fn aggregate_fn_serde() {
163        let session = crate::array_session();
164        session.aggregate_fns().register(TestAgg);
165
166        let agg_fn = TestAgg.bind(EmptyOptions);
167
168        let serialized = agg_fn.serialize_proto().unwrap();
169        let buf = serialized.encode_to_vec();
170        let deserialized_proto = pb::AggregateFn::decode(buf.as_slice()).unwrap();
171        let deserialized = AggregateFnRef::from_proto(&deserialized_proto, &session).unwrap();
172
173        assert_eq!(deserialized, agg_fn);
174    }
175
176    /// The `skip_nans` option must survive a protobuf serialize/deserialize round-trip for the
177    /// numeric aggregates, including the non-default NaN-including configuration.
178    #[rstest]
179    #[case(NumericalAggregateOpts::skip_nans())]
180    #[case(NumericalAggregateOpts::include_nans())]
181    fn numeric_aggregate_options_round_trip(
182        #[case] options: NumericalAggregateOpts,
183    ) -> VortexResult<()> {
184        let session = crate::array_session();
185        let agg_fn = Sum.bind(options);
186
187        let proto = agg_fn.serialize_proto()?;
188        let buf = proto.encode_to_vec();
189        let decoded = pb::AggregateFn::decode(buf.as_slice())?;
190        let round_tripped = AggregateFnRef::from_proto(&decoded, &session)?;
191
192        assert_eq!(round_tripped, agg_fn);
193        Ok(())
194    }
195
196    #[test]
197    fn unknown_aggregate_fn_id_allow_unknown() {
198        let session = VortexSession::empty()
199            .with::<AggregateFnSession>()
200            .allow_unknown();
201
202        let proto = pb::AggregateFn {
203            id: "vortex.test.foreign_aggregate".to_string(),
204            metadata: Some(vec![7, 8, 9]),
205        };
206
207        let agg_fn = AggregateFnRef::from_proto(&proto, &session).unwrap();
208        assert_eq!(agg_fn.id().as_ref(), "vortex.test.foreign_aggregate");
209
210        let roundtrip = agg_fn.serialize_proto().unwrap();
211        assert_eq!(roundtrip.id, proto.id);
212        assert_eq!(roundtrip.metadata(), proto.metadata());
213    }
214}