Skip to main content

sedona_expr/
scalar_udf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17use std::{any::Any, fmt::Debug, sync::Arc};
18
19use arrow_schema::{DataType, FieldRef};
20use datafusion_common::config::ConfigOptions;
21use datafusion_common::{not_impl_err, Result, ScalarValue};
22use datafusion_expr::{
23    ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
24    Volatility,
25};
26use sedona_common::sedona_internal_err;
27use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher};
28
29/// Shorthand for a [SedonaScalarKernel] reference
30pub type ScalarKernelRef = Arc<dyn SedonaScalarKernel>;
31
32/// Helper to resolve an iterable of kernels
33pub trait IntoScalarKernelRefs {
34    fn into_scalar_kernel_refs(self) -> Vec<ScalarKernelRef>;
35}
36
37impl IntoScalarKernelRefs for ScalarKernelRef {
38    fn into_scalar_kernel_refs(self) -> Vec<ScalarKernelRef> {
39        vec![self]
40    }
41}
42
43impl IntoScalarKernelRefs for Vec<ScalarKernelRef> {
44    fn into_scalar_kernel_refs(self) -> Vec<ScalarKernelRef> {
45        self
46    }
47}
48
49impl<T: SedonaScalarKernel + 'static> IntoScalarKernelRefs for T {
50    fn into_scalar_kernel_refs(self) -> Vec<ScalarKernelRef> {
51        vec![Arc::new(self)]
52    }
53}
54
55impl<T: SedonaScalarKernel + 'static> IntoScalarKernelRefs for Vec<Arc<T>> {
56    fn into_scalar_kernel_refs(self) -> Vec<ScalarKernelRef> {
57        self.into_iter()
58            .map(|item| item as ScalarKernelRef)
59            .collect()
60    }
61}
62
63/// Top-level scalar user-defined function
64///
65/// This struct implements datafusion's ScalarUDF and implements kernel dispatch
66/// and argument wrapping/unwrapping while this is still necessary to support
67/// user-defined types.
68#[derive(Debug, Clone)]
69pub struct SedonaScalarUDF {
70    name: String,
71    signature: Signature,
72    kernels: Vec<ScalarKernelRef>,
73    aliases: Vec<String>,
74}
75
76impl PartialEq for SedonaScalarUDF {
77    fn eq(&self, other: &Self) -> bool {
78        self.name == other.name
79    }
80}
81
82impl Eq for SedonaScalarUDF {}
83
84impl std::hash::Hash for SedonaScalarUDF {
85    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
86        self.name.hash(state);
87    }
88}
89
90/// User-defined function implementation
91///
92/// A `SedonaScalarUdf` is comprised of one or more kernels, to which it dispatches
93/// the first whose return_type returns `Some()`. Whereas a SeondaScalarUdf represents
94/// a logical operation (e.g., ST_Intersects()), a kernel wraps the logic around a specific
95/// implementation.
96pub trait SedonaScalarKernel: Debug + Send + Sync {
97    /// Calculate a return type given input types
98    ///
99    /// Returns Some(physical_type) if this kernel applies to the input types or
100    /// None otherwise. This struct acts as a version of the Signature that can
101    /// better accommodate the types we need to support (and might be able to be
102    /// removed when there is better support for matching user-defined types/
103    /// types with metadata in DataFusion).
104    ///
105    /// The [`ArgMatcher`] contains a set of helper functions to help implement this
106    /// function.
107    fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>>;
108
109    /// Calculate a return type given input type and scalar arguments
110    ///
111    /// Most functions should implement [SedonaScalarKernel::return_type]; however, some functions
112    /// (e.g., ST_SetSRID) calculate a return type based on the value of the argument if it is
113    /// a constant. If this is implemented, [SedonaScalarKernel::return_type] will not be called.
114    fn return_type_from_args_and_scalars(
115        &self,
116        args: &[SedonaType],
117        _scalar_args: &[Option<&ScalarValue>],
118    ) -> Result<Option<SedonaType>> {
119        self.return_type(args)
120    }
121
122    /// Compute a batch of results
123    ///
124    /// Computes an output chunk based on the physical types of the input and the
125    /// computed output type. The ColumnarValues passed are the "unwrapped" representation
126    /// of any extension type (e.g., for Wkb the provided ColumnarValue will be Binary).
127    fn invoke_batch(
128        &self,
129        arg_types: &[SedonaType],
130        args: &[ColumnarValue],
131    ) -> Result<ColumnarValue>;
132
133    fn invoke_batch_from_args(
134        &self,
135        arg_types: &[SedonaType],
136        args: &[ColumnarValue],
137        _return_type: &SedonaType,
138        _num_rows: usize,
139        _config_options: Option<&ConfigOptions>,
140    ) -> Result<ColumnarValue> {
141        self.invoke_batch(arg_types, args)
142    }
143}
144
145/// Type definition for a Scalar kernel implementation function
146pub type SedonaScalarKernelImpl =
147    Arc<dyn Fn(&[SedonaType], &[ColumnarValue]) -> Result<ColumnarValue> + Send + Sync>;
148
149/// Scalar kernel based on a function for testing
150pub struct SimpleSedonaScalarKernel {
151    arg_matcher: ArgMatcher,
152    fun: SedonaScalarKernelImpl,
153}
154
155impl Debug for SimpleSedonaScalarKernel {
156    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
157        f.debug_struct("SimpleSedonaScalarKernel").finish()
158    }
159}
160
161impl SimpleSedonaScalarKernel {
162    pub fn new_ref(arg_matcher: ArgMatcher, fun: SedonaScalarKernelImpl) -> ScalarKernelRef {
163        Arc::new(Self { arg_matcher, fun })
164    }
165}
166
167impl SedonaScalarKernel for SimpleSedonaScalarKernel {
168    fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
169        self.arg_matcher.match_args(args)
170    }
171
172    fn invoke_batch(
173        &self,
174        arg_types: &[SedonaType],
175        args: &[ColumnarValue],
176    ) -> Result<ColumnarValue> {
177        (self.fun)(arg_types, args)
178    }
179}
180
181impl SedonaScalarUDF {
182    /// Create a new SedonaScalarUDF
183    pub fn new(
184        name: &str,
185        kernels: Vec<ScalarKernelRef>,
186        volatility: Volatility,
187    ) -> SedonaScalarUDF {
188        let signature = Signature::user_defined(volatility);
189        Self {
190            name: name.to_string(),
191            signature,
192            kernels,
193            aliases: vec![],
194        }
195    }
196
197    /// Add aliases to an existing SedonaScalarUDF
198    pub fn with_aliases(self, aliases: Vec<String>) -> SedonaScalarUDF {
199        Self {
200            name: self.name,
201            signature: self.signature,
202            kernels: self.kernels,
203            aliases,
204        }
205    }
206
207    /// Create a SedonaScalarUDF from a single kernel
208    ///
209    /// This constructor creates a [Volatility::Immutable] function with no documentation
210    /// consisting of only the implementation provided.
211    pub fn from_impl(name: &str, kernels: impl IntoScalarKernelRefs) -> SedonaScalarUDF {
212        Self::new(
213            name,
214            kernels.into_scalar_kernel_refs(),
215            Volatility::Immutable,
216        )
217    }
218
219    /// Add a new kernel to a Scalar UDF
220    ///
221    /// Because kernels are resolved in reverse order, the new kernel will take
222    /// precedence over any previously added kernels that apply to the same types.
223    pub fn add_kernels(&mut self, kernels: impl IntoScalarKernelRefs) {
224        for kernel in kernels.into_scalar_kernel_refs() {
225            self.kernels.push(kernel);
226        }
227    }
228
229    fn return_type_impl(
230        &self,
231        args: &[SedonaType],
232        scalars: &[Option<&ScalarValue>],
233    ) -> Result<(&dyn SedonaScalarKernel, SedonaType)> {
234        // Resolve kernels in reverse so that more recently added ones are resolved first
235        for kernel in self.kernels.iter().rev() {
236            if let Some(return_type) = kernel.return_type_from_args_and_scalars(args, scalars)? {
237                return Ok((kernel.as_ref(), return_type));
238            }
239        }
240
241        let args_display = args
242            .iter()
243            .map(|arg| arg.logical_type_name())
244            .collect::<Vec<_>>()
245            .join(", ");
246
247        not_impl_err!(
248            "{}({args_display}): No kernel matching arguments",
249            self.name
250        )
251    }
252}
253
254impl ScalarUDFImpl for SedonaScalarUDF {
255    fn as_any(&self) -> &dyn Any {
256        self
257    }
258
259    fn name(&self) -> &str {
260        &self.name
261    }
262
263    fn signature(&self) -> &Signature {
264        &self.signature
265    }
266
267    fn documentation(&self) -> Option<&Documentation> {
268        None
269    }
270
271    fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
272        sedona_internal_err!("Should not be called (use return_field_from_args())")
273    }
274
275    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
276        let arg_types = args
277            .arg_fields
278            .iter()
279            .map(|field| SedonaType::from_storage_field(field))
280            .collect::<Result<Vec<_>>>()?;
281        let (_, out_type) = self.return_type_impl(&arg_types, args.scalar_arguments)?;
282        Ok(Arc::new(out_type.to_storage_field("", true)?))
283    }
284
285    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
286        Ok(arg_types.to_vec())
287    }
288
289    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
290        let arg_types = args
291            .arg_fields
292            .iter()
293            .map(|field| SedonaType::from_storage_field(field))
294            .collect::<Result<Vec<_>>>()?;
295
296        let arg_scalars = args
297            .args
298            .iter()
299            .map(|arg| {
300                if let ColumnarValue::Scalar(scalar) = arg {
301                    Some(scalar)
302                } else {
303                    None
304                }
305            })
306            .collect::<Vec<_>>();
307
308        let (kernel, return_type) = self.return_type_impl(&arg_types, &arg_scalars)?;
309        kernel.invoke_batch_from_args(
310            &arg_types,
311            &args.args,
312            &return_type,
313            args.number_rows,
314            Some(&*args.config_options),
315        )
316    }
317
318    fn aliases(&self) -> &[String] {
319        &self.aliases
320    }
321}
322
323#[cfg(test)]
324mod tests {
325
326    use datafusion_common::{scalar::ScalarValue, DFSchema};
327    use sedona_testing::testers::ScalarUdfTester;
328
329    use datafusion_expr::{lit, ExprSchemable, ScalarUDF};
330    use sedona_schema::{
331        crs::lnglat,
332        datatypes::{Edges, WKB_GEOMETRY},
333    };
334
335    use super::*;
336
337    #[test]
338    fn udf_empty() -> Result<()> {
339        // UDF with no implementations
340        let udf = SedonaScalarUDF::new("empty", vec![], Volatility::Immutable);
341        assert_eq!(udf.name(), "empty");
342        assert_eq!(udf.coerce_types(&[])?, vec![]);
343
344        let tester = ScalarUdfTester::new(udf.into(), vec![]);
345
346        let err = tester.return_type().unwrap_err();
347        assert_eq!(err.message(), "empty(): No kernel matching arguments");
348
349        let batch_err = tester.invoke_arrays(vec![]).unwrap_err();
350        assert_eq!(batch_err.message(), "empty(): No kernel matching arguments");
351
352        Ok(())
353    }
354
355    #[test]
356    fn simple_udf() {
357        // UDF with two implementations: one that matches any geometry and one that
358        // matches a specific arrow type.
359        let kernel_geo = SimpleSedonaScalarKernel::new_ref(
360            ArgMatcher::new(
361                vec![ArgMatcher::is_geometry_or_geography()],
362                SedonaType::Arrow(DataType::Null),
363            ),
364            Arc::new(|_, _| Ok(ColumnarValue::Scalar(ScalarValue::Null))),
365        );
366
367        let kernel_arrow = SimpleSedonaScalarKernel::new_ref(
368            ArgMatcher::new(
369                vec![ArgMatcher::is_arrow(DataType::Boolean)],
370                SedonaType::Arrow(DataType::Boolean),
371            ),
372            Arc::new(|_, _| Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)))),
373        );
374
375        let udf = SedonaScalarUDF::new(
376            "simple_udf",
377            vec![kernel_geo, kernel_arrow],
378            Volatility::Immutable,
379        );
380
381        // Calling with a geo type should return a Null type
382        let tester = ScalarUdfTester::new(udf.clone().into(), vec![WKB_GEOMETRY]);
383        tester.assert_return_type(DataType::Null);
384        assert_eq!(
385            tester.invoke_scalar("POINT (0 1)").unwrap(),
386            ScalarValue::Null
387        );
388
389        // Calling with a Boolean should result in a Boolean
390        let tester = ScalarUdfTester::new(
391            udf.clone().into(),
392            vec![SedonaType::Arrow(DataType::Boolean)],
393        );
394        tester.assert_return_type(DataType::Boolean);
395        assert_eq!(
396            tester.invoke_scalar(true).unwrap(),
397            ScalarValue::Boolean(None)
398        );
399
400        // Adding a new kernel should result in that kernel getting picked first
401        let mut udf = udf.clone();
402        udf.add_kernels(SimpleSedonaScalarKernel::new_ref(
403            ArgMatcher::new(
404                vec![ArgMatcher::is_arrow(DataType::Boolean)],
405                SedonaType::Arrow(DataType::Utf8),
406            ),
407            Arc::new(|_, _| Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))),
408        ));
409
410        // Now, calling with a Boolean should result in a Utf8
411        let tester = ScalarUdfTester::new(
412            udf.clone().into(),
413            vec![SedonaType::Arrow(DataType::Boolean)],
414        );
415        tester.assert_return_type(DataType::Utf8);
416    }
417
418    #[test]
419    fn crs_propagation() {
420        let geom_lnglat = SedonaType::Wkb(Edges::Planar, lnglat());
421        let predicate_stub_impl = SimpleSedonaScalarKernel::new_ref(
422            ArgMatcher::new(
423                vec![ArgMatcher::is_geometry(), ArgMatcher::is_geometry()],
424                SedonaType::Arrow(DataType::Boolean),
425            ),
426            Arc::new(|_arg_types, _args| unreachable!("Should not be executed")),
427        );
428        let predicate_stub = SedonaScalarUDF::from_impl("foofy", predicate_stub_impl);
429
430        // None CRS to None CRS is OK
431        let tester = ScalarUdfTester::new(
432            predicate_stub.clone().into(),
433            vec![WKB_GEOMETRY, WKB_GEOMETRY],
434        );
435        tester.assert_return_type(DataType::Boolean);
436
437        // lnglat + lnglat is OK
438        let tester = ScalarUdfTester::new(
439            predicate_stub.clone().into(),
440            vec![geom_lnglat.clone(), geom_lnglat.clone()],
441        );
442        tester.assert_return_type(DataType::Boolean);
443
444        // Non-equal CRSes should error
445        let tester = ScalarUdfTester::new(
446            predicate_stub.clone().into(),
447            vec![WKB_GEOMETRY, geom_lnglat.clone()],
448        );
449        let err = tester.return_type().unwrap_err();
450        assert!(err.message().starts_with("Mismatched CRS arguments"));
451
452        // When geometry is output, it should match the crses of the inputs
453        let geom_out_impl = SimpleSedonaScalarKernel::new_ref(
454            ArgMatcher::new(
455                vec![ArgMatcher::is_geometry(), ArgMatcher::is_geometry()],
456                WKB_GEOMETRY,
457            ),
458            Arc::new(|_arg_types, args| Ok(args[0].clone())),
459        );
460        let geom_out_stub = SedonaScalarUDF::from_impl("foofy", geom_out_impl);
461
462        let tester = ScalarUdfTester::new(
463            geom_out_stub.clone().into(),
464            vec![geom_lnglat.clone(), geom_lnglat.clone()],
465        );
466        tester.assert_return_type(geom_lnglat.clone());
467    }
468
469    #[test]
470    fn return_type_from_scalar_arg() {
471        let udf: ScalarUDF = SedonaScalarUDF::from_impl("simple_cast", SimpleCast {}).into();
472        let call = udf.call(vec![lit(10), lit("float32")]);
473        let schema = DFSchema::empty();
474        assert_eq!(
475            call.data_type_and_nullable(&schema).unwrap(),
476            (DataType::Float32, true)
477        );
478    }
479
480    #[derive(Debug)]
481    struct SimpleCast {}
482
483    impl SimpleCast {
484        fn parse_type(val: &ColumnarValue) -> Result<SedonaType> {
485            if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(scalar_arg1))) = val {
486                match scalar_arg1.as_str() {
487                    "float32" => return Ok(SedonaType::Arrow(DataType::Float32)),
488                    "float64" => return Ok(SedonaType::Arrow(DataType::Float64)),
489                    _ => {}
490                }
491            }
492
493            sedona_internal_err!("unrecognized target value")
494        }
495    }
496
497    impl SedonaScalarKernel for SimpleCast {
498        fn return_type(&self, _args: &[SedonaType]) -> Result<Option<SedonaType>> {
499            sedona_internal_err!("Should not be called")
500        }
501
502        fn return_type_from_args_and_scalars(
503            &self,
504            _args: &[SedonaType],
505            scalar_args: &[Option<&ScalarValue>],
506        ) -> Result<Option<SedonaType>> {
507            let out_type = Self::parse_type(&ColumnarValue::Scalar(
508                scalar_args[1].cloned().expect("arg1 as a scalar in test"),
509            ))?;
510
511            Ok(Some(out_type))
512        }
513
514        fn invoke_batch(
515            &self,
516            _arg_types: &[SedonaType],
517            args: &[ColumnarValue],
518        ) -> Result<ColumnarValue> {
519            let out_type = Self::parse_type(&args[1])?;
520            args[0].cast_to(out_type.storage_type(), None)
521        }
522    }
523}