sedona_expr/
aggregate_udf.rs1use std::{any::Any, fmt::Debug, sync::Arc};
18
19use arrow_schema::{DataType, FieldRef};
20use datafusion_common::{not_impl_err, Result};
21use datafusion_expr::{
22 function::{AccumulatorArgs, StateFieldsArgs},
23 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, Volatility,
24};
25use sedona_common::sedona_internal_err;
26use sedona_schema::datatypes::SedonaType;
27
28pub type SedonaAccumulatorRef = Arc<dyn SedonaAccumulator>;
30
31pub trait IntoSedonaAccumulatorRefs {
33 fn into_sedona_accumulator_refs(self) -> Vec<SedonaAccumulatorRef>;
34}
35
36impl IntoSedonaAccumulatorRefs for SedonaAccumulatorRef {
37 fn into_sedona_accumulator_refs(self) -> Vec<SedonaAccumulatorRef> {
38 vec![self]
39 }
40}
41
42impl IntoSedonaAccumulatorRefs for Vec<SedonaAccumulatorRef> {
43 fn into_sedona_accumulator_refs(self) -> Vec<SedonaAccumulatorRef> {
44 self
45 }
46}
47
48impl<T: SedonaAccumulator + 'static> IntoSedonaAccumulatorRefs for T {
49 fn into_sedona_accumulator_refs(self) -> Vec<SedonaAccumulatorRef> {
50 vec![Arc::new(self)]
51 }
52}
53
54impl<T: SedonaAccumulator + 'static> IntoSedonaAccumulatorRefs for Vec<Arc<T>> {
55 fn into_sedona_accumulator_refs(self) -> Vec<SedonaAccumulatorRef> {
56 self.into_iter()
57 .map(|item| item as SedonaAccumulatorRef)
58 .collect()
59 }
60}
61
62#[derive(Debug, Clone)]
67pub struct SedonaAggregateUDF {
68 name: String,
69 signature: Signature,
70 kernels: Vec<SedonaAccumulatorRef>,
71}
72
73impl PartialEq for SedonaAggregateUDF {
74 fn eq(&self, other: &Self) -> bool {
75 self.name == other.name
76 }
77}
78
79impl Eq for SedonaAggregateUDF {}
80
81impl std::hash::Hash for SedonaAggregateUDF {
82 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
83 self.name.hash(state);
84 }
85}
86
87impl SedonaAggregateUDF {
88 pub fn new(
90 name: &str,
91 kernels: impl IntoSedonaAccumulatorRefs,
92 volatility: Volatility,
93 ) -> Self {
94 let signature = Signature::user_defined(volatility);
95 Self {
96 name: name.to_string(),
97 signature,
98 kernels: kernels.into_sedona_accumulator_refs(),
99 }
100 }
101
102 pub fn from_impl(name: &str, kernels: impl IntoSedonaAccumulatorRefs) -> Self {
104 Self::new(name, kernels, Volatility::Immutable)
105 }
106
107 pub fn add_kernel(&mut self, kernels: impl IntoSedonaAccumulatorRefs) {
112 for kernel in kernels.into_sedona_accumulator_refs() {
113 self.kernels.push(kernel);
114 }
115 }
116
117 pub fn kernels(&self) -> &[SedonaAccumulatorRef] {
119 &self.kernels
120 }
121
122 fn accumulator_arg_types(args: &AccumulatorArgs) -> Result<Vec<SedonaType>> {
123 let arg_fields = args
124 .exprs
125 .iter()
126 .map(|expr| expr.return_field(args.schema))
127 .collect::<Result<Vec<_>>>()?;
128 arg_fields
129 .iter()
130 .map(|field| SedonaType::from_storage_field(field))
131 .collect()
132 }
133
134 fn dispatch_impl(&self, args: &[SedonaType]) -> Result<(&dyn SedonaAccumulator, SedonaType)> {
135 for kernel in self.kernels.iter().rev() {
137 if let Some(return_type) = kernel.return_type(args)? {
138 return Ok((kernel.as_ref(), return_type));
139 }
140 }
141
142 let args_display = args
143 .iter()
144 .map(|arg| arg.logical_type_name())
145 .collect::<Vec<_>>()
146 .join(", ");
147
148 not_impl_err!(
149 "{}({args_display}): No kernel matching arguments",
150 self.name
151 )
152 }
153}
154
155impl AggregateUDFImpl for SedonaAggregateUDF {
156 fn as_any(&self) -> &dyn Any {
157 self
158 }
159
160 fn name(&self) -> &str {
161 &self.name
162 }
163
164 fn signature(&self) -> &Signature {
165 &self.signature
166 }
167
168 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
169 Ok(arg_types.into())
170 }
171
172 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
173 let arg_types = args
174 .input_fields
175 .iter()
176 .map(|field| SedonaType::from_storage_field(field))
177 .collect::<Result<Vec<_>>>()?;
178 let (accumulator, _) = self.dispatch_impl(&arg_types)?;
179 accumulator.state_fields(&arg_types)
180 }
181
182 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
183 let arg_types = arg_fields
184 .iter()
185 .map(|field| SedonaType::from_storage_field(field))
186 .collect::<Result<Vec<_>>>()?;
187 let (_, out_type) = self.dispatch_impl(&arg_types)?;
188 Ok(Arc::new(out_type.to_storage_field("", true)?))
189 }
190
191 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
192 sedona_internal_err!("return_type() should not be called (use return_field())")
193 }
194
195 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
196 if let Ok(arg_types) = Self::accumulator_arg_types(&args) {
197 if let Ok((accumulator, _)) = self.dispatch_impl(&arg_types) {
198 return accumulator.groups_accumulator_supported(&arg_types);
199 }
200 }
201
202 false
203 }
204
205 fn create_groups_accumulator(
206 &self,
207 args: AccumulatorArgs,
208 ) -> Result<Box<dyn GroupsAccumulator>> {
209 let arg_types = Self::accumulator_arg_types(&args)?;
210 let (accumulator, output_type) = self.dispatch_impl(&arg_types)?;
211 accumulator.groups_accumulator(&arg_types, &output_type)
212 }
213
214 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
215 let arg_types = Self::accumulator_arg_types(&acc_args)?;
216 let (accumulator, output_type) = self.dispatch_impl(&arg_types)?;
217 accumulator.accumulator(&arg_types, &output_type)
218 }
219
220 fn documentation(&self) -> Option<&Documentation> {
221 None
222 }
223}
224
225pub trait SedonaAccumulator: Debug + Send + Sync {
226 fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>>;
228
229 fn accumulator(
237 &self,
238 args: &[SedonaType],
239 output_type: &SedonaType,
240 ) -> Result<Box<dyn Accumulator>>;
241
242 fn groups_accumulator_supported(&self, _args: &[SedonaType]) -> bool {
244 false
245 }
246
247 fn groups_accumulator(
253 &self,
254 _args: &[SedonaType],
255 _output_type: &SedonaType,
256 ) -> Result<Box<dyn GroupsAccumulator>> {
257 sedona_internal_err!("groups_accumulator not supported for {self:?}")
258 }
259
260 fn state_fields(&self, args: &[SedonaType]) -> Result<Vec<FieldRef>>;
262}
263
264#[cfg(test)]
265mod test {
266 use crate::aggregate_udf::SedonaAggregateUDF;
267
268 use super::*;
269
270 #[test]
271 fn udaf_empty() -> Result<()> {
272 let udf = SedonaAggregateUDF::new(
274 "empty",
275 Vec::<SedonaAccumulatorRef>::new(),
276 Volatility::Immutable,
277 );
278 assert_eq!(udf.name(), "empty");
279 let err = udf.return_field(&[]).unwrap_err();
280 assert_eq!(err.message(), "empty(): No kernel matching arguments");
281 assert!(udf.kernels().is_empty());
282 assert_eq!(udf.coerce_types(&[])?, vec![]);
283
284 let batch_err = udf.return_field(&[]).unwrap_err();
285 assert_eq!(batch_err.message(), "empty(): No kernel matching arguments");
286
287 Ok(())
288 }
289}