1use std::{iter::zip, sync::Arc};
18
19use arrow_array::{ArrayRef, RecordBatch};
20use arrow_schema::{FieldRef, Schema};
21use datafusion_common::{config::ConfigOptions, Result, ScalarValue};
22use datafusion_expr::{
23 function::{AccumulatorArgs, StateFieldsArgs},
24 Accumulator, AggregateUDF, ColumnarValue, Expr, Literal, ReturnFieldArgs, ScalarFunctionArgs,
25 ScalarUDF,
26};
27use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
28use sedona_common::sedona_internal_err;
29use sedona_schema::datatypes::SedonaType;
30
31use crate::{
32 compare::assert_scalar_equal,
33 create::{create_array, create_scalar},
34};
35
36pub struct AggregateUdfTester {
50 udf: AggregateUDF,
51 arg_types: Vec<SedonaType>,
52}
53
54impl AggregateUdfTester {
55 pub fn new(udf: AggregateUDF, arg_types: Vec<SedonaType>) -> Self {
57 Self { udf, arg_types }
58 }
59
60 pub fn return_type(&self) -> Result<SedonaType> {
62 let arg_fields = self
63 .arg_types
64 .iter()
65 .map(|arg_type| arg_type.to_storage_field("", true).map(Arc::new))
66 .collect::<Result<Vec<_>>>()?;
67
68 let out_field = self.udf.return_field(&arg_fields)?;
69 SedonaType::from_storage_field(&out_field)
70 }
71
72 pub fn aggregate_wkt(&self, batches: Vec<Vec<Option<&str>>>) -> Result<ScalarValue> {
74 let batches_array = batches
75 .into_iter()
76 .map(|batch| create_array(&batch, &self.arg_types[0]))
77 .collect::<Vec<_>>();
78 self.aggregate(&batches_array)
79 }
80
81 pub fn aggregate(&self, batches: &Vec<ArrayRef>) -> Result<ScalarValue> {
88 let state_schema = Arc::new(Schema::new(self.state_fields()?));
89 let mut state_accumulator = self.new_accumulator()?;
90
91 for batch in batches {
92 let mut batch_accumulator = self.new_accumulator()?;
93 batch_accumulator.update_batch(std::slice::from_ref(batch))?;
94 let state_batch_of_one = RecordBatch::try_new(
95 state_schema.clone(),
96 batch_accumulator
97 .state()?
98 .into_iter()
99 .map(|v| v.to_array())
100 .collect::<Result<Vec<_>>>()?,
101 )?;
102 state_accumulator.merge_batch(state_batch_of_one.columns())?;
103 }
104
105 state_accumulator.evaluate()
106 }
107
108 fn new_accumulator(&self) -> Result<Box<dyn Accumulator>> {
109 let mock_schema = Schema::new(self.arg_fields());
110 let exprs = (0..self.arg_types.len())
111 .map(|i| -> Arc<dyn PhysicalExpr> { Arc::new(Column::new("col", i)) })
112 .collect::<Vec<_>>();
113 let accumulator_args = AccumulatorArgs {
114 return_field: self.udf.return_field(mock_schema.fields())?,
115 schema: &mock_schema,
116 ignore_nulls: true,
117 order_bys: &[],
118 is_reversed: false,
119 name: "",
120 is_distinct: false,
121 exprs: &exprs,
122 };
123
124 self.udf.accumulator(accumulator_args)
125 }
126
127 fn state_fields(&self) -> Result<Vec<FieldRef>> {
128 let state_field_args = StateFieldsArgs {
129 name: "",
130 input_fields: &self.arg_fields(),
131 return_field: self.udf.return_field(&self.arg_fields())?,
132 ordering_fields: &[],
133 is_distinct: false,
134 };
135 self.udf.state_fields(state_field_args)
136 }
137
138 fn arg_fields(&self) -> Vec<FieldRef> {
139 self.arg_types
140 .iter()
141 .map(|sedona_type| sedona_type.to_storage_field("", true).map(Arc::new))
142 .collect::<Result<Vec<_>>>()
143 .unwrap()
144 }
145}
146
147pub struct ScalarUdfTester {
159 udf: ScalarUDF,
160 arg_types: Vec<SedonaType>,
161}
162
163impl ScalarUdfTester {
164 pub fn new(udf: ScalarUDF, arg_types: Vec<SedonaType>) -> Self {
166 Self { udf, arg_types }
167 }
168
169 pub fn assert_return_type(&self, data_type: impl TryInto<SedonaType>) {
175 let expected = match data_type.try_into() {
176 Ok(t) => t,
177 Err(_) => panic!("Failed to convert to SedonaType"),
178 };
179 assert_eq!(self.return_type().unwrap(), expected)
180 }
181
182 pub fn assert_scalar_result_equals(&self, actual: impl Literal, expected: impl Literal) {
187 self.assert_scalar_result_equals_inner(actual, expected, None);
188 }
189
190 pub fn assert_scalar_result_equals_with_return_type(
194 &self,
195 actual: impl Literal,
196 expected: impl Literal,
197 return_type: SedonaType,
198 ) {
199 self.assert_scalar_result_equals_inner(actual, expected, Some(return_type));
200 }
201
202 fn assert_scalar_result_equals_inner(
203 &self,
204 actual: impl Literal,
205 expected: impl Literal,
206 return_type: Option<SedonaType>,
207 ) {
208 let return_type = return_type.unwrap_or_else(|| self.return_type().unwrap());
209 let actual = Self::scalar_lit(actual, &return_type).unwrap();
210 let expected = Self::scalar_lit(expected, &return_type).unwrap();
211 assert_scalar_equal(&actual, &expected);
212 }
213
214 pub fn return_type(&self) -> Result<SedonaType> {
216 let scalar_arguments = vec![None; self.arg_types.len()];
217 self.return_type_with_scalars_inner(&scalar_arguments)
218 }
219
220 pub fn return_type_with_scalar(&self, arg0: Option<impl Literal>) -> Result<SedonaType> {
224 let scalar_arguments = vec![arg0
225 .map(|x| Self::scalar_lit(x, &self.arg_types[0]))
226 .transpose()?];
227 self.return_type_with_scalars_inner(&scalar_arguments)
228 }
229
230 pub fn return_type_with_scalar_scalar(
234 &self,
235 arg0: Option<impl Literal>,
236 arg1: Option<impl Literal>,
237 ) -> Result<SedonaType> {
238 let scalar_arguments = vec![
239 arg0.map(|x| Self::scalar_lit(x, &self.arg_types[0]))
240 .transpose()?,
241 arg1.map(|x| Self::scalar_lit(x, &self.arg_types[1]))
242 .transpose()?,
243 ];
244 self.return_type_with_scalars_inner(&scalar_arguments)
245 }
246
247 pub fn return_type_with_scalar_scalar_scalar(
251 &self,
252 arg0: Option<impl Literal>,
253 arg1: Option<impl Literal>,
254 arg2: Option<impl Literal>,
255 ) -> Result<SedonaType> {
256 let scalar_arguments = vec![
257 arg0.map(|x| Self::scalar_lit(x, &self.arg_types[0]))
258 .transpose()?,
259 arg1.map(|x| Self::scalar_lit(x, &self.arg_types[1]))
260 .transpose()?,
261 arg2.map(|x| Self::scalar_lit(x, &self.arg_types[2]))
262 .transpose()?,
263 ];
264 self.return_type_with_scalars_inner(&scalar_arguments)
265 }
266
267 fn return_type_with_scalars_inner(
268 &self,
269 scalar_arguments: &[Option<ScalarValue>],
270 ) -> Result<SedonaType> {
271 let arg_fields = self
272 .arg_types
273 .iter()
274 .map(|sedona_type| sedona_type.to_storage_field("", true).map(Arc::new))
275 .collect::<Result<Vec<_>>>()?;
276
277 let scalar_arguments_ref: Vec<Option<&ScalarValue>> =
278 scalar_arguments.iter().map(|x| x.as_ref()).collect();
279 let args = ReturnFieldArgs {
280 arg_fields: &arg_fields,
281 scalar_arguments: &scalar_arguments_ref,
282 };
283 let return_field = self.udf.return_field_from_args(args)?;
284 SedonaType::from_storage_field(&return_field)
285 }
286
287 pub fn invoke_scalar(&self, arg: impl Literal) -> Result<ScalarValue> {
289 let scalar_arg = Self::scalar_lit(arg, &self.arg_types[0])?;
290
291 let return_type = self
293 .return_type_with_scalars_inner(&[Some(scalar_arg.clone())])
294 .ok();
295
296 let args = vec![ColumnarValue::Scalar(scalar_arg)];
297 if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? {
298 Ok(scalar)
299 } else {
300 sedona_internal_err!("Expected scalar result from scalar invoke")
301 }
302 }
303
304 pub fn invoke_wkb_scalar(&self, wkt_value: Option<&str>) -> Result<ScalarValue> {
306 self.invoke_scalar(create_scalar(wkt_value, &self.arg_types[0]))
307 }
308
309 pub fn invoke_scalar_scalar<T0: Literal, T1: Literal>(
311 &self,
312 arg0: T0,
313 arg1: T1,
314 ) -> Result<ScalarValue> {
315 let scalar_arg0 = Self::scalar_lit(arg0, &self.arg_types[0])?;
316 let scalar_arg1 = Self::scalar_lit(arg1, &self.arg_types[1])?;
317
318 let return_type = self
320 .return_type_with_scalars_inner(&[Some(scalar_arg0.clone()), Some(scalar_arg1.clone())])
321 .ok();
322
323 let args = vec![
324 ColumnarValue::Scalar(scalar_arg0),
325 ColumnarValue::Scalar(scalar_arg1),
326 ];
327 if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? {
328 Ok(scalar)
329 } else {
330 sedona_internal_err!("Expected scalar result from binary scalar invoke")
331 }
332 }
333
334 pub fn invoke_scalar_scalar_scalar<T0: Literal, T1: Literal, T2: Literal>(
336 &self,
337 arg0: T0,
338 arg1: T1,
339 arg2: T2,
340 ) -> Result<ScalarValue> {
341 let scalar_arg0 = Self::scalar_lit(arg0, &self.arg_types[0])?;
342 let scalar_arg1 = Self::scalar_lit(arg1, &self.arg_types[1])?;
343 let scalar_arg2 = Self::scalar_lit(arg2, &self.arg_types[2])?;
344
345 let return_type = self
347 .return_type_with_scalars_inner(&[
348 Some(scalar_arg0.clone()),
349 Some(scalar_arg1.clone()),
350 Some(scalar_arg2.clone()),
351 ])
352 .ok();
353
354 let args = vec![
355 ColumnarValue::Scalar(scalar_arg0),
356 ColumnarValue::Scalar(scalar_arg1),
357 ColumnarValue::Scalar(scalar_arg2),
358 ];
359 if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? {
360 Ok(scalar)
361 } else {
362 sedona_internal_err!("Expected scalar result from binary scalar invoke")
363 }
364 }
365
366 pub fn invoke_wkb_array(&self, wkb_values: Vec<Option<&str>>) -> Result<ArrayRef> {
368 self.invoke_array(create_array(&wkb_values, &self.arg_types[0]))
369 }
370
371 pub fn invoke_wkb_array_scalar(
373 &self,
374 wkb_values: Vec<Option<&str>>,
375 arg: impl Literal,
376 ) -> Result<ArrayRef> {
377 let wkb_array = create_array(&wkb_values, &self.arg_types[0]);
378 self.invoke_arrays_scalar(vec![wkb_array], arg)
379 }
380
381 pub fn invoke_array(&self, array: ArrayRef) -> Result<ArrayRef> {
383 self.invoke_arrays(vec![array])
384 }
385
386 pub fn invoke_array_scalar(&self, array: ArrayRef, arg: impl Literal) -> Result<ArrayRef> {
388 self.invoke_arrays_scalar(vec![array], arg)
389 }
390
391 pub fn invoke_array_scalar_scalar(
393 &self,
394 array: ArrayRef,
395 arg0: impl Literal,
396 arg1: impl Literal,
397 ) -> Result<ArrayRef> {
398 self.invoke_arrays_scalar_scalar(vec![array], arg0, arg1)
399 }
400
401 pub fn invoke_scalar_array(&self, arg: impl Literal, array: ArrayRef) -> Result<ArrayRef> {
403 self.invoke_scalar_arrays(arg, vec![array])
404 }
405
406 pub fn invoke_array_array(&self, array0: ArrayRef, array1: ArrayRef) -> Result<ArrayRef> {
408 self.invoke_arrays(vec![array0, array1])
409 }
410
411 pub fn invoke_array_array_scalar(
413 &self,
414 array0: ArrayRef,
415 array1: ArrayRef,
416 arg: impl Literal,
417 ) -> Result<ArrayRef> {
418 self.invoke_arrays_scalar(vec![array0, array1], arg)
419 }
420
421 fn invoke_scalar_arrays(&self, arg: impl Literal, arrays: Vec<ArrayRef>) -> Result<ArrayRef> {
422 let mut args = zip(arrays, &self.arg_types)
423 .map(|(array, sedona_type)| {
424 ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
425 })
426 .collect::<Result<Vec<_>>>()?;
427 let index = args.len();
428 args.insert(0, Self::scalar_arg(arg, &self.arg_types[index])?);
429
430 if let ColumnarValue::Array(array) = self.invoke(args)? {
431 Ok(array)
432 } else {
433 sedona_internal_err!("Expected array result from scalar/array invoke")
434 }
435 }
436
437 fn invoke_arrays_scalar(&self, arrays: Vec<ArrayRef>, arg: impl Literal) -> Result<ArrayRef> {
438 let mut args = zip(arrays, &self.arg_types)
439 .map(|(array, sedona_type)| {
440 ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
441 })
442 .collect::<Result<Vec<_>>>()?;
443 let index = args.len();
444 args.push(Self::scalar_arg(arg, &self.arg_types[index])?);
445
446 if let ColumnarValue::Array(array) = self.invoke(args)? {
447 Ok(array)
448 } else {
449 sedona_internal_err!("Expected array result from array/scalar invoke")
450 }
451 }
452
453 fn invoke_arrays_scalar_scalar(
454 &self,
455 arrays: Vec<ArrayRef>,
456 arg0: impl Literal,
457 arg1: impl Literal,
458 ) -> Result<ArrayRef> {
459 let mut args = zip(arrays, &self.arg_types)
460 .map(|(array, sedona_type)| {
461 ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
462 })
463 .collect::<Result<Vec<_>>>()?;
464 let index = args.len();
465 args.push(Self::scalar_arg(arg0, &self.arg_types[index])?);
466 args.push(Self::scalar_arg(arg1, &self.arg_types[index + 1])?);
467
468 if let ColumnarValue::Array(array) = self.invoke(args)? {
469 Ok(array)
470 } else {
471 sedona_internal_err!("Expected array result from array/scalar invoke")
472 }
473 }
474
475 pub fn invoke_arrays(&self, arrays: Vec<ArrayRef>) -> Result<ArrayRef> {
477 let args = zip(arrays, &self.arg_types)
478 .map(|(array, sedona_type)| {
479 ColumnarValue::Array(array).cast_to(sedona_type.storage_type(), None)
480 })
481 .collect::<Result<_>>()?;
482
483 if let ColumnarValue::Array(array) = self.invoke(args)? {
484 Ok(array)
485 } else {
486 sedona_internal_err!("Expected array result from array invoke")
487 }
488 }
489
490 pub fn invoke(&self, args: Vec<ColumnarValue>) -> Result<ColumnarValue> {
491 self.invoke_with_return_type(args, None)
492 }
493 pub fn invoke_with_return_type(
494 &self,
495 args: Vec<ColumnarValue>,
496 return_type: Option<SedonaType>,
497 ) -> Result<ColumnarValue> {
498 assert_eq!(args.len(), self.arg_types.len(), "Unexpected arg length");
499
500 let mut number_rows = 1;
501 for arg in &args {
502 match arg {
503 ColumnarValue::Array(array) => {
504 number_rows = array.len();
505 break;
506 }
507 _ => continue,
508 }
509 }
510
511 let return_type = match return_type {
512 Some(return_type) => return_type,
513 None => self.return_type()?,
514 };
515
516 let args = ScalarFunctionArgs {
517 args,
518 arg_fields: self.arg_fields(),
519 number_rows,
520 return_field: return_type.to_storage_field("", true)?.into(),
521 config_options: Arc::new(ConfigOptions::default()),
524 };
525
526 self.udf.invoke_with_args(args)
527 }
528
529 fn scalar_arg(arg: impl Literal, sedona_type: &SedonaType) -> Result<ColumnarValue> {
530 Ok(ColumnarValue::Scalar(Self::scalar_lit(arg, sedona_type)?))
531 }
532
533 fn scalar_lit(arg: impl Literal, sedona_type: &SedonaType) -> Result<ScalarValue> {
534 if let Expr::Literal(scalar, _) = arg.lit() {
535 if matches!(
536 sedona_type,
537 SedonaType::Wkb(_, _) | SedonaType::WkbView(_, _)
538 ) {
539 if let ScalarValue::Utf8(expected_wkt) = scalar {
540 Ok(create_scalar(expected_wkt.as_deref(), sedona_type))
541 } else if &scalar.data_type() == sedona_type.storage_type() {
542 Ok(scalar)
543 } else if scalar.is_null() {
544 Ok(create_scalar(None, sedona_type))
545 } else {
546 sedona_internal_err!("Can't interpret scalar {scalar} as type {sedona_type}")
547 }
548 } else {
549 scalar.cast_to(sedona_type.storage_type())
550 }
551 } else {
552 sedona_internal_err!("Can't use test scalar invoke where .lit() returns non-literal")
553 }
554 }
555
556 fn arg_fields(&self) -> Vec<FieldRef> {
557 self.arg_types
558 .iter()
559 .map(|data_type| data_type.to_storage_field("", false).map(Arc::new))
560 .collect::<Result<Vec<_>>>()
561 .unwrap()
562 }
563}