Skip to main content

sedona_expr/
aggregate_udf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17use std::{any::Any, fmt::Debug, sync::Arc};
18
19use arrow_schema::{DataType, FieldRef};
20use datafusion_common::{not_impl_err, Result};
21use datafusion_expr::{
22    function::{AccumulatorArgs, StateFieldsArgs},
23    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, Volatility,
24};
25use sedona_common::sedona_internal_err;
26use sedona_schema::datatypes::SedonaType;
27
28/// Shorthand for a [SedonaAccumulator] reference
29pub type SedonaAccumulatorRef = Arc<dyn SedonaAccumulator>;
30
31/// Helper to resolve an iterable of accumulators
32pub trait IntoSedonaAccumulatorRefs {
33    fn into_sedona_accumulator_refs(self) -> Vec<SedonaAccumulatorRef>;
34}
35
36impl IntoSedonaAccumulatorRefs for SedonaAccumulatorRef {
37    fn into_sedona_accumulator_refs(self) -> Vec<SedonaAccumulatorRef> {
38        vec![self]
39    }
40}
41
42impl IntoSedonaAccumulatorRefs for Vec<SedonaAccumulatorRef> {
43    fn into_sedona_accumulator_refs(self) -> Vec<SedonaAccumulatorRef> {
44        self
45    }
46}
47
48impl<T: SedonaAccumulator + 'static> IntoSedonaAccumulatorRefs for T {
49    fn into_sedona_accumulator_refs(self) -> Vec<SedonaAccumulatorRef> {
50        vec![Arc::new(self)]
51    }
52}
53
54impl<T: SedonaAccumulator + 'static> IntoSedonaAccumulatorRefs for Vec<Arc<T>> {
55    fn into_sedona_accumulator_refs(self) -> Vec<SedonaAccumulatorRef> {
56        self.into_iter()
57            .map(|item| item as SedonaAccumulatorRef)
58            .collect()
59    }
60}
61
62/// Top-level aggregate user-defined function
63///
64/// This struct implements datafusion's AggregateUDFImpl and implements kernel dispatch
65/// such that implementations can be registered flexibly.
66#[derive(Debug, Clone)]
67pub struct SedonaAggregateUDF {
68    name: String,
69    signature: Signature,
70    kernels: Vec<SedonaAccumulatorRef>,
71}
72
73impl PartialEq for SedonaAggregateUDF {
74    fn eq(&self, other: &Self) -> bool {
75        self.name == other.name
76    }
77}
78
79impl Eq for SedonaAggregateUDF {}
80
81impl std::hash::Hash for SedonaAggregateUDF {
82    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
83        self.name.hash(state);
84    }
85}
86
87impl SedonaAggregateUDF {
88    /// Create a new SedonaAggregateUDF
89    pub fn new(
90        name: &str,
91        kernels: impl IntoSedonaAccumulatorRefs,
92        volatility: Volatility,
93    ) -> Self {
94        let signature = Signature::user_defined(volatility);
95        Self {
96            name: name.to_string(),
97            signature,
98            kernels: kernels.into_sedona_accumulator_refs(),
99        }
100    }
101
102    /// Create a new immutable SedonaAggregateUDF
103    pub fn from_impl(name: &str, kernels: impl IntoSedonaAccumulatorRefs) -> Self {
104        Self::new(name, kernels, Volatility::Immutable)
105    }
106
107    /// Add a new kernel to an Aggregate UDF
108    ///
109    /// Because kernels are resolved in reverse order, the new kernel will take
110    /// precedence over any previously added kernels that apply to the same types.
111    pub fn add_kernel(&mut self, kernels: impl IntoSedonaAccumulatorRefs) {
112        for kernel in kernels.into_sedona_accumulator_refs() {
113            self.kernels.push(kernel);
114        }
115    }
116
117    // List the current kernels
118    pub fn kernels(&self) -> &[SedonaAccumulatorRef] {
119        &self.kernels
120    }
121
122    fn accumulator_arg_types(args: &AccumulatorArgs) -> Result<Vec<SedonaType>> {
123        let arg_fields = args
124            .exprs
125            .iter()
126            .map(|expr| expr.return_field(args.schema))
127            .collect::<Result<Vec<_>>>()?;
128        arg_fields
129            .iter()
130            .map(|field| SedonaType::from_storage_field(field))
131            .collect()
132    }
133
134    fn dispatch_impl(&self, args: &[SedonaType]) -> Result<(&dyn SedonaAccumulator, SedonaType)> {
135        // Resolve kernels in reverse so that more recently added ones are resolved first
136        for kernel in self.kernels.iter().rev() {
137            if let Some(return_type) = kernel.return_type(args)? {
138                return Ok((kernel.as_ref(), return_type));
139            }
140        }
141
142        let args_display = args
143            .iter()
144            .map(|arg| arg.logical_type_name())
145            .collect::<Vec<_>>()
146            .join(", ");
147
148        not_impl_err!(
149            "{}({args_display}): No kernel matching arguments",
150            self.name
151        )
152    }
153}
154
155impl AggregateUDFImpl for SedonaAggregateUDF {
156    fn as_any(&self) -> &dyn Any {
157        self
158    }
159
160    fn name(&self) -> &str {
161        &self.name
162    }
163
164    fn signature(&self) -> &Signature {
165        &self.signature
166    }
167
168    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
169        Ok(arg_types.into())
170    }
171
172    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
173        let arg_types = args
174            .input_fields
175            .iter()
176            .map(|field| SedonaType::from_storage_field(field))
177            .collect::<Result<Vec<_>>>()?;
178        let (accumulator, _) = self.dispatch_impl(&arg_types)?;
179        accumulator.state_fields(&arg_types)
180    }
181
182    fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
183        let arg_types = arg_fields
184            .iter()
185            .map(|field| SedonaType::from_storage_field(field))
186            .collect::<Result<Vec<_>>>()?;
187        let (_, out_type) = self.dispatch_impl(&arg_types)?;
188        Ok(Arc::new(out_type.to_storage_field("", true)?))
189    }
190
191    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
192        sedona_internal_err!("return_type() should not be called (use return_field())")
193    }
194
195    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
196        if let Ok(arg_types) = Self::accumulator_arg_types(&args) {
197            if let Ok((accumulator, _)) = self.dispatch_impl(&arg_types) {
198                return accumulator.groups_accumulator_supported(&arg_types);
199            }
200        }
201
202        false
203    }
204
205    fn create_groups_accumulator(
206        &self,
207        args: AccumulatorArgs,
208    ) -> Result<Box<dyn GroupsAccumulator>> {
209        let arg_types = Self::accumulator_arg_types(&args)?;
210        let (accumulator, output_type) = self.dispatch_impl(&arg_types)?;
211        accumulator.groups_accumulator(&arg_types, &output_type)
212    }
213
214    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
215        let arg_types = Self::accumulator_arg_types(&acc_args)?;
216        let (accumulator, output_type) = self.dispatch_impl(&arg_types)?;
217        accumulator.accumulator(&arg_types, &output_type)
218    }
219
220    fn documentation(&self) -> Option<&Documentation> {
221        None
222    }
223}
224
225pub trait SedonaAccumulator: Debug + Send + Sync {
226    /// Given input data types, calculate an output data type
227    fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>>;
228
229    /// Given input data types and previously-calculated output data type,
230    /// resolve an [Accumulator]
231    ///
232    /// The Accumulator provides the underlying DataFusion implementation.
233    /// The SedonaAccumulator does not perform any wrapping or unwrapping on the
234    /// accumulator arguments or return values (in anticipation of wrapping/unwrapping
235    /// being reverted in the near future).
236    fn accumulator(
237        &self,
238        args: &[SedonaType],
239        output_type: &SedonaType,
240    ) -> Result<Box<dyn Accumulator>>;
241
242    /// Given input data types, check if this implementation supports GroupsAccumulator
243    fn groups_accumulator_supported(&self, _args: &[SedonaType]) -> bool {
244        false
245    }
246
247    /// Given input data types, resolve a [GroupsAccumulator]
248    ///
249    /// A GroupsAccumulator is an important optimization for aggregating many small groups,
250    /// particularly when such an aggregation is cheap. See the DataFusion documentation
251    /// for details.
252    fn groups_accumulator(
253        &self,
254        _args: &[SedonaType],
255        _output_type: &SedonaType,
256    ) -> Result<Box<dyn GroupsAccumulator>> {
257        sedona_internal_err!("groups_accumulator not supported for {self:?}")
258    }
259
260    /// The fields representing the underlying serialized state of the Accumulator
261    fn state_fields(&self, args: &[SedonaType]) -> Result<Vec<FieldRef>>;
262}
263
264#[cfg(test)]
265mod test {
266    use crate::aggregate_udf::SedonaAggregateUDF;
267
268    use super::*;
269
270    #[test]
271    fn udaf_empty() -> Result<()> {
272        // UDF with no implementations
273        let udf = SedonaAggregateUDF::new(
274            "empty",
275            Vec::<SedonaAccumulatorRef>::new(),
276            Volatility::Immutable,
277        );
278        assert_eq!(udf.name(), "empty");
279        let err = udf.return_field(&[]).unwrap_err();
280        assert_eq!(err.message(), "empty(): No kernel matching arguments");
281        assert!(udf.kernels().is_empty());
282        assert_eq!(udf.coerce_types(&[])?, vec![]);
283
284        let batch_err = udf.return_field(&[]).unwrap_err();
285        assert_eq!(batch_err.message(), "empty(): No kernel matching arguments");
286
287        Ok(())
288    }
289}