vortex_array/aggregate_fn/
proto.rs1use std::sync::Arc;
5
6use arcref::ArcRef;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_err;
10use vortex_proto::expr as pb;
11use vortex_session::VortexSession;
12
13use crate::aggregate_fn::AggregateFnId;
14use crate::aggregate_fn::AggregateFnRef;
15use crate::aggregate_fn::new_foreign_aggregate_fn;
16use crate::aggregate_fn::session::AggregateFnSessionExt;
17
18impl AggregateFnRef {
19 pub fn serialize_proto(&self) -> VortexResult<pb::AggregateFn> {
23 let metadata = self
24 .options()
25 .serialize()?
26 .ok_or_else(|| vortex_err!("Aggregate function '{}' is not serializable", self.id()))?;
27
28 Ok(pb::AggregateFn {
29 id: self.id().to_string(),
30 metadata: Some(metadata),
31 })
32 }
33
34 pub fn from_proto(proto: &pb::AggregateFn, session: &VortexSession) -> VortexResult<Self> {
41 let agg_fn_id: AggregateFnId = ArcRef::new_arc(Arc::from(proto.id.as_str()));
42 let agg_fn = if let Some(plugin) = session.aggregate_fns().registry().find(&agg_fn_id) {
43 plugin.deserialize(proto.metadata(), session)?
44 } else if session.allows_unknown() {
45 new_foreign_aggregate_fn(agg_fn_id.clone(), proto.metadata().to_vec())
46 } else {
47 return Err(vortex_err!("unknown aggregate function id: {}", proto.id));
48 };
49
50 if agg_fn.id() != agg_fn_id {
51 vortex_bail!(
52 "Aggregate function ID mismatch: expected {}, got {}",
53 agg_fn_id,
54 agg_fn.id()
55 );
56 }
57
58 Ok(agg_fn)
59 }
60}
61
62#[cfg(test)]
63mod tests {
64 use prost::Message;
65 use vortex_error::VortexResult;
66 use vortex_error::vortex_panic;
67 use vortex_proto::expr as pb;
68 use vortex_session::VortexSession;
69
70 use crate::ArrayRef;
71 use crate::Columnar;
72 use crate::ExecutionCtx;
73 use crate::aggregate_fn::AggregateFnId;
74 use crate::aggregate_fn::AggregateFnRef;
75 use crate::aggregate_fn::AggregateFnVTable;
76 use crate::aggregate_fn::AggregateFnVTableExt;
77 use crate::aggregate_fn::EmptyOptions;
78 use crate::aggregate_fn::session::AggregateFnSession;
79 use crate::aggregate_fn::session::AggregateFnSessionExt;
80 use crate::dtype::DType;
81 use crate::scalar::Scalar;
82
83 #[derive(Clone, Debug)]
85 struct TestAgg;
86
87 impl AggregateFnVTable for TestAgg {
88 type Options = EmptyOptions;
89 type Partial = ();
90
91 fn id(&self) -> AggregateFnId {
92 AggregateFnId::new_ref("vortex.test.proto")
93 }
94
95 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
96 Ok(Some(vec![]))
97 }
98
99 fn deserialize(
100 &self,
101 _metadata: &[u8],
102 _session: &VortexSession,
103 ) -> VortexResult<Self::Options> {
104 Ok(EmptyOptions)
105 }
106
107 fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
108 Some(input_dtype.clone())
109 }
110
111 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
112 self.return_dtype(options, input_dtype)
113 }
114
115 fn empty_partial(
116 &self,
117 _options: &Self::Options,
118 _input_dtype: &DType,
119 ) -> VortexResult<Self::Partial> {
120 Ok(())
121 }
122
123 fn combine_partials(
124 &self,
125 _partial: &mut Self::Partial,
126 _other: Scalar,
127 ) -> VortexResult<()> {
128 Ok(())
129 }
130
131 fn to_scalar(&self, _partial: &Self::Partial) -> VortexResult<Scalar> {
132 vortex_panic!("TestAgg is for serde tests only");
133 }
134
135 fn reset(&self, _partial: &mut Self::Partial) {}
136
137 fn is_saturated(&self, _partial: &Self::Partial) -> bool {
138 true
139 }
140
141 fn accumulate(
142 &self,
143 _state: &mut Self::Partial,
144 _batch: &Columnar,
145 _ctx: &mut ExecutionCtx,
146 ) -> VortexResult<()> {
147 Ok(())
148 }
149
150 fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
151 Ok(partials)
152 }
153 }
154
155 #[test]
156 fn aggregate_fn_serde() {
157 let session = VortexSession::empty().with::<AggregateFnSession>();
158 session.aggregate_fns().register(TestAgg);
159
160 let agg_fn = TestAgg.bind(EmptyOptions);
161
162 let serialized = agg_fn.serialize_proto().unwrap();
163 let buf = serialized.encode_to_vec();
164 let deserialized_proto = pb::AggregateFn::decode(buf.as_slice()).unwrap();
165 let deserialized = AggregateFnRef::from_proto(&deserialized_proto, &session).unwrap();
166
167 assert_eq!(deserialized, agg_fn);
168 }
169
170 #[test]
171 fn unknown_aggregate_fn_id_allow_unknown() {
172 let session = VortexSession::empty()
173 .with::<AggregateFnSession>()
174 .allow_unknown();
175
176 let proto = pb::AggregateFn {
177 id: "vortex.test.foreign_aggregate".to_string(),
178 metadata: Some(vec![7, 8, 9]),
179 };
180
181 let agg_fn = AggregateFnRef::from_proto(&proto, &session).unwrap();
182 assert_eq!(agg_fn.id().as_ref(), "vortex.test.foreign_aggregate");
183
184 let roundtrip = agg_fn.serialize_proto().unwrap();
185 assert_eq!(roundtrip.id, proto.id);
186 assert_eq!(roundtrip.metadata(), proto.metadata());
187 }
188}