1use std::fmt::Display;
18use std::sync::Arc;
19
20use datafusion::config::{ConfigEntry, ConfigExtension, ConfigField, ExtensionOptions, Visit};
21use datafusion::prelude::SessionConfig;
22use datafusion_common::Result;
23use datafusion_common::{config_err, config_namespace};
24use regex::Regex;
25
26use crate::sedona_internal_err;
27
28pub const DEFAULT_SPECULATIVE_THRESHOLD: usize = 1000;
31
32pub const DEFAULT_MIN_POINTS_FOR_BUILD_PREPARATION: usize = 50;
34
35pub fn add_sedona_option_extension(config: SessionConfig) -> SessionConfig {
37 config.with_option_extension(SedonaOptions::default())
38}
39
40config_namespace! {
41 pub struct SedonaOptions {
43 pub spatial_join: SpatialJoinOptions, default = SpatialJoinOptions::default()
45
46 pub crs_provider: CrsProviderOption, default = CrsProviderOption::default()
48 }
49}
50
51config_namespace! {
52 pub struct SpatialJoinOptions {
58 pub enable: bool, default = true
60
61 pub spatial_library: SpatialLibrary, default = SpatialLibrary::Tg
63
64 pub geos: GeosOptions, default = GeosOptions::default()
66
67 pub tg: TgOptions, default = TgOptions::default()
69
70 pub execution_mode: ExecutionMode, default = ExecutionMode::Speculative(DEFAULT_SPECULATIVE_THRESHOLD)
72
73 pub concurrent_build_side_collection: bool, default = true
76
77 pub knn_include_tie_breakers: bool, default = false
79
80 pub repartition_probe_side: bool, default = true
84
85 pub max_index_side_bbox_samples: usize, default = 10000
88
89 pub min_index_side_bbox_samples: usize, default = 1000
92
93 pub target_index_side_bbox_sampling_rate: f64, default = 0.01
96
97 pub spilled_batch_in_memory_size_threshold: usize, default = 0
102
103 pub parallel_refinement_chunk_size: usize, default = 8192
109
110 pub debug : SpatialJoinDebugOptions, default = SpatialJoinDebugOptions::default()
112 }
113}
114
115config_namespace! {
116 pub struct SpatialJoinDebugOptions {
118 pub num_spatial_partitions: NumSpatialPartitionsConfig, default = NumSpatialPartitionsConfig::Auto
120
121 pub memory_for_intermittent_usage: Option<usize>, default = None
123
124 pub force_spill: bool, default = false
126
127 pub random_seed: Option<u64>, default = None
129 }
130}
131
132#[derive(Debug, PartialEq, Clone, Copy)]
133pub enum NumSpatialPartitionsConfig {
134 Auto,
136
137 Fixed(usize),
139}
140
141impl ConfigField for NumSpatialPartitionsConfig {
142 fn visit<V: Visit>(&self, v: &mut V, key: &str, description: &'static str) {
143 let value = match self {
144 NumSpatialPartitionsConfig::Auto => "auto".into(),
145 NumSpatialPartitionsConfig::Fixed(n) => format!("{n}"),
146 };
147 v.some(key, value, description);
148 }
149
150 fn set(&mut self, _key: &str, value: &str) -> Result<()> {
151 let value = value.to_lowercase();
152 let config = match value.as_str() {
153 "auto" => NumSpatialPartitionsConfig::Auto,
154 _ => match value.parse::<usize>() {
155 Ok(n) => {
156 if n > 0 {
157 NumSpatialPartitionsConfig::Fixed(n)
158 } else {
159 return Err(datafusion_common::DataFusionError::Configuration(
160 "num_spatial_partitions must be greater than 0".to_string(),
161 ));
162 }
163 }
164 Err(_) => {
165 return Err(datafusion_common::DataFusionError::Configuration(format!(
166 "Unknown num_spatial_partitions config: {value}. Expected formats: auto, <number>"
167 )));
168 }
169 },
170 };
171 *self = config;
172 Ok(())
173 }
174}
175
176config_namespace! {
177 pub struct GeosOptions {
179 pub min_points_for_build_preparation: usize, default = DEFAULT_MIN_POINTS_FOR_BUILD_PREPARATION
181 }
182}
183
184config_namespace! {
185 pub struct TgOptions {
187 pub index_type: TgIndexType, default = TgIndexType::YStripes
189 }
190}
191
192impl ConfigExtension for SedonaOptions {
193 const PREFIX: &'static str = "sedona";
194}
195
196impl ExtensionOptions for SedonaOptions {
197 fn as_any(&self) -> &dyn std::any::Any {
198 self
199 }
200
201 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
202 self
203 }
204
205 fn cloned(&self) -> Box<dyn ExtensionOptions> {
206 Box::new(self.clone())
207 }
208
209 fn set(&mut self, key: &str, value: &str) -> Result<()> {
210 <Self as ConfigField>::set(self, key, value)
211 }
212
213 fn entries(&self) -> Vec<ConfigEntry> {
214 struct Visitor(Vec<ConfigEntry>);
215
216 impl Visit for Visitor {
217 fn some<V: Display>(&mut self, key: &str, value: V, description: &'static str) {
218 self.0.push(ConfigEntry {
219 key: key.to_string(),
220 value: Some(value.to_string()),
221 description,
222 })
223 }
224
225 fn none(&mut self, key: &str, description: &'static str) {
226 self.0.push(ConfigEntry {
227 key: key.to_string(),
228 value: None,
229 description,
230 })
231 }
232 }
233
234 let mut v = Visitor(vec![]);
235 self.visit(&mut v, Self::PREFIX, "");
236 v.0
237 }
238}
239
240#[derive(Debug, Clone, PartialEq, Copy)]
251pub enum ExecutionMode {
252 PrepareNone,
254
255 PrepareBuild,
257
258 PrepareProbe,
260
261 Speculative(usize),
264}
265
266impl ExecutionMode {
267 pub fn to_usize(&self) -> usize {
272 match self {
273 ExecutionMode::PrepareNone => 0,
274 ExecutionMode::PrepareBuild => 1,
275 ExecutionMode::PrepareProbe => 2,
276 ExecutionMode::Speculative(_) => 3,
277 }
278 }
279}
280
281impl ConfigField for ExecutionMode {
282 fn visit<V: Visit>(&self, v: &mut V, key: &str, description: &'static str) {
283 let value = match self {
284 ExecutionMode::PrepareNone => "prepare_none".into(),
285 ExecutionMode::PrepareBuild => "prepare_build".into(),
286 ExecutionMode::PrepareProbe => "prepare_probe".into(),
287 ExecutionMode::Speculative(n) => format!("auto[{n}]"),
288 };
289 v.some(key, value, description);
290 }
291
292 fn set(&mut self, _key: &str, value: &str) -> Result<()> {
293 let value = value.to_lowercase();
294 let mode = match value.as_str() {
295 "prepare_none" => ExecutionMode::PrepareNone,
296 "prepare_build" => ExecutionMode::PrepareBuild,
297 "prepare_probe" => ExecutionMode::PrepareProbe,
298 _ => {
299 let auto_regex = Regex::new(r"^auto(?:\[(\d+)\])?$").unwrap();
301
302 if let Some(captures) = auto_regex.captures(&value) {
303 let n = if let Some(number_match) = captures.get(1) {
305 match number_match.as_str().parse::<usize>() {
306 Ok(n) => {
307 if n == 0 {
308 return Err(datafusion_common::DataFusionError::Configuration(
309 "Invalid number in auto mode: 0 is not allowed".to_string(),
310 ));
311 }
312 n
313 }
314 Err(_) => {
315 return Err(datafusion_common::DataFusionError::Configuration(
316 format!(
317 "Invalid number in auto mode: {}",
318 number_match.as_str()
319 ),
320 ));
321 }
322 }
323 } else {
324 DEFAULT_SPECULATIVE_THRESHOLD };
326 ExecutionMode::Speculative(n)
327 } else {
328 return Err(datafusion_common::DataFusionError::Configuration(
329 format!("Unknown execution mode: {value}. Expected formats: prepare_none, prepare_build, prepare_probe, auto, auto[number]")
330 ));
331 }
332 }
333 };
334 *self = mode;
335 Ok(())
336 }
337}
338
339#[derive(Debug, PartialEq, Clone, Copy)]
341pub enum SpatialLibrary {
342 Geo,
344
345 Geos,
347
348 Tg,
350}
351
352impl ConfigField for SpatialLibrary {
353 fn visit<V: Visit>(&self, v: &mut V, key: &str, description: &'static str) {
354 let value = match self {
355 SpatialLibrary::Geo => "geo",
356 SpatialLibrary::Geos => "geos",
357 SpatialLibrary::Tg => "tg",
358 };
359 v.some(key, value, description);
360 }
361
362 fn set(&mut self, _key: &str, value: &str) -> Result<()> {
363 let value = value.to_lowercase();
364 let library = match value.as_str() {
365 "geo" => SpatialLibrary::Geo,
366 "geos" => SpatialLibrary::Geos,
367 "tg" => SpatialLibrary::Tg,
368 _ => {
369 return Err(datafusion_common::DataFusionError::Configuration(format!(
370 "Unknown spatial library: {value}. Expected: geo, geos, tg"
371 )));
372 }
373 };
374 *self = library;
375 Ok(())
376 }
377}
378
379#[derive(Debug, PartialEq, Clone, Copy)]
381pub enum TgIndexType {
382 Natural,
384
385 YStripes,
387}
388
389impl ConfigField for TgIndexType {
390 fn visit<V: Visit>(&self, v: &mut V, key: &str, description: &'static str) {
391 let value = match self {
392 TgIndexType::Natural => "natural",
393 TgIndexType::YStripes => "ystripes",
394 };
395 v.some(key, value, description);
396 }
397
398 fn set(&mut self, _key: &str, value: &str) -> Result<()> {
399 let value = value.to_lowercase();
400 let index_type = match value.as_str() {
401 "natural" => TgIndexType::Natural,
402 "ystripes" => TgIndexType::YStripes,
403 _ => {
404 return Err(datafusion_common::DataFusionError::Configuration(format!(
405 "Unknown TG index type: {value}. Expected: natural, ystripes"
406 )));
407 }
408 };
409 *self = index_type;
410 Ok(())
411 }
412}
413
414pub trait CrsProvider: std::fmt::Debug + Send + Sync {
422 fn to_projjson(&self, crs_string: &str) -> Result<String>;
423}
424
425#[derive(Debug, Clone)]
428pub struct CrsProviderOption(Arc<dyn CrsProvider>);
429
430impl CrsProviderOption {
431 pub fn new(inner: Arc<dyn CrsProvider>) -> Self {
433 CrsProviderOption(inner)
434 }
435
436 pub fn to_projjson(&self, crs_string: &str) -> Result<String> {
438 self.0.to_projjson(crs_string)
439 }
440}
441
442impl Default for CrsProviderOption {
443 fn default() -> Self {
444 Self(Arc::new(DefaultCrsProvider {}))
445 }
446}
447
448impl PartialEq for CrsProviderOption {
449 fn eq(&self, other: &Self) -> bool {
450 Arc::ptr_eq(&self.0, &other.0)
451 }
452}
453
454impl ConfigField for CrsProviderOption {
455 fn visit<V: Visit>(&self, v: &mut V, key: &str, description: &'static str) {
456 v.some(key, format!("{:?}", self.0), description);
457 }
458
459 fn set(&mut self, key: &str, _value: &str) -> Result<()> {
460 config_err!("Can't set {key} from SQL")
461 }
462}
463
464#[derive(Debug)]
465struct DefaultCrsProvider {}
466
467impl CrsProvider for DefaultCrsProvider {
468 fn to_projjson(&self, crs_string: &str) -> Result<String> {
469 sedona_internal_err!(
470 "Can't convert {crs_string} to PROJJSON CRS (no CrsProvider registered)"
471 )
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478 use datafusion::config::ConfigField;
479
480 #[test]
481 fn test_execution_mode_parsing_basic_modes() {
482 let mut mode = ExecutionMode::PrepareNone;
483
484 assert!(mode.set("", "prepare_none").is_ok());
486 assert_eq!(mode, ExecutionMode::PrepareNone);
487
488 assert!(mode.set("", "prepare_build").is_ok());
489 assert_eq!(mode, ExecutionMode::PrepareBuild);
490
491 assert!(mode.set("", "prepare_probe").is_ok());
492 assert_eq!(mode, ExecutionMode::PrepareProbe);
493 }
494
495 #[test]
496 fn test_execution_mode_parsing_auto_modes() {
497 let mut mode = ExecutionMode::PrepareNone;
498
499 assert!(mode.set("", "auto").is_ok());
501 assert_eq!(mode, ExecutionMode::Speculative(1000));
502
503 assert!(mode.set("", "auto[10]").is_ok());
505 assert_eq!(mode, ExecutionMode::Speculative(10));
506
507 assert!(mode.set("", "auto[500]").is_ok());
508 assert_eq!(mode, ExecutionMode::Speculative(500));
509
510 assert!(mode.set("", "auto[1]").is_ok());
511 assert_eq!(mode, ExecutionMode::Speculative(1));
512 }
513
514 #[test]
515 fn test_execution_mode_parsing_case_insensitive() {
516 let mut mode = ExecutionMode::PrepareNone;
517
518 assert!(mode.set("", "PREPARE_NONE").is_ok());
520 assert_eq!(mode, ExecutionMode::PrepareNone);
521
522 assert!(mode.set("", "PREPARE_BUILD").is_ok());
523 assert_eq!(mode, ExecutionMode::PrepareBuild);
524
525 assert!(mode.set("", "PREPARE_PROBE").is_ok());
526 assert_eq!(mode, ExecutionMode::PrepareProbe);
527
528 assert!(mode.set("", "AUTO").is_ok());
529 assert_eq!(mode, ExecutionMode::Speculative(1000));
530
531 assert!(mode.set("", "Auto[50]").is_ok());
532 assert_eq!(mode, ExecutionMode::Speculative(50));
533 }
534
535 #[test]
536 fn test_execution_mode_parsing_invalid_formats() {
537 let mut mode = ExecutionMode::PrepareNone;
538
539 assert!(mode.set("", "invalid").is_err());
541 assert!(mode.set("", "").is_err());
542 assert!(mode.set("", "auto[0]").is_err());
543 assert!(mode.set("", "auto[]").is_err());
544 assert!(mode.set("", "auto[abc]").is_err());
545 assert!(mode.set("", "auto[10").is_err());
546 assert!(mode.set("", "auto10]").is_err());
547 assert!(mode.set("", "auto[10][20]").is_err());
548 assert!(mode.set("", "auto 10").is_err());
549 assert!(mode.set("", "auto:10").is_err());
550 assert!(mode.set("", "auto(10)").is_err());
551 }
552
553 #[test]
554 fn test_tg_index_type_parsing() {
555 let mut index_type = TgIndexType::YStripes;
556
557 assert!(index_type.set("", "natural").is_ok());
558 assert_eq!(index_type, TgIndexType::Natural);
559
560 assert!(index_type.set("", "Natural").is_ok());
561 assert_eq!(index_type, TgIndexType::Natural);
562
563 assert!(index_type.set("", "ystripes").is_ok());
564 assert_eq!(index_type, TgIndexType::YStripes);
565
566 assert!(index_type.set("", "YStripes").is_ok());
567 assert_eq!(index_type, TgIndexType::YStripes);
568 }
569
570 #[test]
571 fn test_tg_index_type_parsing_invalid_formats() {
572 let mut index_type = TgIndexType::YStripes;
573
574 assert!(index_type.set("", "unindexed").is_err());
575 assert!(index_type.set("", "invalid").is_err());
576 assert!(index_type.set("", "").is_err());
577 }
578
579 #[test]
580 fn test_num_spatial_partitions_config_parsing() {
581 let mut config = NumSpatialPartitionsConfig::Auto;
582
583 assert!(config.set("", "auto").is_ok());
584 assert_eq!(config, NumSpatialPartitionsConfig::Auto);
585
586 assert!(config.set("", "10").is_ok());
587 assert_eq!(config, NumSpatialPartitionsConfig::Fixed(10));
588
589 assert!(config.set("", "0").is_err());
590 assert!(config.set("", "invalid").is_err());
591 assert!(config.set("", "fixed[10]").is_err());
592 }
593
594 #[test]
595 fn test_default_crs_provider_returns_error() {
596 let provider = CrsProviderOption::default();
597 let result = provider.to_projjson("EPSG:4326");
598 assert!(result.is_err());
599 let err_msg = result.unwrap_err().to_string();
600 assert!(
601 err_msg.contains("Can't convert EPSG:4326 to PROJJSON CRS"),
602 "Unexpected error message: {err_msg}"
603 );
604 assert!(
605 err_msg.contains("no CrsProvider registered"),
606 "Unexpected error message: {err_msg}"
607 );
608 }
609
610 #[test]
611 fn test_crs_provider_option_set_from_sql_returns_error() {
612 let mut option = CrsProviderOption::default();
613 let result = option.set("sedona.crs_provider", "some_value");
614 assert!(result.is_err());
615 let err_msg = result.unwrap_err().to_string();
616 assert!(
617 err_msg.contains("Can't set sedona.crs_provider from SQL"),
618 "Unexpected error message: {err_msg}"
619 );
620 }
621}