vortex_array/aggregate_fn/
proto.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_bail;
6use vortex_error::vortex_err;
7use vortex_proto::expr as pb;
8use vortex_session::VortexSession;
9
10use crate::aggregate_fn::AggregateFnId;
11use crate::aggregate_fn::AggregateFnRef;
12use crate::aggregate_fn::new_foreign_aggregate_fn;
13use crate::aggregate_fn::session::AggregateFnSessionExt;
14
15impl AggregateFnRef {
16 pub fn serialize_proto(&self) -> VortexResult<pb::AggregateFn> {
20 let metadata = self
21 .options()
22 .serialize()?
23 .ok_or_else(|| vortex_err!("Aggregate function '{}' is not serializable", self.id()))?;
24
25 Ok(pb::AggregateFn {
26 id: self.id().to_string(),
27 metadata: Some(metadata),
28 })
29 }
30
31 pub fn from_proto(proto: &pb::AggregateFn, session: &VortexSession) -> VortexResult<Self> {
38 let agg_fn_id: AggregateFnId = AggregateFnId::new(proto.id.as_str());
39 let agg_fn = if let Some(plugin) = session.aggregate_fns().registry().find(&agg_fn_id) {
40 plugin.deserialize(proto.metadata(), session)?
41 } else if session.allows_unknown() {
42 new_foreign_aggregate_fn(agg_fn_id, proto.metadata().to_vec())
43 } else {
44 return Err(vortex_err!("unknown aggregate function id: {}", proto.id));
45 };
46
47 if agg_fn.id() != agg_fn_id {
48 vortex_bail!(
49 "Aggregate function ID mismatch: expected {}, got {}",
50 agg_fn_id,
51 agg_fn.id()
52 );
53 }
54
55 Ok(agg_fn)
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use prost::Message;
62 use vortex_error::VortexResult;
63 use vortex_error::vortex_panic;
64 use vortex_proto::expr as pb;
65 use vortex_session::VortexSession;
66
67 use crate::ArrayRef;
68 use crate::Columnar;
69 use crate::ExecutionCtx;
70 use crate::aggregate_fn::AggregateFnId;
71 use crate::aggregate_fn::AggregateFnRef;
72 use crate::aggregate_fn::AggregateFnVTable;
73 use crate::aggregate_fn::AggregateFnVTableExt;
74 use crate::aggregate_fn::EmptyOptions;
75 use crate::aggregate_fn::session::AggregateFnSession;
76 use crate::aggregate_fn::session::AggregateFnSessionExt;
77 use crate::dtype::DType;
78 use crate::scalar::Scalar;
79
80 #[derive(Clone, Debug)]
82 struct TestAgg;
83
84 impl AggregateFnVTable for TestAgg {
85 type Options = EmptyOptions;
86 type Partial = ();
87
88 fn id(&self) -> AggregateFnId {
89 AggregateFnId::new("vortex.test.proto")
90 }
91
92 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
93 Ok(Some(vec![]))
94 }
95
96 fn deserialize(
97 &self,
98 _metadata: &[u8],
99 _session: &VortexSession,
100 ) -> VortexResult<Self::Options> {
101 Ok(EmptyOptions)
102 }
103
104 fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
105 Some(input_dtype.clone())
106 }
107
108 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
109 self.return_dtype(options, input_dtype)
110 }
111
112 fn empty_partial(
113 &self,
114 _options: &Self::Options,
115 _input_dtype: &DType,
116 ) -> VortexResult<Self::Partial> {
117 Ok(())
118 }
119
120 fn combine_partials(
121 &self,
122 _partial: &mut Self::Partial,
123 _other: Scalar,
124 ) -> VortexResult<()> {
125 Ok(())
126 }
127
128 fn to_scalar(&self, _partial: &Self::Partial) -> VortexResult<Scalar> {
129 vortex_panic!("TestAgg is for serde tests only");
130 }
131
132 fn reset(&self, _partial: &mut Self::Partial) {}
133
134 fn is_saturated(&self, _partial: &Self::Partial) -> bool {
135 true
136 }
137
138 fn accumulate(
139 &self,
140 _state: &mut Self::Partial,
141 _batch: &Columnar,
142 _ctx: &mut ExecutionCtx,
143 ) -> VortexResult<()> {
144 Ok(())
145 }
146
147 fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
148 Ok(partials)
149 }
150
151 fn finalize_scalar(&self, _partial: &Self::Partial) -> VortexResult<Scalar> {
152 vortex_panic!("TestAgg is for serde tests only");
153 }
154 }
155
156 #[test]
157 fn aggregate_fn_serde() {
158 let session = VortexSession::empty().with::<AggregateFnSession>();
159 session.aggregate_fns().register(TestAgg);
160
161 let agg_fn = TestAgg.bind(EmptyOptions);
162
163 let serialized = agg_fn.serialize_proto().unwrap();
164 let buf = serialized.encode_to_vec();
165 let deserialized_proto = pb::AggregateFn::decode(buf.as_slice()).unwrap();
166 let deserialized = AggregateFnRef::from_proto(&deserialized_proto, &session).unwrap();
167
168 assert_eq!(deserialized, agg_fn);
169 }
170
171 #[test]
172 fn unknown_aggregate_fn_id_allow_unknown() {
173 let session = VortexSession::empty()
174 .with::<AggregateFnSession>()
175 .allow_unknown();
176
177 let proto = pb::AggregateFn {
178 id: "vortex.test.foreign_aggregate".to_string(),
179 metadata: Some(vec![7, 8, 9]),
180 };
181
182 let agg_fn = AggregateFnRef::from_proto(&proto, &session).unwrap();
183 assert_eq!(agg_fn.id().as_ref(), "vortex.test.foreign_aggregate");
184
185 let roundtrip = agg_fn.serialize_proto().unwrap();
186 assert_eq!(roundtrip.id, proto.id);
187 assert_eq!(roundtrip.metadata(), proto.metadata());
188 }
189}