sedona_testing/
testers.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17use std::{iter::zip, sync::Arc};
18
19use arrow_array::{ArrayRef, RecordBatch};
20use arrow_schema::{FieldRef, Schema};
21use datafusion_common::{config::ConfigOptions, Result, ScalarValue};
22use datafusion_expr::{
23    function::{AccumulatorArgs, StateFieldsArgs},
24    Accumulator, AggregateUDF, ColumnarValue, Expr, Literal, ReturnFieldArgs, ScalarFunctionArgs,
25    ScalarUDF,
26};
27use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
28use sedona_common::sedona_internal_err;
29use sedona_schema::datatypes::SedonaType;
30
31use crate::{
32    compare::assert_scalar_equal,
33    create::{create_array, create_scalar},
34};
35
36/// Low-level tester for aggregate functions
37///
38/// This struct provides a means by which to run a simple check of an
39/// aggregate implementation by approximating one method DataFusion
40/// might use to perform the aggregation. Whereas DataFusion may arrange
41/// various calls to accumulate, state, and update_batch to optimize
42/// for different cases, this tester is always created by aggregating
43/// states that were in turn created from accumulating one batch.
44///
45/// This is not a replacement for testing at a higher level using
46/// DataFusion's actual aggregate implementation but provides
47/// a useful mechanism to ensure all the pieces of an accumulator
48/// are plugged in.
49pub struct AggregateUdfTester {
50    udf: AggregateUDF,
51    arg_types: Vec<SedonaType>,
52}
53
54impl AggregateUdfTester {
55    /// Create a new tester
56    pub fn new(udf: AggregateUDF, arg_types: Vec<SedonaType>) -> Self {
57        Self { udf, arg_types }
58    }
59
60    /// Compute the return type
61    pub fn return_type(&self) -> Result<SedonaType> {
62        let arg_fields = self
63            .arg_types
64            .iter()
65            .map(|arg_type| arg_type.to_storage_field("", true).map(Arc::new))
66            .collect::<Result<Vec<_>>>()?;
67
68        let out_field = self.udf.return_field(&arg_fields)?;
69        SedonaType::from_storage_field(&out_field)
70    }
71
72    /// Perform a simple aggregation using WKT as geometry input
73    pub fn aggregate_wkt(&self, batches: Vec<Vec<Option<&str>>>) -> Result<ScalarValue> {
74        let batches_array = batches
75            .into_iter()
76            .map(|batch| create_array(&batch, &self.arg_types[0]))
77            .collect::<Vec<_>>();
78        self.aggregate(&batches_array)
79    }
80
81    /// Perform a simple aggregation
82    ///
83    /// Each batch in batches is accumulated with its own accumulator
84    /// and serialized into its own state, after which the states are accumulated
85    /// in batches of one. This has the effect of testing all the pieces of
86    /// an aggregator in a somewhat configurable/predictable way.
87    pub fn aggregate(&self, batches: &Vec<ArrayRef>) -> Result<ScalarValue> {
88        let state_schema = Arc::new(Schema::new(self.state_fields()?));
89        let mut state_accumulator = self.new_accumulator()?;
90
91        for batch in batches {
92            let mut batch_accumulator = self.new_accumulator()?;
93            batch_accumulator.update_batch(std::slice::from_ref(batch))?;
94            let state_batch_of_one = RecordBatch::try_new(
95                state_schema.clone(),
96                batch_accumulator
97                    .state()?
98                    .into_iter()
99                    .map(|v| v.to_array())
100                    .collect::<Result<Vec<_>>>()?,
101            )?;
102            state_accumulator.merge_batch(state_batch_of_one.columns())?;
103        }
104
105        state_accumulator.evaluate()
106    }
107
108    fn new_accumulator(&self) -> Result<Box<dyn Accumulator>> {
109        let mock_schema = Schema::new(self.arg_fields());
110        let exprs = (0..self.arg_types.len())
111            .map(|i| -> Arc<dyn PhysicalExpr> { Arc::new(Column::new("col", i)) })
112            .collect::<Vec<_>>();
113        let accumulator_args = AccumulatorArgs {
114            return_field: self.udf.return_field(mock_schema.fields())?,
115            schema: &mock_schema,
116            ignore_nulls: true,
117            order_bys: &[],
118            is_reversed: false,
119            name: "",
120            is_distinct: false,
121            exprs: &exprs,
122        };
123
124        self.udf.accumulator(accumulator_args)
125    }
126
127    fn state_fields(&self) -> Result<Vec<FieldRef>> {
128        let state_field_args = StateFieldsArgs {
129            name: "",
130            input_fields: &self.arg_fields(),
131            return_field: self.udf.return_field(&self.arg_fields())?,
132            ordering_fields: &[],
133            is_distinct: false,
134        };
135        self.udf.state_fields(state_field_args)
136    }
137
138    fn arg_fields(&self) -> Vec<FieldRef> {
139        self.arg_types
140            .iter()
141            .map(|sedona_type| sedona_type.to_storage_field("", true).map(Arc::new))
142            .collect::<Result<Vec<_>>>()
143            .unwrap()
144    }
145}
146
147/// Low-level tester for scalar functions
148///
149/// This struct provides a means by which to run a simple check of an
150/// scalar UDF implementation by simulating how DataFusion might call it.
151///
152/// This is not a replacement for testing at a higher level using DataFusion's
153/// actual implementation but provides a useful mechanism to ensure all the
154/// pieces of an scalar UDF are plugged in.
155///
156/// Note that arguments are always cast to the values passed [Self::new]:
157/// to test different combinations of argument types, use a new tester.
158pub struct ScalarUdfTester {
159    udf: ScalarUDF,
160    arg_types: Vec<SedonaType>,
161}
162
163impl ScalarUdfTester {
164    /// Create a new tester
165    pub fn new(udf: ScalarUDF, arg_types: Vec<SedonaType>) -> Self {
166        Self { udf, arg_types }
167    }
168
169    /// Assert the return type of the function for the argument types used
170    /// to construct this tester
171    ///
172    /// Both [SedonaType] or [DataType] objects can be used as the expected
173    /// data type.
174    pub fn assert_return_type(&self, data_type: impl TryInto<SedonaType>) {
175        let expected = match data_type.try_into() {
176            Ok(t) => t,
177            Err(_) => panic!("Failed to convert to SedonaType"),
178        };
179        assert_eq!(self.return_type().unwrap(), expected)
180    }
181
182    /// Assert the result of invoking this function
183    ///
184    /// Both actual and expected are interpreted according to the calculated
185    /// return type (notably, WKT is interpreted as geometry or geography output).
186    pub fn assert_scalar_result_equals(&self, actual: impl Literal, expected: impl Literal) {
187        self.assert_scalar_result_equals_inner(actual, expected, None);
188    }
189
190    /// Assert the result of invoking this function with the return type specified
191    ///
192    /// This is for UDFs implementing `SedonaScalarKernel::return_type_from_args_and_scalars()`.
193    pub fn assert_scalar_result_equals_with_return_type(
194        &self,
195        actual: impl Literal,
196        expected: impl Literal,
197        return_type: SedonaType,
198    ) {
199        self.assert_scalar_result_equals_inner(actual, expected, Some(return_type));
200    }
201
202    fn assert_scalar_result_equals_inner(
203        &self,
204        actual: impl Literal,
205        expected: impl Literal,
206        return_type: Option<SedonaType>,
207    ) {
208        let return_type = return_type.unwrap_or_else(|| self.return_type().unwrap());
209        let actual = Self::scalar_lit(actual, &return_type).unwrap();
210        let expected = Self::scalar_lit(expected, &return_type).unwrap();
211        assert_scalar_equal(&actual, &expected);
212    }
213
214    /// Compute the return type
215    pub fn return_type(&self) -> Result<SedonaType> {
216        let scalar_arguments = vec![None; self.arg_types.len()];
217        self.return_type_with_scalars_inner(&scalar_arguments)
218    }
219
220    /// Compute the return type from one scalar argument
221    ///
222    /// This is for UDFs implementing `SedonaScalarKernel::return_type_from_args_and_scalars()`.
223    pub fn return_type_with_scalar(&self, arg0: Option<impl Literal>) -> Result<SedonaType> {
224        let scalar_arguments = vec![arg0
225            .map(|x| Self::scalar_lit(x, &self.arg_types[0]))
226            .transpose()?];
227        self.return_type_with_scalars_inner(&scalar_arguments)
228    }
229
230    /// Compute the return type from two scalar arguments
231    ///
232    /// This is for UDFs implementing `SedonaScalarKernel::return_type_from_args_and_scalars()`.
233    pub fn return_type_with_scalar_scalar(
234        &self,
235        arg0: Option<impl Literal>,
236        arg1: Option<impl Literal>,
237    ) -> Result<SedonaType> {
238        let scalar_arguments = vec![
239            arg0.map(|x| Self::scalar_lit(x, &self.arg_types[0]))
240                .transpose()?,
241            arg1.map(|x| Self::scalar_lit(x, &self.arg_types[1]))
242                .transpose()?,
243        ];
244        self.return_type_with_scalars_inner(&scalar_arguments)
245    }
246
247    /// Compute the return type from three scalar arguments
248    ///
249    /// This is for UDFs implementing `SedonaScalarKernel::return_type_from_args_and_scalars()`.
250    pub fn return_type_with_scalar_scalar_scalar(
251        &self,
252        arg0: Option<impl Literal>,
253        arg1: Option<impl Literal>,
254        arg2: Option<impl Literal>,
255    ) -> Result<SedonaType> {
256        let scalar_arguments = vec![
257            arg0.map(|x| Self::scalar_lit(x, &self.arg_types[0]))
258                .transpose()?,
259            arg1.map(|x| Self::scalar_lit(x, &self.arg_types[1]))
260                .transpose()?,
261            arg2.map(|x| Self::scalar_lit(x, &self.arg_types[2]))
262                .transpose()?,
263        ];
264        self.return_type_with_scalars_inner(&scalar_arguments)
265    }
266
267    fn return_type_with_scalars_inner(
268        &self,
269        scalar_arguments: &[Option<ScalarValue>],
270    ) -> Result<SedonaType> {
271        let arg_fields = self
272            .arg_types
273            .iter()
274            .map(|sedona_type| sedona_type.to_storage_field("", true).map(Arc::new))
275            .collect::<Result<Vec<_>>>()?;
276
277        let scalar_arguments_ref: Vec<Option<&ScalarValue>> =
278            scalar_arguments.iter().map(|x| x.as_ref()).collect();
279        let args = ReturnFieldArgs {
280            arg_fields: &arg_fields,
281            scalar_arguments: &scalar_arguments_ref,
282        };
283        let return_field = self.udf.return_field_from_args(args)?;
284        SedonaType::from_storage_field(&return_field)
285    }
286
287    /// Invoke this function with a scalar
288    pub fn invoke_scalar(&self, arg: impl Literal) -> Result<ScalarValue> {
289        let scalar_arg = Self::scalar_lit(arg, &self.arg_types[0])?;
290
291        // Some UDF calculate the return type from the input scalar arguments, so try it first.
292        let return_type = self
293            .return_type_with_scalars_inner(&[Some(scalar_arg.clone())])
294            .ok();
295
296        let args = vec![ColumnarValue::Scalar(scalar_arg)];
297        if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? {
298            Ok(scalar)
299        } else {
300            sedona_internal_err!("Expected scalar result from scalar invoke")
301        }
302    }
303
304    /// Invoke this function with a geometry scalar
305    pub fn invoke_wkb_scalar(&self, wkt_value: Option<&str>) -> Result<ScalarValue> {
306        self.invoke_scalar(create_scalar(wkt_value, &self.arg_types[0]))
307    }
308
309    /// Invoke this function with two scalars
310    pub fn invoke_scalar_scalar<T0: Literal, T1: Literal>(
311        &self,
312        arg0: T0,
313        arg1: T1,
314    ) -> Result<ScalarValue> {
315        let scalar_arg0 = Self::scalar_lit(arg0, &self.arg_types[0])?;
316        let scalar_arg1 = Self::scalar_lit(arg1, &self.arg_types[1])?;
317
318        // Some UDF calculate the return type from the input scalar arguments, so try it first.
319        let return_type = self
320            .return_type_with_scalars_inner(&[Some(scalar_arg0.clone()), Some(scalar_arg1.clone())])
321            .ok();
322
323        let args = vec![
324            ColumnarValue::Scalar(scalar_arg0),
325            ColumnarValue::Scalar(scalar_arg1),
326        ];
327        if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? {
328            Ok(scalar)
329        } else {
330            sedona_internal_err!("Expected scalar result from binary scalar invoke")
331        }
332    }
333
334    /// Invoke this function with three scalars
335    pub fn invoke_scalar_scalar_scalar<T0: Literal, T1: Literal, T2: Literal>(
336        &self,
337        arg0: T0,
338        arg1: T1,
339        arg2: T2,
340    ) -> Result<ScalarValue> {
341        let scalar_arg0 = Self::scalar_lit(arg0, &self.arg_types[0])?;
342        let scalar_arg1 = Self::scalar_lit(arg1, &self.arg_types[1])?;
343        let scalar_arg2 = Self::scalar_lit(arg2, &self.arg_types[2])?;
344
345        // Some UDF calculate the return type from the input scalar arguments, so try it first.
346        let return_type = self
347            .return_type_with_scalars_inner(&[
348                Some(scalar_arg0.clone()),
349                Some(scalar_arg1.clone()),
350                Some(scalar_arg2.clone()),
351            ])
352            .ok();
353
354        let args = vec![
355            ColumnarValue::Scalar(scalar_arg0),
356            ColumnarValue::Scalar(scalar_arg1),
357            ColumnarValue::Scalar(scalar_arg2),
358        ];
359        if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? {
360            Ok(scalar)
361        } else {
362            sedona_internal_err!("Expected scalar result from binary scalar invoke")
363        }
364    }
365
366    /// Invoke this function with a geometry array
367    pub fn invoke_wkb_array(&self, wkb_values: Vec<Option<&str>>) -> Result<ArrayRef> {
368        self.invoke_array(create_array(&wkb_values, &self.arg_types[0]))
369    }
370
371    /// Invoke this function with a geometry array and a scalar
372    pub fn invoke_wkb_array_scalar(
373        &self,
374        wkb_values: Vec<Option<&str>>,
375        arg: impl Literal,
376    ) -> Result<ArrayRef> {
377        let wkb_array = create_array(&wkb_values, &self.arg_types[0]);
378        self.invoke_arrays_scalar(vec![wkb_array], arg)
379    }
380
381    /// Invoke this function with an array
382    pub fn invoke_array(&self, array: ArrayRef) -> Result<ArrayRef> {
383        self.invoke_arrays(vec![array])
384    }
385
386    /// Invoke a binary function with an array and a scalar
387    pub fn invoke_array_scalar(&self, array: ArrayRef, arg: impl Literal) -> Result<ArrayRef> {
388        self.invoke_arrays_scalar(vec![array], arg)
389    }
390
391    /// Invoke a binary function with an array, and two scalars
392    pub fn invoke_array_scalar_scalar(
393        &self,
394        array: ArrayRef,
395        arg0: impl Literal,
396        arg1: impl Literal,
397    ) -> Result<ArrayRef> {
398        self.invoke_arrays_scalar_scalar(vec![array], arg0, arg1)
399    }
400
401    /// Invoke a binary function with a scalar and an array
402    pub fn invoke_scalar_array(&self, arg: impl Literal, array: ArrayRef) -> Result<ArrayRef> {
403        self.invoke_scalar_arrays(arg, vec![array])
404    }
405
406    /// Invoke a binary function with two arrays
407    pub fn invoke_array_array(&self, array0: ArrayRef, array1: ArrayRef) -> Result<ArrayRef> {
408        self.invoke_arrays(vec![array0, array1])
409    }
410
411    /// Invoke a binary function with two arrays and a scalar
412    pub fn invoke_array_array_scalar(
413        &self,
414        array0: ArrayRef,
415        array1: ArrayRef,
416        arg: impl Literal,
417    ) -> Result<ArrayRef> {
418        self.invoke_arrays_scalar(vec![array0, array1], arg)
419    }
420
421    fn invoke_scalar_arrays(&self, arg: impl Literal, arrays: Vec<ArrayRef>) -> Result<ArrayRef> {
422        let mut args = zip(arrays, &self.arg_types)
423            .map(|(array, sedona_type)| {
424                ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
425            })
426            .collect::<Result<Vec<_>>>()?;
427        let index = args.len();
428        args.insert(0, Self::scalar_arg(arg, &self.arg_types[index])?);
429
430        if let ColumnarValue::Array(array) = self.invoke(args)? {
431            Ok(array)
432        } else {
433            sedona_internal_err!("Expected array result from scalar/array invoke")
434        }
435    }
436
437    fn invoke_arrays_scalar(&self, arrays: Vec<ArrayRef>, arg: impl Literal) -> Result<ArrayRef> {
438        let mut args = zip(arrays, &self.arg_types)
439            .map(|(array, sedona_type)| {
440                ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
441            })
442            .collect::<Result<Vec<_>>>()?;
443        let index = args.len();
444        args.push(Self::scalar_arg(arg, &self.arg_types[index])?);
445
446        if let ColumnarValue::Array(array) = self.invoke(args)? {
447            Ok(array)
448        } else {
449            sedona_internal_err!("Expected array result from array/scalar invoke")
450        }
451    }
452
453    fn invoke_arrays_scalar_scalar(
454        &self,
455        arrays: Vec<ArrayRef>,
456        arg0: impl Literal,
457        arg1: impl Literal,
458    ) -> Result<ArrayRef> {
459        let mut args = zip(arrays, &self.arg_types)
460            .map(|(array, sedona_type)| {
461                ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
462            })
463            .collect::<Result<Vec<_>>>()?;
464        let index = args.len();
465        args.push(Self::scalar_arg(arg0, &self.arg_types[index])?);
466        args.push(Self::scalar_arg(arg1, &self.arg_types[index + 1])?);
467
468        if let ColumnarValue::Array(array) = self.invoke(args)? {
469            Ok(array)
470        } else {
471            sedona_internal_err!("Expected array result from array/scalar invoke")
472        }
473    }
474
475    // Invoke a function with a set of arrays
476    pub fn invoke_arrays(&self, arrays: Vec<ArrayRef>) -> Result<ArrayRef> {
477        let args = zip(arrays, &self.arg_types)
478            .map(|(array, sedona_type)| {
479                ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
480            })
481            .collect::<Result<_>>()?;
482
483        if let ColumnarValue::Array(array) = self.invoke(args)? {
484            Ok(array)
485        } else {
486            sedona_internal_err!("Expected array result from array invoke")
487        }
488    }
489
490    pub fn invoke(&self, args: Vec<ColumnarValue>) -> Result<ColumnarValue> {
491        self.invoke_with_return_type(args, None)
492    }
493    pub fn invoke_with_return_type(
494        &self,
495        args: Vec<ColumnarValue>,
496        return_type: Option<SedonaType>,
497    ) -> Result<ColumnarValue> {
498        assert_eq!(args.len(), self.arg_types.len(), "Unexpected arg length");
499
500        let mut number_rows = 1;
501        for arg in &args {
502            match arg {
503                ColumnarValue::Array(array) => {
504                    number_rows = array.len();
505                    break;
506                }
507                _ => continue,
508            }
509        }
510
511        let return_type = match return_type {
512            Some(return_type) => return_type,
513            None => self.return_type()?,
514        };
515
516        let args = ScalarFunctionArgs {
517            args,
518            arg_fields: self.arg_fields(),
519            number_rows,
520            return_field: return_type.to_storage_field("", true)?.into(),
521            // TODO: Consider piping actual ConfigOptions for more realistic testing
522            // See: https://github.com/apache/sedona-db/issues/248
523            config_options: Arc::new(ConfigOptions::default()),
524        };
525
526        self.udf.invoke_with_args(args)
527    }
528
529    fn scalar_arg(arg: impl Literal, sedona_type: &SedonaType) -> Result<ColumnarValue> {
530        Ok(ColumnarValue::Scalar(Self::scalar_lit(arg, sedona_type)?))
531    }
532
533    fn scalar_lit(arg: impl Literal, sedona_type: &SedonaType) -> Result<ScalarValue> {
534        if let Expr::Literal(scalar, _) = arg.lit() {
535            if matches!(
536                sedona_type,
537                SedonaType::Wkb(_, _) | SedonaType::WkbView(_, _)
538            ) {
539                if let ScalarValue::Utf8(expected_wkt) = scalar {
540                    Ok(create_scalar(expected_wkt.as_deref(), sedona_type))
541                } else if &scalar.data_type() == sedona_type.storage_type() {
542                    Ok(scalar)
543                } else if scalar.is_null() {
544                    Ok(create_scalar(None, sedona_type))
545                } else {
546                    sedona_internal_err!("Can't interpret scalar {scalar} as type {sedona_type}")
547                }
548            } else {
549                scalar.cast_to(sedona_type.storage_type())
550            }
551        } else {
552            sedona_internal_err!("Can't use test scalar invoke where .lit() returns non-literal")
553        }
554    }
555
556    fn arg_fields(&self) -> Vec<FieldRef> {
557        self.arg_types
558            .iter()
559            .map(|data_type| data_type.to_storage_field("", false).map(Arc::new))
560            .collect::<Result<Vec<_>>>()
561            .unwrap()
562    }
563}