sedona_expr/
aggregate_udf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17use std::{any::Any, fmt::Debug, sync::Arc};
18
19use arrow_schema::{DataType, FieldRef};
20use datafusion_common::{not_impl_err, Result};
21use datafusion_expr::{
22    function::{AccumulatorArgs, StateFieldsArgs},
23    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
24};
25use sedona_common::sedona_internal_err;
26use sedona_schema::datatypes::SedonaType;
27
28use sedona_schema::matchers::ArgMatcher;
29
30pub type SedonaAccumulatorRef = Arc<dyn SedonaAccumulator + Send + Sync>;
31
32/// Top-level aggregate user-defined function
33///
34/// This struct implements datafusion's AggregateUDFImpl and implements kernel dispatch
35/// such that implementations can be registered flexibly.
36#[derive(Debug, Clone)]
37pub struct SedonaAggregateUDF {
38    name: String,
39    signature: Signature,
40    kernels: Vec<SedonaAccumulatorRef>,
41    documentation: Option<Documentation>,
42}
43
44impl PartialEq for SedonaAggregateUDF {
45    fn eq(&self, other: &Self) -> bool {
46        self.name == other.name
47    }
48}
49
50impl Eq for SedonaAggregateUDF {}
51
52impl std::hash::Hash for SedonaAggregateUDF {
53    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
54        self.name.hash(state);
55    }
56}
57
58impl SedonaAggregateUDF {
59    /// Create a new SedonaAggregateUDF
60    pub fn new(
61        name: &str,
62        kernels: Vec<SedonaAccumulatorRef>,
63        volatility: Volatility,
64        documentation: Option<Documentation>,
65    ) -> Self {
66        let signature = Signature::user_defined(volatility);
67        Self {
68            name: name.to_string(),
69            signature,
70            kernels,
71            documentation,
72        }
73    }
74
75    /// Create a new stub aggregate function
76    ///
77    /// Creates a new aggregate function that calculates a return type but fails when
78    /// invoked with arguments. This is useful to create stub functions when it is
79    /// expected that the actual functionality will be registered from one or more
80    /// independent crates (e.g., ST_Union_Agg(), which may be implemented in
81    /// sedona-geo or sedona-geography).
82    pub fn new_stub(
83        name: &str,
84        arg_matcher: ArgMatcher,
85        volatility: Volatility,
86        documentation: Option<Documentation>,
87    ) -> Self {
88        let stub_kernel = StubAccumulator::new(name.to_string(), arg_matcher);
89        Self::new(name, vec![Arc::new(stub_kernel)], volatility, documentation)
90    }
91
92    /// Add a new kernel to an Aggregate UDF
93    ///
94    /// Because kernels are resolved in reverse order, the new kernel will take
95    /// precedence over any previously added kernels that apply to the same types.
96    pub fn add_kernel(&mut self, kernel: SedonaAccumulatorRef) {
97        self.kernels.push(kernel);
98    }
99
100    // List the current kernels
101    pub fn kernels(&self) -> &[SedonaAccumulatorRef] {
102        &self.kernels
103    }
104
105    fn dispatch_impl(&self, args: &[SedonaType]) -> Result<(&dyn SedonaAccumulator, SedonaType)> {
106        // Resolve kernels in reverse so that more recently added ones are resolved first
107        for kernel in self.kernels.iter().rev() {
108            if let Some(return_type) = kernel.return_type(args)? {
109                return Ok((kernel.as_ref(), return_type));
110            }
111        }
112
113        not_impl_err!("{}({:?}): No kernel matching arguments", self.name, args)
114    }
115}
116
117impl AggregateUDFImpl for SedonaAggregateUDF {
118    fn as_any(&self) -> &dyn Any {
119        self
120    }
121
122    fn name(&self) -> &str {
123        &self.name
124    }
125
126    fn signature(&self) -> &Signature {
127        &self.signature
128    }
129
130    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
131        Ok(arg_types.into())
132    }
133
134    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
135        let arg_types = args
136            .input_fields
137            .iter()
138            .map(|field| SedonaType::from_storage_field(field))
139            .collect::<Result<Vec<_>>>()?;
140        let (accumulator, _) = self.dispatch_impl(&arg_types)?;
141        accumulator.state_fields(&arg_types)
142    }
143
144    fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
145        let arg_types = arg_fields
146            .iter()
147            .map(|field| SedonaType::from_storage_field(field))
148            .collect::<Result<Vec<_>>>()?;
149        let (_, out_type) = self.dispatch_impl(&arg_types)?;
150        Ok(Arc::new(out_type.to_storage_field("", true)?))
151    }
152
153    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
154        sedona_internal_err!("return_type() should not be called (use return_field())")
155    }
156
157    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
158        let arg_fields = acc_args
159            .exprs
160            .iter()
161            .map(|expr| expr.return_field(acc_args.schema))
162            .collect::<Result<Vec<_>>>()?;
163        let arg_types = arg_fields
164            .iter()
165            .map(|field| SedonaType::from_storage_field(field))
166            .collect::<Result<Vec<_>>>()?;
167        let (accumulator, output_type) = self.dispatch_impl(&arg_types)?;
168        accumulator.accumulator(&arg_types, &output_type)
169    }
170
171    fn documentation(&self) -> Option<&Documentation> {
172        self.documentation.as_ref()
173    }
174}
175
176pub trait SedonaAccumulator: Debug {
177    /// Given input data types, calculate an output data type
178    fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>>;
179
180    /// Given input data types and previously-calculated output data type,
181    /// resolve an [Accumulator]
182    ///
183    /// The Accumulator provides the underlying DataFusion implementation.
184    /// The SedonaAccumulator does not perform any wrapping or unwrapping on the
185    /// accumulator arguments or return values (in anticipation of wrapping/unwrapping
186    /// being reverted in the near future).
187    fn accumulator(
188        &self,
189        args: &[SedonaType],
190        output_type: &SedonaType,
191    ) -> Result<Box<dyn Accumulator>>;
192
193    /// The fields representing the underlying serialized state of the Accumulator
194    fn state_fields(&self, args: &[SedonaType]) -> Result<Vec<FieldRef>>;
195}
196
197#[derive(Debug)]
198struct StubAccumulator {
199    name: String,
200    matcher: ArgMatcher,
201}
202
203impl StubAccumulator {
204    fn new(name: String, matcher: ArgMatcher) -> Self {
205        Self { name, matcher }
206    }
207}
208
209impl SedonaAccumulator for StubAccumulator {
210    fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
211        self.matcher.match_args(args)
212    }
213
214    fn accumulator(
215        &self,
216        args: &[SedonaType],
217        _output_type: &SedonaType,
218    ) -> Result<Box<dyn Accumulator>> {
219        not_impl_err!(
220            "Implementation for {}({args:?}) was not registered",
221            self.name
222        )
223    }
224
225    fn state_fields(&self, _args: &[SedonaType]) -> Result<Vec<FieldRef>> {
226        Ok(vec![])
227    }
228}
229
230#[cfg(test)]
231mod test {
232    use sedona_testing::testers::AggregateUdfTester;
233
234    use crate::aggregate_udf::SedonaAggregateUDF;
235
236    use super::*;
237
238    #[test]
239    fn udaf_empty() -> Result<()> {
240        // UDF with no implementations
241        let udf = SedonaAggregateUDF::new("empty", vec![], Volatility::Immutable, None);
242        assert_eq!(udf.name(), "empty");
243        let err = udf.return_field(&[]).unwrap_err();
244        assert_eq!(err.message(), "empty([]): No kernel matching arguments");
245        assert!(udf.kernels().is_empty());
246        assert_eq!(udf.coerce_types(&[])?, vec![]);
247
248        let batch_err = udf.return_field(&[]).unwrap_err();
249        assert_eq!(
250            batch_err.message(),
251            "empty([]): No kernel matching arguments"
252        );
253
254        Ok(())
255    }
256
257    #[test]
258    fn stub() {
259        let stub = SedonaAggregateUDF::new_stub(
260            "stubby",
261            ArgMatcher::new(vec![], SedonaType::Arrow(DataType::Boolean)),
262            Volatility::Immutable,
263            None,
264        );
265
266        // We registered the stub with zero arguments, so when we call it
267        // with zero arguments it should calculate a return type but
268        // produce our stub error message when used.
269        let tester = AggregateUdfTester::new(stub.clone().into(), vec![]);
270        assert_eq!(
271            tester.return_type().unwrap(),
272            SedonaType::Arrow(DataType::Boolean)
273        );
274
275        let err = tester.aggregate(&vec![]).unwrap_err();
276        assert_eq!(
277            err.message(),
278            "Implementation for stubby([]) was not registered"
279        );
280
281        // If we call with anything else, we shouldn't be able to do anything
282        let tester = AggregateUdfTester::new(
283            stub.clone().into(),
284            vec![SedonaType::Arrow(DataType::Binary)],
285        );
286        let err = tester.return_type().unwrap_err();
287        assert_eq!(
288            err.message(),
289            "stubby([Arrow(Binary)]): No kernel matching arguments"
290        );
291    }
292}