1use std::{any::Any, fmt::Debug, sync::Arc};
18
19use arrow_schema::{DataType, FieldRef};
20use datafusion_common::config::ConfigOptions;
21use datafusion_common::{not_impl_err, Result, ScalarValue};
22use datafusion_expr::{
23 ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
24 Volatility,
25};
26use sedona_common::sedona_internal_err;
27use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher};
28
29pub type ScalarKernelRef = Arc<dyn SedonaScalarKernel>;
31
32pub trait IntoScalarKernelRefs {
34 fn into_scalar_kernel_refs(self) -> Vec<ScalarKernelRef>;
35}
36
37impl IntoScalarKernelRefs for ScalarKernelRef {
38 fn into_scalar_kernel_refs(self) -> Vec<ScalarKernelRef> {
39 vec![self]
40 }
41}
42
43impl IntoScalarKernelRefs for Vec<ScalarKernelRef> {
44 fn into_scalar_kernel_refs(self) -> Vec<ScalarKernelRef> {
45 self
46 }
47}
48
49impl<T: SedonaScalarKernel + 'static> IntoScalarKernelRefs for T {
50 fn into_scalar_kernel_refs(self) -> Vec<ScalarKernelRef> {
51 vec![Arc::new(self)]
52 }
53}
54
55impl<T: SedonaScalarKernel + 'static> IntoScalarKernelRefs for Vec<Arc<T>> {
56 fn into_scalar_kernel_refs(self) -> Vec<ScalarKernelRef> {
57 self.into_iter()
58 .map(|item| item as ScalarKernelRef)
59 .collect()
60 }
61}
62
63#[derive(Debug, Clone)]
69pub struct SedonaScalarUDF {
70 name: String,
71 signature: Signature,
72 kernels: Vec<ScalarKernelRef>,
73 aliases: Vec<String>,
74}
75
76impl PartialEq for SedonaScalarUDF {
77 fn eq(&self, other: &Self) -> bool {
78 self.name == other.name
79 }
80}
81
82impl Eq for SedonaScalarUDF {}
83
84impl std::hash::Hash for SedonaScalarUDF {
85 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
86 self.name.hash(state);
87 }
88}
89
90pub trait SedonaScalarKernel: Debug + Send + Sync {
97 fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>>;
108
109 fn return_type_from_args_and_scalars(
115 &self,
116 args: &[SedonaType],
117 _scalar_args: &[Option<&ScalarValue>],
118 ) -> Result<Option<SedonaType>> {
119 self.return_type(args)
120 }
121
122 fn invoke_batch(
128 &self,
129 arg_types: &[SedonaType],
130 args: &[ColumnarValue],
131 ) -> Result<ColumnarValue>;
132
133 fn invoke_batch_from_args(
134 &self,
135 arg_types: &[SedonaType],
136 args: &[ColumnarValue],
137 _return_type: &SedonaType,
138 _num_rows: usize,
139 _config_options: Option<&ConfigOptions>,
140 ) -> Result<ColumnarValue> {
141 self.invoke_batch(arg_types, args)
142 }
143}
144
145pub type SedonaScalarKernelImpl =
147 Arc<dyn Fn(&[SedonaType], &[ColumnarValue]) -> Result<ColumnarValue> + Send + Sync>;
148
149pub struct SimpleSedonaScalarKernel {
151 arg_matcher: ArgMatcher,
152 fun: SedonaScalarKernelImpl,
153}
154
155impl Debug for SimpleSedonaScalarKernel {
156 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
157 f.debug_struct("SimpleSedonaScalarKernel").finish()
158 }
159}
160
161impl SimpleSedonaScalarKernel {
162 pub fn new_ref(arg_matcher: ArgMatcher, fun: SedonaScalarKernelImpl) -> ScalarKernelRef {
163 Arc::new(Self { arg_matcher, fun })
164 }
165}
166
167impl SedonaScalarKernel for SimpleSedonaScalarKernel {
168 fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
169 self.arg_matcher.match_args(args)
170 }
171
172 fn invoke_batch(
173 &self,
174 arg_types: &[SedonaType],
175 args: &[ColumnarValue],
176 ) -> Result<ColumnarValue> {
177 (self.fun)(arg_types, args)
178 }
179}
180
181impl SedonaScalarUDF {
182 pub fn new(
184 name: &str,
185 kernels: Vec<ScalarKernelRef>,
186 volatility: Volatility,
187 ) -> SedonaScalarUDF {
188 let signature = Signature::user_defined(volatility);
189 Self {
190 name: name.to_string(),
191 signature,
192 kernels,
193 aliases: vec![],
194 }
195 }
196
197 pub fn with_aliases(self, aliases: Vec<String>) -> SedonaScalarUDF {
199 Self {
200 name: self.name,
201 signature: self.signature,
202 kernels: self.kernels,
203 aliases,
204 }
205 }
206
207 pub fn from_impl(name: &str, kernels: impl IntoScalarKernelRefs) -> SedonaScalarUDF {
212 Self::new(
213 name,
214 kernels.into_scalar_kernel_refs(),
215 Volatility::Immutable,
216 )
217 }
218
219 pub fn add_kernels(&mut self, kernels: impl IntoScalarKernelRefs) {
224 for kernel in kernels.into_scalar_kernel_refs() {
225 self.kernels.push(kernel);
226 }
227 }
228
229 fn return_type_impl(
230 &self,
231 args: &[SedonaType],
232 scalars: &[Option<&ScalarValue>],
233 ) -> Result<(&dyn SedonaScalarKernel, SedonaType)> {
234 for kernel in self.kernels.iter().rev() {
236 if let Some(return_type) = kernel.return_type_from_args_and_scalars(args, scalars)? {
237 return Ok((kernel.as_ref(), return_type));
238 }
239 }
240
241 let args_display = args
242 .iter()
243 .map(|arg| arg.logical_type_name())
244 .collect::<Vec<_>>()
245 .join(", ");
246
247 not_impl_err!(
248 "{}({args_display}): No kernel matching arguments",
249 self.name
250 )
251 }
252}
253
254impl ScalarUDFImpl for SedonaScalarUDF {
255 fn as_any(&self) -> &dyn Any {
256 self
257 }
258
259 fn name(&self) -> &str {
260 &self.name
261 }
262
263 fn signature(&self) -> &Signature {
264 &self.signature
265 }
266
267 fn documentation(&self) -> Option<&Documentation> {
268 None
269 }
270
271 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
272 sedona_internal_err!("Should not be called (use return_field_from_args())")
273 }
274
275 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
276 let arg_types = args
277 .arg_fields
278 .iter()
279 .map(|field| SedonaType::from_storage_field(field))
280 .collect::<Result<Vec<_>>>()?;
281 let (_, out_type) = self.return_type_impl(&arg_types, args.scalar_arguments)?;
282 Ok(Arc::new(out_type.to_storage_field("", true)?))
283 }
284
285 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
286 Ok(arg_types.to_vec())
287 }
288
289 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
290 let arg_types = args
291 .arg_fields
292 .iter()
293 .map(|field| SedonaType::from_storage_field(field))
294 .collect::<Result<Vec<_>>>()?;
295
296 let arg_scalars = args
297 .args
298 .iter()
299 .map(|arg| {
300 if let ColumnarValue::Scalar(scalar) = arg {
301 Some(scalar)
302 } else {
303 None
304 }
305 })
306 .collect::<Vec<_>>();
307
308 let (kernel, return_type) = self.return_type_impl(&arg_types, &arg_scalars)?;
309 kernel.invoke_batch_from_args(
310 &arg_types,
311 &args.args,
312 &return_type,
313 args.number_rows,
314 Some(&*args.config_options),
315 )
316 }
317
318 fn aliases(&self) -> &[String] {
319 &self.aliases
320 }
321}
322
323#[cfg(test)]
324mod tests {
325
326 use datafusion_common::{scalar::ScalarValue, DFSchema};
327 use sedona_testing::testers::ScalarUdfTester;
328
329 use datafusion_expr::{lit, ExprSchemable, ScalarUDF};
330 use sedona_schema::{
331 crs::lnglat,
332 datatypes::{Edges, WKB_GEOMETRY},
333 };
334
335 use super::*;
336
337 #[test]
338 fn udf_empty() -> Result<()> {
339 let udf = SedonaScalarUDF::new("empty", vec![], Volatility::Immutable);
341 assert_eq!(udf.name(), "empty");
342 assert_eq!(udf.coerce_types(&[])?, vec![]);
343
344 let tester = ScalarUdfTester::new(udf.into(), vec![]);
345
346 let err = tester.return_type().unwrap_err();
347 assert_eq!(err.message(), "empty(): No kernel matching arguments");
348
349 let batch_err = tester.invoke_arrays(vec![]).unwrap_err();
350 assert_eq!(batch_err.message(), "empty(): No kernel matching arguments");
351
352 Ok(())
353 }
354
355 #[test]
356 fn simple_udf() {
357 let kernel_geo = SimpleSedonaScalarKernel::new_ref(
360 ArgMatcher::new(
361 vec![ArgMatcher::is_geometry_or_geography()],
362 SedonaType::Arrow(DataType::Null),
363 ),
364 Arc::new(|_, _| Ok(ColumnarValue::Scalar(ScalarValue::Null))),
365 );
366
367 let kernel_arrow = SimpleSedonaScalarKernel::new_ref(
368 ArgMatcher::new(
369 vec![ArgMatcher::is_arrow(DataType::Boolean)],
370 SedonaType::Arrow(DataType::Boolean),
371 ),
372 Arc::new(|_, _| Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)))),
373 );
374
375 let udf = SedonaScalarUDF::new(
376 "simple_udf",
377 vec![kernel_geo, kernel_arrow],
378 Volatility::Immutable,
379 );
380
381 let tester = ScalarUdfTester::new(udf.clone().into(), vec![WKB_GEOMETRY]);
383 tester.assert_return_type(DataType::Null);
384 assert_eq!(
385 tester.invoke_scalar("POINT (0 1)").unwrap(),
386 ScalarValue::Null
387 );
388
389 let tester = ScalarUdfTester::new(
391 udf.clone().into(),
392 vec![SedonaType::Arrow(DataType::Boolean)],
393 );
394 tester.assert_return_type(DataType::Boolean);
395 assert_eq!(
396 tester.invoke_scalar(true).unwrap(),
397 ScalarValue::Boolean(None)
398 );
399
400 let mut udf = udf.clone();
402 udf.add_kernels(SimpleSedonaScalarKernel::new_ref(
403 ArgMatcher::new(
404 vec![ArgMatcher::is_arrow(DataType::Boolean)],
405 SedonaType::Arrow(DataType::Utf8),
406 ),
407 Arc::new(|_, _| Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))),
408 ));
409
410 let tester = ScalarUdfTester::new(
412 udf.clone().into(),
413 vec![SedonaType::Arrow(DataType::Boolean)],
414 );
415 tester.assert_return_type(DataType::Utf8);
416 }
417
418 #[test]
419 fn crs_propagation() {
420 let geom_lnglat = SedonaType::Wkb(Edges::Planar, lnglat());
421 let predicate_stub_impl = SimpleSedonaScalarKernel::new_ref(
422 ArgMatcher::new(
423 vec![ArgMatcher::is_geometry(), ArgMatcher::is_geometry()],
424 SedonaType::Arrow(DataType::Boolean),
425 ),
426 Arc::new(|_arg_types, _args| unreachable!("Should not be executed")),
427 );
428 let predicate_stub = SedonaScalarUDF::from_impl("foofy", predicate_stub_impl);
429
430 let tester = ScalarUdfTester::new(
432 predicate_stub.clone().into(),
433 vec![WKB_GEOMETRY, WKB_GEOMETRY],
434 );
435 tester.assert_return_type(DataType::Boolean);
436
437 let tester = ScalarUdfTester::new(
439 predicate_stub.clone().into(),
440 vec![geom_lnglat.clone(), geom_lnglat.clone()],
441 );
442 tester.assert_return_type(DataType::Boolean);
443
444 let tester = ScalarUdfTester::new(
446 predicate_stub.clone().into(),
447 vec![WKB_GEOMETRY, geom_lnglat.clone()],
448 );
449 let err = tester.return_type().unwrap_err();
450 assert!(err.message().starts_with("Mismatched CRS arguments"));
451
452 let geom_out_impl = SimpleSedonaScalarKernel::new_ref(
454 ArgMatcher::new(
455 vec![ArgMatcher::is_geometry(), ArgMatcher::is_geometry()],
456 WKB_GEOMETRY,
457 ),
458 Arc::new(|_arg_types, args| Ok(args[0].clone())),
459 );
460 let geom_out_stub = SedonaScalarUDF::from_impl("foofy", geom_out_impl);
461
462 let tester = ScalarUdfTester::new(
463 geom_out_stub.clone().into(),
464 vec![geom_lnglat.clone(), geom_lnglat.clone()],
465 );
466 tester.assert_return_type(geom_lnglat.clone());
467 }
468
469 #[test]
470 fn return_type_from_scalar_arg() {
471 let udf: ScalarUDF = SedonaScalarUDF::from_impl("simple_cast", SimpleCast {}).into();
472 let call = udf.call(vec![lit(10), lit("float32")]);
473 let schema = DFSchema::empty();
474 assert_eq!(
475 call.data_type_and_nullable(&schema).unwrap(),
476 (DataType::Float32, true)
477 );
478 }
479
480 #[derive(Debug)]
481 struct SimpleCast {}
482
483 impl SimpleCast {
484 fn parse_type(val: &ColumnarValue) -> Result<SedonaType> {
485 if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(scalar_arg1))) = val {
486 match scalar_arg1.as_str() {
487 "float32" => return Ok(SedonaType::Arrow(DataType::Float32)),
488 "float64" => return Ok(SedonaType::Arrow(DataType::Float64)),
489 _ => {}
490 }
491 }
492
493 sedona_internal_err!("unrecognized target value")
494 }
495 }
496
497 impl SedonaScalarKernel for SimpleCast {
498 fn return_type(&self, _args: &[SedonaType]) -> Result<Option<SedonaType>> {
499 sedona_internal_err!("Should not be called")
500 }
501
502 fn return_type_from_args_and_scalars(
503 &self,
504 _args: &[SedonaType],
505 scalar_args: &[Option<&ScalarValue>],
506 ) -> Result<Option<SedonaType>> {
507 let out_type = Self::parse_type(&ColumnarValue::Scalar(
508 scalar_args[1].cloned().expect("arg1 as a scalar in test"),
509 ))?;
510
511 Ok(Some(out_type))
512 }
513
514 fn invoke_batch(
515 &self,
516 _arg_types: &[SedonaType],
517 args: &[ColumnarValue],
518 ) -> Result<ColumnarValue> {
519 let out_type = Self::parse_type(&args[1])?;
520 args[0].cast_to(out_type.storage_type(), None)
521 }
522 }
523}