sedona_expr/
aggregate_udf.rs1use std::{any::Any, fmt::Debug, sync::Arc};
18
19use arrow_schema::{DataType, FieldRef};
20use datafusion_common::{not_impl_err, Result};
21use datafusion_expr::{
22 function::{AccumulatorArgs, StateFieldsArgs},
23 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
24};
25use sedona_common::sedona_internal_err;
26use sedona_schema::datatypes::SedonaType;
27
28use sedona_schema::matchers::ArgMatcher;
29
30pub type SedonaAccumulatorRef = Arc<dyn SedonaAccumulator + Send + Sync>;
31
32#[derive(Debug, Clone)]
37pub struct SedonaAggregateUDF {
38 name: String,
39 signature: Signature,
40 kernels: Vec<SedonaAccumulatorRef>,
41 documentation: Option<Documentation>,
42}
43
44impl PartialEq for SedonaAggregateUDF {
45 fn eq(&self, other: &Self) -> bool {
46 self.name == other.name
47 }
48}
49
50impl Eq for SedonaAggregateUDF {}
51
52impl std::hash::Hash for SedonaAggregateUDF {
53 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
54 self.name.hash(state);
55 }
56}
57
58impl SedonaAggregateUDF {
59 pub fn new(
61 name: &str,
62 kernels: Vec<SedonaAccumulatorRef>,
63 volatility: Volatility,
64 documentation: Option<Documentation>,
65 ) -> Self {
66 let signature = Signature::user_defined(volatility);
67 Self {
68 name: name.to_string(),
69 signature,
70 kernels,
71 documentation,
72 }
73 }
74
75 pub fn new_stub(
83 name: &str,
84 arg_matcher: ArgMatcher,
85 volatility: Volatility,
86 documentation: Option<Documentation>,
87 ) -> Self {
88 let stub_kernel = StubAccumulator::new(name.to_string(), arg_matcher);
89 Self::new(name, vec![Arc::new(stub_kernel)], volatility, documentation)
90 }
91
92 pub fn add_kernel(&mut self, kernel: SedonaAccumulatorRef) {
97 self.kernels.push(kernel);
98 }
99
100 pub fn kernels(&self) -> &[SedonaAccumulatorRef] {
102 &self.kernels
103 }
104
105 fn dispatch_impl(&self, args: &[SedonaType]) -> Result<(&dyn SedonaAccumulator, SedonaType)> {
106 for kernel in self.kernels.iter().rev() {
108 if let Some(return_type) = kernel.return_type(args)? {
109 return Ok((kernel.as_ref(), return_type));
110 }
111 }
112
113 not_impl_err!("{}({:?}): No kernel matching arguments", self.name, args)
114 }
115}
116
117impl AggregateUDFImpl for SedonaAggregateUDF {
118 fn as_any(&self) -> &dyn Any {
119 self
120 }
121
122 fn name(&self) -> &str {
123 &self.name
124 }
125
126 fn signature(&self) -> &Signature {
127 &self.signature
128 }
129
130 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
131 Ok(arg_types.into())
132 }
133
134 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
135 let arg_types = args
136 .input_fields
137 .iter()
138 .map(|field| SedonaType::from_storage_field(field))
139 .collect::<Result<Vec<_>>>()?;
140 let (accumulator, _) = self.dispatch_impl(&arg_types)?;
141 accumulator.state_fields(&arg_types)
142 }
143
144 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
145 let arg_types = arg_fields
146 .iter()
147 .map(|field| SedonaType::from_storage_field(field))
148 .collect::<Result<Vec<_>>>()?;
149 let (_, out_type) = self.dispatch_impl(&arg_types)?;
150 Ok(Arc::new(out_type.to_storage_field("", true)?))
151 }
152
153 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
154 sedona_internal_err!("return_type() should not be called (use return_field())")
155 }
156
157 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
158 let arg_fields = acc_args
159 .exprs
160 .iter()
161 .map(|expr| expr.return_field(acc_args.schema))
162 .collect::<Result<Vec<_>>>()?;
163 let arg_types = arg_fields
164 .iter()
165 .map(|field| SedonaType::from_storage_field(field))
166 .collect::<Result<Vec<_>>>()?;
167 let (accumulator, output_type) = self.dispatch_impl(&arg_types)?;
168 accumulator.accumulator(&arg_types, &output_type)
169 }
170
171 fn documentation(&self) -> Option<&Documentation> {
172 self.documentation.as_ref()
173 }
174}
175
176pub trait SedonaAccumulator: Debug {
177 fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>>;
179
180 fn accumulator(
188 &self,
189 args: &[SedonaType],
190 output_type: &SedonaType,
191 ) -> Result<Box<dyn Accumulator>>;
192
193 fn state_fields(&self, args: &[SedonaType]) -> Result<Vec<FieldRef>>;
195}
196
197#[derive(Debug)]
198struct StubAccumulator {
199 name: String,
200 matcher: ArgMatcher,
201}
202
203impl StubAccumulator {
204 fn new(name: String, matcher: ArgMatcher) -> Self {
205 Self { name, matcher }
206 }
207}
208
209impl SedonaAccumulator for StubAccumulator {
210 fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
211 self.matcher.match_args(args)
212 }
213
214 fn accumulator(
215 &self,
216 args: &[SedonaType],
217 _output_type: &SedonaType,
218 ) -> Result<Box<dyn Accumulator>> {
219 not_impl_err!(
220 "Implementation for {}({args:?}) was not registered",
221 self.name
222 )
223 }
224
225 fn state_fields(&self, _args: &[SedonaType]) -> Result<Vec<FieldRef>> {
226 Ok(vec![])
227 }
228}
229
230#[cfg(test)]
231mod test {
232 use sedona_testing::testers::AggregateUdfTester;
233
234 use crate::aggregate_udf::SedonaAggregateUDF;
235
236 use super::*;
237
238 #[test]
239 fn udaf_empty() -> Result<()> {
240 let udf = SedonaAggregateUDF::new("empty", vec![], Volatility::Immutable, None);
242 assert_eq!(udf.name(), "empty");
243 let err = udf.return_field(&[]).unwrap_err();
244 assert_eq!(err.message(), "empty([]): No kernel matching arguments");
245 assert!(udf.kernels().is_empty());
246 assert_eq!(udf.coerce_types(&[])?, vec![]);
247
248 let batch_err = udf.return_field(&[]).unwrap_err();
249 assert_eq!(
250 batch_err.message(),
251 "empty([]): No kernel matching arguments"
252 );
253
254 Ok(())
255 }
256
257 #[test]
258 fn stub() {
259 let stub = SedonaAggregateUDF::new_stub(
260 "stubby",
261 ArgMatcher::new(vec![], SedonaType::Arrow(DataType::Boolean)),
262 Volatility::Immutable,
263 None,
264 );
265
266 let tester = AggregateUdfTester::new(stub.clone().into(), vec![]);
270 assert_eq!(
271 tester.return_type().unwrap(),
272 SedonaType::Arrow(DataType::Boolean)
273 );
274
275 let err = tester.aggregate(&vec![]).unwrap_err();
276 assert_eq!(
277 err.message(),
278 "Implementation for stubby([]) was not registered"
279 );
280
281 let tester = AggregateUdfTester::new(
283 stub.clone().into(),
284 vec![SedonaType::Arrow(DataType::Binary)],
285 );
286 let err = tester.return_type().unwrap_err();
287 assert_eq!(
288 err.message(),
289 "stubby([Arrow(Binary)]): No kernel matching arguments"
290 );
291 }
292}