vortex_array/aggregate_fn/
proto.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_bail;
6use vortex_error::vortex_err;
7use vortex_proto::expr as pb;
8use vortex_session::VortexSession;
9
10use crate::aggregate_fn::AggregateFnId;
11use crate::aggregate_fn::AggregateFnRef;
12use crate::aggregate_fn::new_foreign_aggregate_fn;
13use crate::aggregate_fn::session::AggregateFnSessionExt;
14
15impl AggregateFnRef {
16 pub fn serialize_proto(&self) -> VortexResult<pb::AggregateFn> {
20 let metadata = self
21 .options()
22 .serialize()?
23 .ok_or_else(|| vortex_err!("Aggregate function '{}' is not serializable", self.id()))?;
24
25 Ok(pb::AggregateFn {
26 id: self.id().to_string(),
27 metadata: Some(metadata),
28 })
29 }
30
31 pub fn from_proto(proto: &pb::AggregateFn, session: &VortexSession) -> VortexResult<Self> {
38 #[expect(clippy::disallowed_methods, reason = "interning a dynamic id")]
39 let agg_fn_id: AggregateFnId = AggregateFnId::new(proto.id.as_str());
40 let agg_fn = if let Some(plugin) = session.aggregate_fns().find_plugin(&agg_fn_id) {
41 plugin.deserialize(proto.metadata(), session)?
42 } else if session.allows_unknown() {
43 new_foreign_aggregate_fn(agg_fn_id, proto.metadata().to_vec())
44 } else {
45 return Err(vortex_err!("unknown aggregate function id: {}", proto.id));
46 };
47
48 if agg_fn.id() != agg_fn_id {
49 vortex_bail!(
50 "Aggregate function ID mismatch: expected {}, got {}",
51 agg_fn_id,
52 agg_fn.id()
53 );
54 }
55
56 Ok(agg_fn)
57 }
58}
59
60#[cfg(test)]
61mod tests {
62 use prost::Message;
63 use rstest::rstest;
64 use vortex_error::VortexResult;
65 use vortex_error::vortex_panic;
66 use vortex_proto::expr as pb;
67 use vortex_session::VortexSession;
68
69 use crate::ArrayRef;
70 use crate::Columnar;
71 use crate::ExecutionCtx;
72 use crate::aggregate_fn::AggregateFnId;
73 use crate::aggregate_fn::AggregateFnRef;
74 use crate::aggregate_fn::AggregateFnVTable;
75 use crate::aggregate_fn::AggregateFnVTableExt;
76 use crate::aggregate_fn::EmptyOptions;
77 use crate::aggregate_fn::NumericalAggregateOpts;
78 use crate::aggregate_fn::fns::sum::Sum;
79 use crate::aggregate_fn::session::AggregateFnSession;
80 use crate::aggregate_fn::session::AggregateFnSessionExt;
81 use crate::dtype::DType;
82 use crate::scalar::Scalar;
83
84 #[derive(Clone, Debug)]
86 struct TestAgg;
87
88 impl AggregateFnVTable for TestAgg {
89 type Options = EmptyOptions;
90 type Partial = ();
91
92 #[expect(clippy::disallowed_methods, reason = "test-only id")]
93 fn id(&self) -> AggregateFnId {
94 AggregateFnId::new("vortex.test.proto")
95 }
96
97 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
98 Ok(Some(vec![]))
99 }
100
101 fn deserialize(
102 &self,
103 _metadata: &[u8],
104 _session: &VortexSession,
105 ) -> VortexResult<Self::Options> {
106 Ok(EmptyOptions)
107 }
108
109 fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
110 Some(input_dtype.clone())
111 }
112
113 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
114 self.return_dtype(options, input_dtype)
115 }
116
117 fn empty_partial(
118 &self,
119 _options: &Self::Options,
120 _input_dtype: &DType,
121 ) -> VortexResult<Self::Partial> {
122 Ok(())
123 }
124
125 fn combine_partials(
126 &self,
127 _partial: &mut Self::Partial,
128 _other: Scalar,
129 ) -> VortexResult<()> {
130 Ok(())
131 }
132
133 fn to_scalar(&self, _partial: &Self::Partial) -> VortexResult<Scalar> {
134 vortex_panic!("TestAgg is for serde tests only");
135 }
136
137 fn reset(&self, _partial: &mut Self::Partial) {}
138
139 fn is_saturated(&self, _partial: &Self::Partial) -> bool {
140 true
141 }
142
143 fn accumulate(
144 &self,
145 _state: &mut Self::Partial,
146 _batch: &Columnar,
147 _ctx: &mut ExecutionCtx,
148 ) -> VortexResult<()> {
149 Ok(())
150 }
151
152 fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
153 Ok(partials)
154 }
155
156 fn finalize_scalar(&self, _partial: &Self::Partial) -> VortexResult<Scalar> {
157 vortex_panic!("TestAgg is for serde tests only");
158 }
159 }
160
161 #[test]
162 fn aggregate_fn_serde() {
163 let session = crate::array_session();
164 session.aggregate_fns().register(TestAgg);
165
166 let agg_fn = TestAgg.bind(EmptyOptions);
167
168 let serialized = agg_fn.serialize_proto().unwrap();
169 let buf = serialized.encode_to_vec();
170 let deserialized_proto = pb::AggregateFn::decode(buf.as_slice()).unwrap();
171 let deserialized = AggregateFnRef::from_proto(&deserialized_proto, &session).unwrap();
172
173 assert_eq!(deserialized, agg_fn);
174 }
175
176 #[rstest]
179 #[case(NumericalAggregateOpts::skip_nans())]
180 #[case(NumericalAggregateOpts::include_nans())]
181 fn numeric_aggregate_options_round_trip(
182 #[case] options: NumericalAggregateOpts,
183 ) -> VortexResult<()> {
184 let session = crate::array_session();
185 let agg_fn = Sum.bind(options);
186
187 let proto = agg_fn.serialize_proto()?;
188 let buf = proto.encode_to_vec();
189 let decoded = pb::AggregateFn::decode(buf.as_slice())?;
190 let round_tripped = AggregateFnRef::from_proto(&decoded, &session)?;
191
192 assert_eq!(round_tripped, agg_fn);
193 Ok(())
194 }
195
196 #[test]
197 fn unknown_aggregate_fn_id_allow_unknown() {
198 let session = VortexSession::empty()
199 .with::<AggregateFnSession>()
200 .allow_unknown();
201
202 let proto = pb::AggregateFn {
203 id: "vortex.test.foreign_aggregate".to_string(),
204 metadata: Some(vec![7, 8, 9]),
205 };
206
207 let agg_fn = AggregateFnRef::from_proto(&proto, &session).unwrap();
208 assert_eq!(agg_fn.id().as_ref(), "vortex.test.foreign_aggregate");
209
210 let roundtrip = agg_fn.serialize_proto().unwrap();
211 assert_eq!(roundtrip.id, proto.id);
212 assert_eq!(roundtrip.metadata(), proto.metadata());
213 }
214}