sedona_expr/
function_set.rs1use crate::{
18 aggregate_udf::{SedonaAccumulatorRef, SedonaAggregateUDF},
19 scalar_udf::{ScalarKernelRef, SedonaScalarUDF},
20};
21use datafusion_common::error::Result;
22use datafusion_expr::{AggregateUDFImpl, ScalarUDFImpl};
23use sedona_common::sedona_internal_err;
24use std::collections::HashMap;
25
26pub struct FunctionSet {
33 scalar_udfs: HashMap<String, SedonaScalarUDF>,
34 aggregate_udfs: HashMap<String, SedonaAggregateUDF>,
35}
36
37impl FunctionSet {
38 pub fn new() -> Self {
40 Self {
41 scalar_udfs: HashMap::new(),
42 aggregate_udfs: HashMap::new(),
43 }
44 }
45
46 pub fn scalar_udfs(&self) -> impl Iterator<Item = &SedonaScalarUDF> + '_ {
48 self.scalar_udfs.values()
49 }
50
51 pub fn scalar_udf(&self, name: &str) -> Option<&SedonaScalarUDF> {
53 self.scalar_udfs.get(name)
54 }
55
56 pub fn scalar_udf_mut(&mut self, name: &str) -> Option<&mut SedonaScalarUDF> {
58 self.scalar_udfs.get_mut(name)
59 }
60
61 pub fn insert_scalar_udf(&mut self, udf: SedonaScalarUDF) -> Option<SedonaScalarUDF> {
63 self.scalar_udfs.insert(udf.name().to_string(), udf)
64 }
65
66 pub fn aggregate_udfs(&self) -> impl Iterator<Item = &SedonaAggregateUDF> + '_ {
68 self.aggregate_udfs.values()
69 }
70
71 pub fn aggregate_udf(&self, name: &str) -> Option<&SedonaAggregateUDF> {
73 self.aggregate_udfs.get(name)
74 }
75
76 pub fn aggregate_udf_mut(&mut self, name: &str) -> Option<&mut SedonaAggregateUDF> {
78 self.aggregate_udfs.get_mut(name)
79 }
80
81 pub fn insert_aggregate_udf(&mut self, udf: SedonaAggregateUDF) -> Option<SedonaAggregateUDF> {
83 self.aggregate_udfs.insert(udf.name().to_string(), udf)
84 }
85
86 pub fn merge(&mut self, other: FunctionSet) {
88 for (k, v) in other.scalar_udfs.into_iter() {
89 self.scalar_udfs.insert(k, v);
90 }
91
92 for (k, v) in other.aggregate_udfs.into_iter() {
93 self.aggregate_udfs.insert(k, v);
94 }
95 }
96
97 pub fn add_scalar_udf_kernel(
103 &mut self,
104 name: &str,
105 kernel: ScalarKernelRef,
106 ) -> Result<&SedonaScalarUDF> {
107 if let Some(function) = self.scalar_udf_mut(name) {
108 function.add_kernel(kernel);
109 } else {
110 let function = SedonaScalarUDF::from_kernel(name, kernel);
111 self.insert_scalar_udf(function);
112 }
113
114 Ok(self.scalar_udf(name).unwrap())
115 }
116
117 pub fn add_aggregate_udf_kernel(
122 &mut self,
123 name: &str,
124 kernel: SedonaAccumulatorRef,
125 ) -> Result<&SedonaAggregateUDF> {
126 if let Some(function) = self.aggregate_udf_mut(name) {
127 function.add_kernel(kernel);
128 Ok(self.aggregate_udf(name).unwrap())
129 } else {
130 sedona_internal_err!("Can't register aggregate kernel for function '{}'", name)
131 }
132 }
133}
134
135impl Default for FunctionSet {
136 fn default() -> Self {
137 Self::new()
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use std::{collections::HashSet, sync::Arc};
144
145 use arrow_schema::{DataType, FieldRef};
146 use datafusion_common::{not_impl_err, scalar::ScalarValue};
147
148 use datafusion_expr::{Accumulator, ColumnarValue, Volatility};
149 use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher};
150
151 use crate::{aggregate_udf::SedonaAccumulator, scalar_udf::SimpleSedonaScalarKernel};
152
153 use super::*;
154
155 #[test]
156 fn function_set() {
157 let mut functions = FunctionSet::new();
158 assert_eq!(functions.scalar_udfs().collect::<Vec<_>>().len(), 0);
159 assert!(functions.scalar_udf("simple_udf").is_none());
160 assert!(functions.scalar_udf_mut("simple_udf").is_none());
161
162 let kernel = SimpleSedonaScalarKernel::new_ref(
163 ArgMatcher::new(
164 vec![ArgMatcher::is_arrow(DataType::Boolean)],
165 SedonaType::Arrow(DataType::Boolean),
166 ),
167 Arc::new(|_, _| Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)))),
168 );
169
170 let udf = SedonaScalarUDF::new(
171 "simple_udf",
172 vec![kernel.clone()],
173 Volatility::Immutable,
174 None,
175 );
176
177 functions.insert_scalar_udf(udf);
178 assert_eq!(functions.scalar_udfs().collect::<Vec<_>>().len(), 1);
179 assert!(functions.scalar_udf("simple_udf").is_some());
180 assert!(functions.scalar_udf_mut("simple_udf").is_some());
181 assert_eq!(
182 functions
183 .add_scalar_udf_kernel("simple_udf", kernel.clone())
184 .unwrap()
185 .name(),
186 "simple_udf"
187 );
188 let inserted_udf = functions
189 .add_scalar_udf_kernel("function that does not yet exist", kernel.clone())
190 .unwrap();
191 assert_eq!(inserted_udf.name(), "function that does not yet exist");
192
193 let kernel2 = SimpleSedonaScalarKernel::new_ref(
194 ArgMatcher::new(
195 vec![ArgMatcher::is_arrow(DataType::Utf8)],
196 SedonaType::Arrow(DataType::Utf8),
197 ),
198 Arc::new(|_, _| Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))),
199 );
200
201 let udf2 = SedonaScalarUDF::new("simple_udf2", vec![kernel2], Volatility::Immutable, None);
202 let mut functions2 = FunctionSet::new();
203 functions2.insert_scalar_udf(udf2);
204 functions.merge(functions2);
205 assert_eq!(
206 functions
207 .scalar_udfs()
208 .map(|s| s.name())
209 .collect::<HashSet<_>>(),
210 vec![
211 "simple_udf",
212 "simple_udf2",
213 "function that does not yet exist"
214 ]
215 .into_iter()
216 .collect::<HashSet<_>>()
217 );
218 }
219
220 #[derive(Debug)]
221 struct TestAccumulator {}
222
223 impl SedonaAccumulator for TestAccumulator {
224 fn return_type(&self, _args: &[SedonaType]) -> Result<Option<SedonaType>> {
225 not_impl_err!("")
226 }
227
228 fn accumulator(
229 &self,
230 _args: &[SedonaType],
231 _output_type: &SedonaType,
232 ) -> Result<Box<dyn Accumulator>> {
233 not_impl_err!("")
234 }
235
236 fn state_fields(&self, _args: &[SedonaType]) -> Result<Vec<FieldRef>> {
237 not_impl_err!("")
238 }
239 }
240
241 #[test]
242 fn function_set_with_aggregates() {
243 let mut functions = FunctionSet::new();
244 assert_eq!(functions.scalar_udfs().collect::<Vec<_>>().len(), 0);
245 assert!(functions.aggregate_udf("simple_udaf").is_none());
246 assert!(functions.aggregate_udf_mut("simple_udaf").is_none());
247
248 let udaf = SedonaAggregateUDF::new("simple_udaf", vec![], Volatility::Immutable, None);
249 let kernel = Arc::new(TestAccumulator {});
250
251 functions.insert_aggregate_udf(udaf);
252 assert_eq!(functions.aggregate_udfs().collect::<Vec<_>>().len(), 1);
253 assert!(functions.aggregate_udf("simple_udaf").is_some());
254 assert!(functions.aggregate_udf_mut("simple_udaf").is_some());
255 assert_eq!(
256 functions
257 .add_aggregate_udf_kernel("simple_udaf", kernel.clone())
258 .unwrap()
259 .name(),
260 "simple_udaf"
261 );
262 let err = functions
263 .add_aggregate_udf_kernel("function that does not exist", kernel.clone())
264 .unwrap_err();
265 assert!(err.message().lines().next().unwrap().contains(
266 "Can't register aggregate kernel for function 'function that does not exist'."
267 ));
268
269 let udaf2 = SedonaAggregateUDF::new(
270 "simple_udaf2",
271 vec![kernel.clone()],
272 Volatility::Immutable,
273 None,
274 );
275 let mut functions2 = FunctionSet::new();
276 functions2.insert_aggregate_udf(udaf2);
277 functions.merge(functions2);
278 assert_eq!(
279 functions
280 .aggregate_udfs()
281 .map(|s| s.name())
282 .collect::<HashSet<_>>(),
283 vec!["simple_udaf", "simple_udaf2"]
284 .into_iter()
285 .collect::<HashSet<_>>()
286 );
287 }
288}