sedona_expr/
function_set.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17use crate::{
18    aggregate_udf::{SedonaAccumulatorRef, SedonaAggregateUDF},
19    scalar_udf::{ScalarKernelRef, SedonaScalarUDF},
20};
21use datafusion_common::error::Result;
22use datafusion_expr::{AggregateUDFImpl, ScalarUDFImpl};
23use sedona_common::sedona_internal_err;
24use std::collections::HashMap;
25
26/// Helper for managing groups of functions
27///
28/// Sedona coordinates the assembly of a large number of spatial functions with potentially
29/// different sets of dependencies (e.g., geography vs. geometry), multiple implementations,
30/// and/or implementations that live in different crates. This structure helps coordinate
31/// these implementations.
32pub struct FunctionSet {
33    scalar_udfs: HashMap<String, SedonaScalarUDF>,
34    aggregate_udfs: HashMap<String, SedonaAggregateUDF>,
35}
36
37impl FunctionSet {
38    /// Create a new, empty FunctionSet
39    pub fn new() -> Self {
40        Self {
41            scalar_udfs: HashMap::new(),
42            aggregate_udfs: HashMap::new(),
43        }
44    }
45
46    /// Iterate over references to all [SedonaScalarUDF]s
47    pub fn scalar_udfs(&self) -> impl Iterator<Item = &SedonaScalarUDF> + '_ {
48        self.scalar_udfs.values()
49    }
50
51    /// Return a reference to the function corresponding to the name
52    pub fn scalar_udf(&self, name: &str) -> Option<&SedonaScalarUDF> {
53        self.scalar_udfs.get(name)
54    }
55
56    /// Return a mutable reference to the function corresponding to the name
57    pub fn scalar_udf_mut(&mut self, name: &str) -> Option<&mut SedonaScalarUDF> {
58        self.scalar_udfs.get_mut(name)
59    }
60
61    /// Insert a new ScalarUDF and return the UDF that had previously been added, if any
62    pub fn insert_scalar_udf(&mut self, udf: SedonaScalarUDF) -> Option<SedonaScalarUDF> {
63        self.scalar_udfs.insert(udf.name().to_string(), udf)
64    }
65
66    /// Iterate over references to all [SedonaAggregateUDF]s
67    pub fn aggregate_udfs(&self) -> impl Iterator<Item = &SedonaAggregateUDF> + '_ {
68        self.aggregate_udfs.values()
69    }
70
71    /// Return a reference to the aggregate function corresponding to the name
72    pub fn aggregate_udf(&self, name: &str) -> Option<&SedonaAggregateUDF> {
73        self.aggregate_udfs.get(name)
74    }
75
76    /// Return a mutable reference to the aggregate function corresponding to the name
77    pub fn aggregate_udf_mut(&mut self, name: &str) -> Option<&mut SedonaAggregateUDF> {
78        self.aggregate_udfs.get_mut(name)
79    }
80
81    /// Insert a new AggregateUDF and return the UDF that had previously been added, if any
82    pub fn insert_aggregate_udf(&mut self, udf: SedonaAggregateUDF) -> Option<SedonaAggregateUDF> {
83        self.aggregate_udfs.insert(udf.name().to_string(), udf)
84    }
85
86    /// Consume another function set and merge its contents into this one
87    pub fn merge(&mut self, other: FunctionSet) {
88        for (k, v) in other.scalar_udfs.into_iter() {
89            self.scalar_udfs.insert(k, v);
90        }
91
92        for (k, v) in other.aggregate_udfs.into_iter() {
93            self.aggregate_udfs.insert(k, v);
94        }
95    }
96
97    /// Add a kernel to a function in this set
98    ///
99    /// This adds a scalar UDF with immutable output and no documentation if a
100    /// function of that name does not exist in this set. A reference to the
101    /// matching function is returned.
102    pub fn add_scalar_udf_kernel(
103        &mut self,
104        name: &str,
105        kernel: ScalarKernelRef,
106    ) -> Result<&SedonaScalarUDF> {
107        if let Some(function) = self.scalar_udf_mut(name) {
108            function.add_kernel(kernel);
109        } else {
110            let function = SedonaScalarUDF::from_kernel(name, kernel);
111            self.insert_scalar_udf(function);
112        }
113
114        Ok(self.scalar_udf(name).unwrap())
115    }
116
117    /// Add an aggregate kernel to a function in this set
118    ///
119    /// This errors if a function of that name does not exist in this set. A reference
120    /// to the matching function is returned.
121    pub fn add_aggregate_udf_kernel(
122        &mut self,
123        name: &str,
124        kernel: SedonaAccumulatorRef,
125    ) -> Result<&SedonaAggregateUDF> {
126        if let Some(function) = self.aggregate_udf_mut(name) {
127            function.add_kernel(kernel);
128            Ok(self.aggregate_udf(name).unwrap())
129        } else {
130            sedona_internal_err!("Can't register aggregate kernel for function '{}'", name)
131        }
132    }
133}
134
135impl Default for FunctionSet {
136    fn default() -> Self {
137        Self::new()
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use std::{collections::HashSet, sync::Arc};
144
145    use arrow_schema::{DataType, FieldRef};
146    use datafusion_common::{not_impl_err, scalar::ScalarValue};
147
148    use datafusion_expr::{Accumulator, ColumnarValue, Volatility};
149    use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher};
150
151    use crate::{aggregate_udf::SedonaAccumulator, scalar_udf::SimpleSedonaScalarKernel};
152
153    use super::*;
154
155    #[test]
156    fn function_set() {
157        let mut functions = FunctionSet::new();
158        assert_eq!(functions.scalar_udfs().collect::<Vec<_>>().len(), 0);
159        assert!(functions.scalar_udf("simple_udf").is_none());
160        assert!(functions.scalar_udf_mut("simple_udf").is_none());
161
162        let kernel = SimpleSedonaScalarKernel::new_ref(
163            ArgMatcher::new(
164                vec![ArgMatcher::is_arrow(DataType::Boolean)],
165                SedonaType::Arrow(DataType::Boolean),
166            ),
167            Arc::new(|_, _| Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)))),
168        );
169
170        let udf = SedonaScalarUDF::new(
171            "simple_udf",
172            vec![kernel.clone()],
173            Volatility::Immutable,
174            None,
175        );
176
177        functions.insert_scalar_udf(udf);
178        assert_eq!(functions.scalar_udfs().collect::<Vec<_>>().len(), 1);
179        assert!(functions.scalar_udf("simple_udf").is_some());
180        assert!(functions.scalar_udf_mut("simple_udf").is_some());
181        assert_eq!(
182            functions
183                .add_scalar_udf_kernel("simple_udf", kernel.clone())
184                .unwrap()
185                .name(),
186            "simple_udf"
187        );
188        let inserted_udf = functions
189            .add_scalar_udf_kernel("function that does not yet exist", kernel.clone())
190            .unwrap();
191        assert_eq!(inserted_udf.name(), "function that does not yet exist");
192
193        let kernel2 = SimpleSedonaScalarKernel::new_ref(
194            ArgMatcher::new(
195                vec![ArgMatcher::is_arrow(DataType::Utf8)],
196                SedonaType::Arrow(DataType::Utf8),
197            ),
198            Arc::new(|_, _| Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))),
199        );
200
201        let udf2 = SedonaScalarUDF::new("simple_udf2", vec![kernel2], Volatility::Immutable, None);
202        let mut functions2 = FunctionSet::new();
203        functions2.insert_scalar_udf(udf2);
204        functions.merge(functions2);
205        assert_eq!(
206            functions
207                .scalar_udfs()
208                .map(|s| s.name())
209                .collect::<HashSet<_>>(),
210            vec![
211                "simple_udf",
212                "simple_udf2",
213                "function that does not yet exist"
214            ]
215            .into_iter()
216            .collect::<HashSet<_>>()
217        );
218    }
219
220    #[derive(Debug)]
221    struct TestAccumulator {}
222
223    impl SedonaAccumulator for TestAccumulator {
224        fn return_type(&self, _args: &[SedonaType]) -> Result<Option<SedonaType>> {
225            not_impl_err!("")
226        }
227
228        fn accumulator(
229            &self,
230            _args: &[SedonaType],
231            _output_type: &SedonaType,
232        ) -> Result<Box<dyn Accumulator>> {
233            not_impl_err!("")
234        }
235
236        fn state_fields(&self, _args: &[SedonaType]) -> Result<Vec<FieldRef>> {
237            not_impl_err!("")
238        }
239    }
240
241    #[test]
242    fn function_set_with_aggregates() {
243        let mut functions = FunctionSet::new();
244        assert_eq!(functions.scalar_udfs().collect::<Vec<_>>().len(), 0);
245        assert!(functions.aggregate_udf("simple_udaf").is_none());
246        assert!(functions.aggregate_udf_mut("simple_udaf").is_none());
247
248        let udaf = SedonaAggregateUDF::new("simple_udaf", vec![], Volatility::Immutable, None);
249        let kernel = Arc::new(TestAccumulator {});
250
251        functions.insert_aggregate_udf(udaf);
252        assert_eq!(functions.aggregate_udfs().collect::<Vec<_>>().len(), 1);
253        assert!(functions.aggregate_udf("simple_udaf").is_some());
254        assert!(functions.aggregate_udf_mut("simple_udaf").is_some());
255        assert_eq!(
256            functions
257                .add_aggregate_udf_kernel("simple_udaf", kernel.clone())
258                .unwrap()
259                .name(),
260            "simple_udaf"
261        );
262        let err = functions
263            .add_aggregate_udf_kernel("function that does not exist", kernel.clone())
264            .unwrap_err();
265        assert!(err.message().lines().next().unwrap().contains(
266            "Can't register aggregate kernel for function 'function that does not exist'."
267        ));
268
269        let udaf2 = SedonaAggregateUDF::new(
270            "simple_udaf2",
271            vec![kernel.clone()],
272            Volatility::Immutable,
273            None,
274        );
275        let mut functions2 = FunctionSet::new();
276        functions2.insert_aggregate_udf(udaf2);
277        functions.merge(functions2);
278        assert_eq!(
279            functions
280                .aggregate_udfs()
281                .map(|s| s.name())
282                .collect::<HashSet<_>>(),
283            vec!["simple_udaf", "simple_udaf2"]
284                .into_iter()
285                .collect::<HashSet<_>>()
286        );
287    }
288}