sedona_testing/
benchmark_util.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17use std::{fmt::Debug, sync::Arc, vec};
18
19use arrow_array::{ArrayRef, Float64Array};
20use arrow_schema::DataType;
21
22use datafusion_common::{Result, ScalarValue};
23use datafusion_expr::{AggregateUDF, ScalarUDF};
24use geo_types::Rect;
25use rand::{distributions::Uniform, rngs::StdRng, Rng, SeedableRng};
26
27use sedona_common::sedona_internal_err;
28use sedona_geometry::types::GeometryTypeId;
29use sedona_schema::datatypes::{SedonaType, RASTER, WKB_GEOMETRY};
30use sedona_schema::raster::BandDataType;
31
32use crate::{
33    datagen::RandomPartitionedDataBuilder,
34    rasters::generate_tiled_rasters,
35    testers::{AggregateUdfTester, ScalarUdfTester},
36};
37
38/// The default number of rows per batch (the same as the DataFusion default)
39pub const ROWS_PER_BATCH: usize = 8192;
40
41/// The number of rows per batch to use for tiny size benchmarks
42pub const ROWS_PER_BATCH_TINY: usize = 1024;
43
44/// The default number of batches to use for small size benchmarks
45///
46/// This was chosen to ensure that most benchmarks run nicely with criterion
47/// defaults (target 5s, 100 samples).
48pub const NUM_BATCHES_SMALL: usize = 16;
49
50/// The default number of batches to use for tiny size benchmarks
51///
52/// Just one batch for testing that benchmarks actually run.
53pub const NUM_BATCHES_TINY: usize = 1;
54
55#[cfg(feature = "criterion")]
56pub mod benchmark {
57    use super::*;
58    use criterion::Criterion;
59    use sedona_expr::function_set::FunctionSet;
60
61    /// Benchmark a [ScalarUDF] using [Criterion]
62    ///
63    /// When built with the criterion feature, provides utilities for running a
64    /// basic benchmark on a [ScalarUDF] given [BenchmarkArgs]. This
65    /// basic benchmark currently has a hard-coded data size of 16 batches by
66    /// 8192 rows (==131,072 rows), which was chosen to ensure that most benchmarks
67    /// run nicely with criterion defaults (target 5s, 100 samples).
68    pub fn scalar(
69        c: &mut Criterion,
70        functions: &FunctionSet,
71        lib: &str,
72        name: &str,
73        config: impl Into<BenchmarkArgs>,
74    ) {
75        let not_found_err = format!("{name} was not found in function set");
76        let udf: ScalarUDF = functions
77            .scalar_udf(name)
78            .expect(&not_found_err)
79            .clone()
80            .into();
81        let data = config
82            .into()
83            .build_data(
84                Config::default().num_batches(),
85                Config::default().rows_per_batch(),
86            )
87            .unwrap();
88        c.bench_function(&data.make_label(lib, name), |b| {
89            b.iter(|| data.invoke_scalar(&udf).unwrap())
90        });
91    }
92
93    /// Benchmark a [AggregateUDF] using [Criterion]
94    ///
95    /// When built with the criterion feature, provides utilities for running a
96    /// basic benchmark on a [AggregateUDF] given [BenchmarkArgs]. This
97    /// shares a the default batch configuration with [scalar]. Because
98    /// aggregate functions can be invoked with varying combinations of
99    /// accumulation and merging of states, they should also be benchmarked
100    /// at a higher level. This benchmark primarily checks the accumulator.
101    pub fn aggregate(
102        c: &mut Criterion,
103        functions: &FunctionSet,
104        lib: &str,
105        name: &str,
106        config: impl Into<BenchmarkArgs>,
107    ) {
108        let not_found_err = format!("{name} was not found in function set");
109        let udf: AggregateUDF = functions
110            .aggregate_udf(name)
111            .expect(&not_found_err)
112            .clone()
113            .into();
114        let data = config
115            .into()
116            .build_data(
117                Config::default().num_batches(),
118                Config::default().rows_per_batch(),
119            )
120            .unwrap();
121        c.bench_function(&data.make_label(lib, name), |b| {
122            b.iter(|| data.invoke_aggregate(&udf).unwrap())
123        });
124    }
125
126    pub enum Config {
127        Tiny,
128        Small,
129    }
130
131    impl Default for Config {
132        fn default() -> Self {
133            #[cfg(debug_assertions)]
134            return Self::Tiny;
135
136            #[cfg(not(debug_assertions))]
137            return Self::Small;
138        }
139    }
140
141    impl Config {
142        fn num_batches(&self) -> usize {
143            match self {
144                Config::Tiny => NUM_BATCHES_TINY,
145                Config::Small => NUM_BATCHES_SMALL,
146            }
147        }
148
149        fn rows_per_batch(&self) -> usize {
150            match self {
151                Config::Tiny => ROWS_PER_BATCH_TINY,
152                Config::Small => ROWS_PER_BATCH,
153            }
154        }
155    }
156}
157
158/// Specification for benchmark arguments
159///
160/// This provides a concise definition of function input based on a
161/// combination of scalar/array arguments each specified by a [BenchmarkArgSpec].
162#[derive(Debug, Clone)]
163pub enum BenchmarkArgs {
164    /// Invoke a unary function with array input
165    Array(BenchmarkArgSpec),
166    /// Invoke a binary function with scalar and array input
167    ScalarArray(BenchmarkArgSpec, BenchmarkArgSpec),
168    /// Invoke a binary function with array and scalar input
169    ArrayScalar(BenchmarkArgSpec, BenchmarkArgSpec),
170    /// Invoke a binary function with two arrays
171    ArrayArray(BenchmarkArgSpec, BenchmarkArgSpec),
172    /// Invoke a function with an array and two scalar inputs
173    ArrayScalarScalar(BenchmarkArgSpec, BenchmarkArgSpec, BenchmarkArgSpec),
174    /// Invoke a ternary function with two arrays and a scalar
175    ArrayArrayScalar(BenchmarkArgSpec, BenchmarkArgSpec, BenchmarkArgSpec),
176    /// Invoke a ternary function with three arrays
177    ArrayArrayArray(BenchmarkArgSpec, BenchmarkArgSpec, BenchmarkArgSpec),
178    /// Invoke a quaternary function with four arrays
179    ArrayArrayArrayArray(
180        BenchmarkArgSpec,
181        BenchmarkArgSpec,
182        BenchmarkArgSpec,
183        BenchmarkArgSpec,
184    ),
185}
186
187impl From<BenchmarkArgSpec> for BenchmarkArgs {
188    fn from(value: BenchmarkArgSpec) -> Self {
189        BenchmarkArgs::Array(value)
190    }
191}
192
193impl BenchmarkArgs {
194    /// Calculate the [SedonaType]s of the input arguments
195    fn sedona_types(&self) -> Vec<SedonaType> {
196        self.specs().iter().map(|col| col.sedona_type()).collect()
197    }
198
199    /// Build [BenchmarkData] with the specified number of batches
200    pub fn build_data(&self, num_batches: usize, rows_per_batch: usize) -> Result<BenchmarkData> {
201        let array_configs = match self {
202            BenchmarkArgs::Array(_)
203            | BenchmarkArgs::ArrayArray(_, _)
204            | BenchmarkArgs::ArrayArrayScalar(_, _, _)
205            | BenchmarkArgs::ArrayArrayArray(_, _, _)
206            | BenchmarkArgs::ArrayArrayArrayArray(_, _, _, _) => self.specs(),
207            BenchmarkArgs::ScalarArray(_, col)
208            | BenchmarkArgs::ArrayScalar(col, _)
209            | BenchmarkArgs::ArrayScalarScalar(col, _, _) => {
210                vec![col.clone()]
211            }
212        };
213        let scalar_configs = match self {
214            BenchmarkArgs::ScalarArray(col, _)
215            | BenchmarkArgs::ArrayScalar(_, col)
216            | BenchmarkArgs::ArrayArrayScalar(_, _, col) => {
217                vec![col.clone()]
218            }
219            BenchmarkArgs::ArrayScalarScalar(_, col0, col1) => {
220                vec![col0.clone(), col1.clone()]
221            }
222            _ => vec![],
223        };
224
225        let arrays = array_configs
226            .iter()
227            .enumerate()
228            .map(|(i, col)| col.build_arrays(i, num_batches, rows_per_batch))
229            .collect::<Result<Vec<_>>>()?;
230
231        let scalars = scalar_configs
232            .iter()
233            .enumerate()
234            .map(|(i, col)| col.build_scalar(i))
235            .collect::<Result<Vec<_>>>()?;
236
237        Ok(BenchmarkData {
238            config: self.clone(),
239            num_batches,
240            arrays,
241            scalars,
242        })
243    }
244
245    fn specs(&self) -> Vec<BenchmarkArgSpec> {
246        match self {
247            BenchmarkArgs::Array(col) => vec![col.clone()],
248            BenchmarkArgs::ScalarArray(col0, col1)
249            | BenchmarkArgs::ArrayScalar(col0, col1)
250            | BenchmarkArgs::ArrayArray(col0, col1) => {
251                vec![col0.clone(), col1.clone()]
252            }
253            BenchmarkArgs::ArrayScalarScalar(col0, col1, col2)
254            | BenchmarkArgs::ArrayArrayScalar(col0, col1, col2)
255            | BenchmarkArgs::ArrayArrayArray(col0, col1, col2) => {
256                vec![col0.clone(), col1.clone(), col2.clone()]
257            }
258            BenchmarkArgs::ArrayArrayArrayArray(col0, col1, col2, col3) => {
259                vec![col0.clone(), col1.clone(), col2.clone(), col3.clone()]
260            }
261        }
262    }
263}
264
265/// Specification of a single argument to a function
266///
267/// Geometries are generated using the [RandomPartitionedDataBuilder], which offers
268/// more specific options for generating random geometries.
269#[derive(Clone)]
270pub enum BenchmarkArgSpec {
271    /// Randomly generated point input
272    Point,
273    /// Randomly generated linestring input with a specified number of vertices
274    LineString(usize),
275    /// Randomly generated polygon input with a specified number of vertices
276    Polygon(usize),
277    /// Randomly generated linestring input with a specified number of vertices
278    MultiPoint(usize),
279    /// Randomly generated floating point input with a given range of values
280    Float64(f64, f64),
281    /// A transformation of any of the above based on a [ScalarUDF] accepting
282    /// a single argument
283    Transformed(Box<BenchmarkArgSpec>, ScalarUDF),
284    /// A string that will be a constant
285    String(String),
286    /// Randomly generated raster input with a specified width, height
287    Raster(usize, usize),
288}
289
290// Custom implementation of Debug because otherwise the output of Transformed()
291// is excessively verbose
292impl Debug for BenchmarkArgSpec {
293    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294        match self {
295            Self::Point => write!(f, "Point"),
296            Self::LineString(arg0) => f.debug_tuple("LineString").field(arg0).finish(),
297            Self::Polygon(arg0) => f.debug_tuple("Polygon").field(arg0).finish(),
298            Self::MultiPoint(arg0) => f.debug_tuple("MultiPoint").field(arg0).finish(),
299            Self::Float64(arg0, arg1) => f.debug_tuple("Float64").field(arg0).field(arg1).finish(),
300            Self::Transformed(inner, t) => write!(f, "{}({:?})", t.name(), inner),
301            Self::String(s) => write!(f, "String({s})"),
302            Self::Raster(w, h) => f.debug_tuple("Raster").field(w).field(h).finish(),
303        }
304    }
305}
306
307impl BenchmarkArgSpec {
308    /// The [SedonaType] of this argument
309    pub fn sedona_type(&self) -> SedonaType {
310        match self {
311            BenchmarkArgSpec::Point
312            | BenchmarkArgSpec::Polygon(_)
313            | BenchmarkArgSpec::LineString(_)
314            | BenchmarkArgSpec::MultiPoint(_) => WKB_GEOMETRY,
315            BenchmarkArgSpec::Float64(_, _) => SedonaType::Arrow(DataType::Float64),
316            BenchmarkArgSpec::Transformed(inner, t) => {
317                let tester = ScalarUdfTester::new(t.clone(), vec![inner.sedona_type()]);
318                tester.return_type().unwrap()
319            }
320            BenchmarkArgSpec::String(_) => SedonaType::Arrow(DataType::Utf8),
321            BenchmarkArgSpec::Raster(_, _) => RASTER,
322        }
323    }
324
325    /// Build a [ScalarValue] for this argument
326    ///
327    /// This currently builds the same non-null scalar for each unique value
328    /// of i (the argument number).
329    pub fn build_scalar(&self, i: usize) -> Result<ScalarValue> {
330        let array = self.build_arrays(i, 1, 1)?;
331        ScalarValue::try_from_array(&array[0], 0)
332    }
333
334    /// Build a column of num_batches arrays
335    ///
336    /// This currently builds the same column for each unique value of i (the argument
337    /// number). The batch size is currently fixed to 8192 (the DataFusion default).
338    pub fn build_arrays(
339        &self,
340        i: usize,
341        num_batches: usize,
342        rows_per_batch: usize,
343    ) -> Result<Vec<ArrayRef>> {
344        match self {
345            BenchmarkArgSpec::Point => {
346                self.build_geometry(i, GeometryTypeId::Point, num_batches, 1, 1, rows_per_batch)
347            }
348            BenchmarkArgSpec::LineString(vertex_count) => self.build_geometry(
349                i,
350                GeometryTypeId::LineString,
351                num_batches,
352                *vertex_count,
353                1,
354                rows_per_batch,
355            ),
356            BenchmarkArgSpec::Polygon(vertex_count) => self.build_geometry(
357                i,
358                GeometryTypeId::Polygon,
359                num_batches,
360                *vertex_count,
361                1,
362                rows_per_batch,
363            ),
364            BenchmarkArgSpec::MultiPoint(part_count) => self.build_geometry(
365                i,
366                GeometryTypeId::MultiPoint,
367                num_batches,
368                1,
369                *part_count,
370                rows_per_batch,
371            ),
372            BenchmarkArgSpec::Float64(lo, hi) => {
373                let mut rng = self.rng(i);
374                let dist = Uniform::new(lo, hi);
375                (0..num_batches)
376                    .map(|_| -> Result<ArrayRef> {
377                        let float64_array: Float64Array =
378                            (0..rows_per_batch).map(|_| rng.sample(dist)).collect();
379                        Ok(Arc::new(float64_array))
380                    })
381                    .collect()
382            }
383            BenchmarkArgSpec::Transformed(inner, t) => {
384                let inner_type = inner.sedona_type();
385                let inner_arrays = inner.build_arrays(i, num_batches, rows_per_batch)?;
386                let tester = ScalarUdfTester::new(t.clone(), vec![inner_type]);
387                inner_arrays
388                    .into_iter()
389                    .map(|array| tester.invoke_array(array))
390                    .collect::<Result<Vec<_>>>()
391            }
392            BenchmarkArgSpec::String(s) => {
393                let string_array = (0..num_batches)
394                    .map(|_| {
395                        let array = arrow_array::StringArray::from_iter_values(
396                            std::iter::repeat_n(s, rows_per_batch),
397                        );
398                        Ok(Arc::new(array) as ArrayRef)
399                    })
400                    .collect::<Result<Vec<_>>>()?;
401                Ok(string_array)
402            }
403            BenchmarkArgSpec::Raster(width, height) => {
404                let mut arrays = vec![];
405                for _ in 0..num_batches {
406                    let tile_size = (*width, *height);
407                    let tile_count = (rows_per_batch, 1);
408                    let raster = generate_tiled_rasters(
409                        tile_size,
410                        tile_count,
411                        BandDataType::UInt8,
412                        Some(43),
413                    )?;
414                    arrays.push(Arc::new(raster) as ArrayRef);
415                }
416                Ok(arrays)
417            }
418        }
419    }
420
421    fn build_geometry(
422        &self,
423        i: usize,
424        geom_type: GeometryTypeId,
425        num_batches: usize,
426        vertex_count: usize,
427        num_parts_count: usize,
428        rows_per_batch: usize,
429    ) -> Result<Vec<ArrayRef>> {
430        let builder = RandomPartitionedDataBuilder::new()
431            .num_partitions(1)
432            .rows_per_batch(rows_per_batch)
433            .batches_per_partition(num_batches)
434            // Use a random geometry range that is also not unrealistic for geography
435            .bounds(Rect::new((-10.0, -10.0), (10.0, 10.0)))
436            .size_range((0.1, 2.0))
437            .vertices_per_linestring_range((vertex_count, vertex_count))
438            .num_parts_range((num_parts_count, num_parts_count))
439            .geometry_type(geom_type)
440            // Currently just use WKB_GEOMETRY (we can generate a view type with
441            // Transformed)
442            .sedona_type(WKB_GEOMETRY);
443
444        builder
445            .partition_reader(self.rng(i), 0)
446            .map(|batch| -> Result<ArrayRef> { Ok(batch?.column(2).clone()) })
447            .collect()
448    }
449
450    fn rng(&self, i: usize) -> impl Rng {
451        StdRng::seed_from_u64(42 + i as u64)
452    }
453}
454
455/// Fully resolved data ready for running a benchmark
456///
457/// This struct contains the fully built data (such that benchmarks do not
458/// measure the time required to build the data) and has methods for invoking
459/// functions on it.
460pub struct BenchmarkData {
461    config: BenchmarkArgs,
462    num_batches: usize,
463    arrays: Vec<Vec<ArrayRef>>,
464    scalars: Vec<ScalarValue>,
465}
466
467impl BenchmarkData {
468    /// Create a label based on the library, function name, and configuration
469    pub fn make_label(&self, lib: &str, name: &str) -> String {
470        format!("{lib}-{name}-{:?}", self.config)
471    }
472
473    /// Invoke a scalar function on this data
474    pub fn invoke_scalar(&self, udf: &ScalarUDF) -> Result<()> {
475        let tester = ScalarUdfTester::new(udf.clone(), self.config.sedona_types().clone());
476
477        match self.config {
478            BenchmarkArgs::Array(_) => {
479                for i in 0..self.num_batches {
480                    tester.invoke_array(self.arrays[0][i].clone())?;
481                }
482            }
483            BenchmarkArgs::ScalarArray(_, _) => {
484                let scalar = &self.scalars[0];
485                for i in 0..self.num_batches {
486                    tester.invoke_scalar_array(scalar.clone(), self.arrays[0][i].clone())?;
487                }
488            }
489            BenchmarkArgs::ArrayScalar(_, _) => {
490                let scalar = &self.scalars[0];
491                for i in 0..self.num_batches {
492                    tester.invoke_array_scalar(self.arrays[0][i].clone(), scalar.clone())?;
493                }
494            }
495            BenchmarkArgs::ArrayArray(_, _) => {
496                for i in 0..self.num_batches {
497                    tester
498                        .invoke_array_array(self.arrays[0][i].clone(), self.arrays[1][i].clone())?;
499                }
500            }
501            BenchmarkArgs::ArrayScalarScalar(_, _, _) => {
502                let scalar0 = &self.scalars[0];
503                let scalar1 = &self.scalars[1];
504                for i in 0..self.num_batches {
505                    tester.invoke_array_scalar_scalar(
506                        self.arrays[0][i].clone(),
507                        scalar0.clone(),
508                        scalar1.clone(),
509                    )?;
510                }
511            }
512            BenchmarkArgs::ArrayArrayScalar(_, _, _) => {
513                for i in 0..self.num_batches {
514                    tester.invoke_array_array_scalar(
515                        self.arrays[0][i].clone(),
516                        self.arrays[1][i].clone(),
517                        self.scalars[0].clone(),
518                    )?;
519                }
520            }
521            BenchmarkArgs::ArrayArrayArray(_, _, _) => {
522                for i in 0..self.num_batches {
523                    tester.invoke_arrays(vec![
524                        self.arrays[0][i].clone(),
525                        self.arrays[1][i].clone(),
526                        self.arrays[2][i].clone(),
527                    ])?;
528                }
529            }
530            BenchmarkArgs::ArrayArrayArrayArray(_, _, _, _) => {
531                for i in 0..self.num_batches {
532                    tester.invoke_arrays(vec![
533                        self.arrays[0][i].clone(),
534                        self.arrays[1][i].clone(),
535                        self.arrays[2][i].clone(),
536                        self.arrays[3][i].clone(),
537                    ])?;
538                }
539            }
540        }
541
542        Ok(())
543    }
544
545    /// Invoke an aggregate function on this data
546    pub fn invoke_aggregate(&self, udf: &AggregateUDF) -> Result<ScalarValue> {
547        if !matches!(self.config, BenchmarkArgs::Array(_)) {
548            return sedona_internal_err!(
549                "invoke_aggregate() not implemented for {:?}",
550                self.config
551            );
552        }
553
554        let tester = AggregateUdfTester::new(udf.clone(), self.config.sedona_types().clone());
555        tester.aggregate(&self.arrays[0])
556    }
557}
558
559#[cfg(test)]
560mod test {
561    use arrow_array::{Array, StructArray};
562    use datafusion_common::cast::as_binary_array;
563    use datafusion_expr::{ColumnarValue, SimpleScalarUDF};
564    use geo_traits::Dimensions;
565    use rstest::rstest;
566    use sedona_geometry::{analyze::analyze_geometry, types::GeometryTypeAndDimensions};
567
568    use super::*;
569
570    #[test]
571    fn arg_spec_scalar() {
572        let spec = BenchmarkArgSpec::Point;
573        assert_eq!(spec.sedona_type(), WKB_GEOMETRY);
574
575        let scalar = spec.build_scalar(0).unwrap();
576
577        // Make sure this is deterministic
578        assert_eq!(spec.build_scalar(0).unwrap(), scalar);
579
580        // Make sure we generate different scalars for different columns
581        assert_ne!(spec.build_scalar(1).unwrap(), scalar);
582
583        if let ScalarValue::Binary(Some(wkb_bytes)) = scalar {
584            let wkb = wkb::reader::read_wkb(&wkb_bytes).unwrap();
585            let analysis = analyze_geometry(&wkb).unwrap();
586            assert_eq!(analysis.point_count, 1);
587            assert_eq!(
588                analysis.geometry_type,
589                GeometryTypeAndDimensions::new(GeometryTypeId::Point, Dimensions::Xy)
590            )
591        } else {
592            unreachable!("Unexpected scalar output {scalar}")
593        }
594    }
595
596    #[rstest]
597    fn arg_spec_geometry(
598        #[values(
599            (BenchmarkArgSpec::Point, GeometryTypeId::Point, 1),
600            (BenchmarkArgSpec::LineString(10), GeometryTypeId::LineString, 10),
601            (BenchmarkArgSpec::Polygon(10), GeometryTypeId::Polygon, 11),
602            (BenchmarkArgSpec::MultiPoint(10), GeometryTypeId::MultiPoint, 10),
603        )]
604        config: (BenchmarkArgSpec, GeometryTypeId, i64),
605    ) {
606        let (spec, geometry_type, point_count) = config;
607        assert_eq!(spec.sedona_type(), WKB_GEOMETRY);
608
609        let arrays = spec.build_arrays(0, 2, ROWS_PER_BATCH).unwrap();
610        assert_eq!(arrays.len(), 2);
611
612        // Make sure this is deterministic
613        assert_eq!(spec.build_arrays(0, 2, ROWS_PER_BATCH).unwrap(), arrays);
614
615        // Make sure we generate different arrays for different argument numbers
616        assert_ne!(spec.build_arrays(1, 2, ROWS_PER_BATCH).unwrap(), arrays);
617
618        for array in arrays {
619            assert_eq!(array.data_type(), WKB_GEOMETRY.storage_type());
620            assert_eq!(array.len(), ROWS_PER_BATCH);
621
622            let binary_array = as_binary_array(&array).unwrap();
623            assert_eq!(binary_array.null_count(), 0);
624
625            for wkb_bytes in binary_array {
626                let wkb = wkb::reader::read_wkb(wkb_bytes.unwrap()).unwrap();
627                let analysis = analyze_geometry(&wkb).unwrap();
628                assert_eq!(analysis.point_count, point_count);
629                assert_eq!(
630                    analysis.geometry_type,
631                    GeometryTypeAndDimensions::new(geometry_type, Dimensions::Xy)
632                )
633            }
634        }
635    }
636
637    #[test]
638    fn arg_spec_float() {
639        let spec = BenchmarkArgSpec::Float64(1.0, 2.0);
640        assert_eq!(spec.sedona_type(), SedonaType::Arrow(DataType::Float64));
641
642        let arrays = spec.build_arrays(0, 2, ROWS_PER_BATCH).unwrap();
643        assert_eq!(arrays.len(), 2);
644
645        // Make sure this is deterministic
646        assert_eq!(spec.build_arrays(0, 2, ROWS_PER_BATCH).unwrap(), arrays);
647
648        // Make sure we generate different arrays for different argument numbers
649        assert_ne!(spec.build_arrays(1, 2, ROWS_PER_BATCH).unwrap(), arrays);
650
651        for array in arrays {
652            assert_eq!(array.data_type(), &DataType::Float64);
653            assert_eq!(array.len(), ROWS_PER_BATCH);
654            assert_eq!(array.null_count(), 0);
655        }
656    }
657
658    #[test]
659    fn arg_spec_transformed() {
660        let udf = SimpleScalarUDF::new(
661            "float32",
662            vec![DataType::Float64],
663            DataType::Float32,
664            datafusion_expr::Volatility::Immutable,
665            Arc::new(|args| -> Result<ColumnarValue> { args[0].cast_to(&DataType::Float32, None) }),
666        );
667
668        let spec =
669            BenchmarkArgSpec::Transformed(BenchmarkArgSpec::Float64(1.0, 2.0).into(), udf.into());
670        assert_eq!(spec.sedona_type(), SedonaType::Arrow(DataType::Float32));
671
672        assert_eq!(format!("{spec:?}"), "float32(Float64(1.0, 2.0))");
673        let arrays = spec.build_arrays(0, 2, ROWS_PER_BATCH).unwrap();
674        assert_eq!(arrays.len(), 2);
675
676        // Make sure this is deterministic
677        assert_eq!(spec.build_arrays(0, 2, ROWS_PER_BATCH).unwrap(), arrays);
678
679        // Make sure we generate different arrays for different argument numbers
680        assert_ne!(spec.build_arrays(1, 2, ROWS_PER_BATCH).unwrap(), arrays);
681
682        for array in arrays {
683            assert_eq!(array.data_type(), &DataType::Float32);
684            assert_eq!(array.len(), ROWS_PER_BATCH);
685            assert_eq!(array.null_count(), 0);
686        }
687    }
688
689    #[test]
690    fn args_array() {
691        let spec = BenchmarkArgs::Array(BenchmarkArgSpec::Point);
692        assert_eq!(spec.sedona_types(), [WKB_GEOMETRY]);
693
694        let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
695        assert_eq!(data.num_batches, 2);
696        assert_eq!(data.arrays.len(), 1);
697        assert_eq!(data.scalars.len(), 0);
698
699        assert_eq!(data.arrays[0].len(), 2);
700        assert_eq!(WKB_GEOMETRY.storage_type(), data.arrays[0][0].data_type());
701    }
702
703    #[test]
704    fn args_array_scalar() {
705        let spec = BenchmarkArgs::ArrayScalar(
706            BenchmarkArgSpec::Point,
707            BenchmarkArgSpec::Float64(1.0, 2.0),
708        );
709        assert_eq!(
710            spec.sedona_types(),
711            [WKB_GEOMETRY, SedonaType::Arrow(DataType::Float64)]
712        );
713
714        let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
715        assert_eq!(data.num_batches, 2);
716
717        assert_eq!(data.arrays.len(), 1);
718        assert_eq!(data.arrays[0].len(), 2);
719        assert_eq!(WKB_GEOMETRY.storage_type(), data.arrays[0][0].data_type());
720
721        assert_eq!(data.scalars.len(), 1);
722        assert_eq!(data.scalars[0].data_type(), DataType::Float64);
723    }
724
725    #[test]
726    fn args_scalar_array() {
727        let spec = BenchmarkArgs::ScalarArray(
728            BenchmarkArgSpec::Point,
729            BenchmarkArgSpec::Float64(1.0, 2.0),
730        );
731        assert_eq!(
732            spec.sedona_types(),
733            [WKB_GEOMETRY, SedonaType::Arrow(DataType::Float64)]
734        );
735
736        let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
737        assert_eq!(data.num_batches, 2);
738
739        assert_eq!(data.scalars.len(), 1);
740        assert_eq!(WKB_GEOMETRY.storage_type(), &data.scalars[0].data_type());
741
742        assert_eq!(data.arrays.len(), 1);
743        assert_eq!(data.arrays[0].len(), 2);
744        assert_eq!(data.arrays[0][0].data_type(), &DataType::Float64);
745    }
746
747    #[test]
748    fn args_array_array() {
749        let spec =
750            BenchmarkArgs::ArrayArray(BenchmarkArgSpec::Point, BenchmarkArgSpec::Float64(1.0, 2.0));
751        assert_eq!(
752            spec.sedona_types(),
753            [WKB_GEOMETRY, SedonaType::Arrow(DataType::Float64)]
754        );
755
756        let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
757        assert_eq!(data.num_batches, 2);
758        assert_eq!(data.arrays.len(), 2);
759        assert_eq!(data.scalars.len(), 0);
760
761        assert_eq!(data.arrays[0].len(), 2);
762        assert_eq!(WKB_GEOMETRY.storage_type(), data.arrays[0][0].data_type());
763
764        assert_eq!(data.arrays[1].len(), 2);
765        assert_eq!(data.arrays[1][0].data_type(), &DataType::Float64);
766    }
767
768    #[test]
769    fn args_array_scalar_scalar() {
770        let spec = BenchmarkArgs::ArrayScalarScalar(
771            BenchmarkArgSpec::Point,
772            BenchmarkArgSpec::Float64(1.0, 2.0),
773            BenchmarkArgSpec::String("test".to_string()),
774        );
775        assert_eq!(
776            spec.sedona_types(),
777            [
778                WKB_GEOMETRY,
779                SedonaType::Arrow(DataType::Float64),
780                SedonaType::Arrow(DataType::Utf8)
781            ]
782        );
783
784        let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
785        assert_eq!(data.num_batches, 2);
786        assert_eq!(data.arrays.len(), 1);
787        assert_eq!(data.scalars.len(), 2);
788        assert_eq!(data.arrays[0].len(), 2);
789        assert_eq!(WKB_GEOMETRY.storage_type(), data.arrays[0][0].data_type());
790        assert_eq!(data.scalars[0].data_type(), DataType::Float64);
791        assert_eq!(data.scalars[1].data_type(), DataType::Utf8);
792    }
793
794    #[test]
795    fn args_array_array_scalar() {
796        let spec = BenchmarkArgs::ArrayArrayScalar(
797            BenchmarkArgSpec::Point,
798            BenchmarkArgSpec::Point,
799            BenchmarkArgSpec::Float64(1.0, 2.0),
800        );
801        assert_eq!(
802            spec.sedona_types(),
803            [
804                WKB_GEOMETRY,
805                WKB_GEOMETRY,
806                SedonaType::Arrow(DataType::Float64)
807            ]
808        );
809
810        let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
811        assert_eq!(data.num_batches, 2);
812        assert_eq!(data.arrays.len(), 3);
813        assert_eq!(data.scalars.len(), 1);
814        assert_eq!(data.arrays[0].len(), 2);
815        assert_eq!(WKB_GEOMETRY.storage_type(), data.arrays[0][0].data_type());
816        assert_eq!(data.arrays[1].len(), 2);
817        assert_eq!(WKB_GEOMETRY.storage_type(), data.arrays[1][0].data_type());
818
819        assert_eq!(data.scalars[0].data_type(), DataType::Float64);
820    }
821
822    #[test]
823    fn args_array_array_array() {
824        let spec = BenchmarkArgs::ArrayArrayArray(
825            BenchmarkArgSpec::Point,
826            BenchmarkArgSpec::Point,
827            BenchmarkArgSpec::Float64(1.0, 2.0),
828        );
829        assert_eq!(
830            spec.sedona_types(),
831            [
832                WKB_GEOMETRY,
833                WKB_GEOMETRY,
834                SedonaType::Arrow(DataType::Float64)
835            ]
836        );
837
838        let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
839        assert_eq!(data.num_batches, 2);
840        assert_eq!(data.arrays.len(), 3);
841        assert_eq!(data.scalars.len(), 0);
842        assert_eq!(data.arrays[0].len(), 2);
843        assert_eq!(WKB_GEOMETRY.storage_type(), data.arrays[0][0].data_type());
844        assert_eq!(data.arrays[1].len(), 2);
845        assert_eq!(WKB_GEOMETRY.storage_type(), data.arrays[1][0].data_type());
846        assert_eq!(data.arrays[2].len(), 2);
847        assert_eq!(data.arrays[2][0].data_type(), &DataType::Float64);
848    }
849
850    #[test]
851    fn args_array_array_array_array() {
852        let spec = BenchmarkArgs::ArrayArrayArrayArray(
853            BenchmarkArgSpec::Float64(1.0, 2.0),
854            BenchmarkArgSpec::Float64(3.0, 4.0),
855            BenchmarkArgSpec::Float64(5.0, 6.0),
856            BenchmarkArgSpec::Float64(7.0, 8.0),
857        );
858        assert_eq!(
859            spec.sedona_types(),
860            [
861                SedonaType::Arrow(DataType::Float64),
862                SedonaType::Arrow(DataType::Float64),
863                SedonaType::Arrow(DataType::Float64),
864                SedonaType::Arrow(DataType::Float64)
865            ]
866        );
867
868        let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
869        assert_eq!(data.num_batches, 2);
870        assert_eq!(data.arrays.len(), 4);
871        assert_eq!(data.scalars.len(), 0);
872        assert_eq!(data.arrays[0].len(), 2);
873        assert_eq!(data.arrays[0][0].data_type(), &DataType::Float64);
874        assert_eq!(data.arrays[1].len(), 2);
875        assert_eq!(data.arrays[1][0].data_type(), &DataType::Float64);
876        assert_eq!(data.arrays[2].len(), 2);
877        assert_eq!(data.arrays[2][0].data_type(), &DataType::Float64);
878        assert_eq!(data.arrays[3].len(), 2);
879        assert_eq!(data.arrays[3][0].data_type(), &DataType::Float64);
880    }
881
882    #[test]
883    fn arg_spec_raster() {
884        use sedona_raster::array::RasterStructArray;
885        use sedona_raster::traits::RasterRef;
886
887        let spec = BenchmarkArgSpec::Raster(10, 5);
888        assert_eq!(spec.sedona_type(), RASTER);
889        let data = spec.build_arrays(0, 2, ROWS_PER_BATCH).unwrap();
890        assert_eq!(data.len(), 2);
891        assert_eq!(data[0].data_type(), RASTER.storage_type());
892
893        let raster_array = data[0].as_any().downcast_ref::<StructArray>().unwrap();
894        let rasters = RasterStructArray::new(raster_array);
895        assert_eq!(rasters.len(), ROWS_PER_BATCH);
896        let raster = rasters.get(0).unwrap();
897        let metadata = raster.metadata();
898        assert_eq!(metadata.width(), 10);
899        assert_eq!(metadata.height(), 5);
900    }
901}