1use std::sync::Arc;
18
19use datafusion_common::{JoinSide, Result};
20use datafusion_physical_expr::projection::update_expr;
21use datafusion_physical_expr::PhysicalExpr;
22use datafusion_physical_plan::projection::ProjectionExpr;
23use sedona_common::sedona_internal_err;
24
25#[derive(Debug, Clone)]
28pub enum SpatialPredicate {
29 Distance(DistancePredicate),
30 Relation(RelationPredicate),
31 KNearestNeighbors(KNNPredicate),
32}
33
34impl std::fmt::Display for SpatialPredicate {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 SpatialPredicate::Distance(predicate) => write!(f, "{predicate}"),
38 SpatialPredicate::Relation(predicate) => write!(f, "{predicate}"),
39 SpatialPredicate::KNearestNeighbors(predicate) => write!(f, "{predicate}"),
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
62pub struct DistancePredicate {
63 pub left: Arc<dyn PhysicalExpr>,
66 pub right: Arc<dyn PhysicalExpr>,
69 pub distance: Arc<dyn PhysicalExpr>,
72 pub distance_side: JoinSide,
76}
77
78impl DistancePredicate {
79 pub fn new(
87 left: Arc<dyn PhysicalExpr>,
88 right: Arc<dyn PhysicalExpr>,
89 distance: Arc<dyn PhysicalExpr>,
90 distance_side: JoinSide,
91 ) -> Self {
92 Self {
93 left,
94 right,
95 distance,
96 distance_side,
97 }
98 }
99}
100
101impl std::fmt::Display for DistancePredicate {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 write!(
104 f,
105 "ST_Distance({}, {}) < {}",
106 self.left, self.right, self.distance
107 )
108 }
109}
110
111#[derive(Debug, Clone)]
134pub struct RelationPredicate {
135 pub left: Arc<dyn PhysicalExpr>,
138 pub right: Arc<dyn PhysicalExpr>,
141 pub relation_type: SpatialRelationType,
143}
144
145impl RelationPredicate {
146 pub fn new(
153 left: Arc<dyn PhysicalExpr>,
154 right: Arc<dyn PhysicalExpr>,
155 relation_type: SpatialRelationType,
156 ) -> Self {
157 Self {
158 left,
159 right,
160 relation_type,
161 }
162 }
163}
164
165impl std::fmt::Display for RelationPredicate {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 write!(
168 f,
169 "ST_{}({}, {})",
170 self.relation_type, self.left, self.right
171 )
172 }
173}
174
175#[derive(Debug, Clone, Copy, PartialEq, Eq)]
177pub enum SpatialRelationType {
178 Intersects,
179 Contains,
180 Within,
181 Covers,
182 CoveredBy,
183 Touches,
184 Crosses,
185 Overlaps,
186 Equals,
187}
188
189impl SpatialRelationType {
190 pub fn from_name(name: &str) -> Option<Self> {
199 match name {
200 "st_intersects" => Some(SpatialRelationType::Intersects),
201 "st_contains" => Some(SpatialRelationType::Contains),
202 "st_within" => Some(SpatialRelationType::Within),
203 "st_covers" => Some(SpatialRelationType::Covers),
204 "st_coveredby" | "st_covered_by" => Some(SpatialRelationType::CoveredBy),
205 "st_touches" => Some(SpatialRelationType::Touches),
206 "st_crosses" => Some(SpatialRelationType::Crosses),
207 "st_overlaps" => Some(SpatialRelationType::Overlaps),
208 "st_equals" => Some(SpatialRelationType::Equals),
209 _ => None,
210 }
211 }
212
213 pub fn invert(&self) -> Self {
221 match self {
222 SpatialRelationType::Intersects => SpatialRelationType::Intersects,
223 SpatialRelationType::Covers => SpatialRelationType::CoveredBy,
224 SpatialRelationType::CoveredBy => SpatialRelationType::Covers,
225 SpatialRelationType::Contains => SpatialRelationType::Within,
226 SpatialRelationType::Within => SpatialRelationType::Contains,
227 SpatialRelationType::Touches => SpatialRelationType::Touches,
228 SpatialRelationType::Crosses => SpatialRelationType::Crosses,
229 SpatialRelationType::Overlaps => SpatialRelationType::Overlaps,
230 SpatialRelationType::Equals => SpatialRelationType::Equals,
231 }
232 }
233}
234
235impl std::fmt::Display for SpatialRelationType {
236 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237 match self {
238 SpatialRelationType::Intersects => write!(f, "intersects"),
239 SpatialRelationType::Contains => write!(f, "contains"),
240 SpatialRelationType::Within => write!(f, "within"),
241 SpatialRelationType::Covers => write!(f, "covers"),
242 SpatialRelationType::CoveredBy => write!(f, "coveredby"),
243 SpatialRelationType::Touches => write!(f, "touches"),
244 SpatialRelationType::Crosses => write!(f, "crosses"),
245 SpatialRelationType::Overlaps => write!(f, "overlaps"),
246 SpatialRelationType::Equals => write!(f, "equals"),
247 }
248 }
249}
250
251#[derive(Debug, Clone)]
279pub struct KNNPredicate {
280 pub left: Arc<dyn PhysicalExpr>,
283 pub right: Arc<dyn PhysicalExpr>,
286 pub k: u32,
288 pub use_spheroid: bool,
291 pub probe_side: JoinSide,
294}
295
296impl KNNPredicate {
297 pub fn new(
306 left: Arc<dyn PhysicalExpr>,
307 right: Arc<dyn PhysicalExpr>,
308 k: u32,
309 use_spheroid: bool,
310 probe_side: JoinSide,
311 ) -> Self {
312 assert!(matches!(probe_side, JoinSide::Left | JoinSide::Right));
313 Self {
314 left,
315 right,
316 k,
317 use_spheroid,
318 probe_side,
319 }
320 }
321}
322
323impl std::fmt::Display for KNNPredicate {
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 write!(
326 f,
327 "ST_KNN({}, {}, {}, {})",
328 self.left, self.right, self.k, self.use_spheroid
329 )
330 }
331}
332
333pub trait SpatialPredicateTrait: Sized {
336 fn swap_for_swapped_children(&self) -> Self;
341
342 fn update_for_child_projections(
347 &self,
348 projected_left_exprs: &[ProjectionExpr],
349 projected_right_exprs: &[ProjectionExpr],
350 ) -> Result<Option<Self>>;
351}
352
353impl SpatialPredicateTrait for SpatialPredicate {
354 fn swap_for_swapped_children(&self) -> Self {
355 match self {
356 SpatialPredicate::Relation(pred) => {
357 SpatialPredicate::Relation(pred.swap_for_swapped_children())
358 }
359 SpatialPredicate::Distance(pred) => {
360 SpatialPredicate::Distance(pred.swap_for_swapped_children())
361 }
362 SpatialPredicate::KNearestNeighbors(pred) => {
363 SpatialPredicate::KNearestNeighbors(pred.swap_for_swapped_children())
364 }
365 }
366 }
367
368 fn update_for_child_projections(
369 &self,
370 projected_left_exprs: &[ProjectionExpr],
371 projected_right_exprs: &[ProjectionExpr],
372 ) -> Result<Option<Self>> {
373 match self {
374 SpatialPredicate::Relation(pred) => Ok(pred
375 .update_for_child_projections(projected_left_exprs, projected_right_exprs)?
376 .map(SpatialPredicate::Relation)),
377 SpatialPredicate::Distance(pred) => Ok(pred
378 .update_for_child_projections(projected_left_exprs, projected_right_exprs)?
379 .map(SpatialPredicate::Distance)),
380 SpatialPredicate::KNearestNeighbors(pred) => Ok(pred
381 .update_for_child_projections(projected_left_exprs, projected_right_exprs)?
382 .map(SpatialPredicate::KNearestNeighbors)),
383 }
384 }
385}
386
387impl SpatialPredicateTrait for RelationPredicate {
388 fn swap_for_swapped_children(&self) -> Self {
389 Self {
390 left: Arc::clone(&self.right),
391 right: Arc::clone(&self.left),
392 relation_type: self.relation_type.invert(),
393 }
394 }
395
396 fn update_for_child_projections(
397 &self,
398 projected_left_exprs: &[ProjectionExpr],
399 projected_right_exprs: &[ProjectionExpr],
400 ) -> Result<Option<Self>> {
401 let Some(left) = update_expr(&self.left, projected_left_exprs, false)? else {
402 return Ok(None);
403 };
404 let Some(right) = update_expr(&self.right, projected_right_exprs, false)? else {
405 return Ok(None);
406 };
407
408 Ok(Some(Self {
409 left,
410 right,
411 relation_type: self.relation_type,
412 }))
413 }
414}
415
416impl SpatialPredicateTrait for DistancePredicate {
417 fn swap_for_swapped_children(&self) -> Self {
418 Self {
419 left: Arc::clone(&self.right),
420 right: Arc::clone(&self.left),
421 distance: Arc::clone(&self.distance),
422 distance_side: self.distance_side.negate(),
423 }
424 }
425
426 fn update_for_child_projections(
427 &self,
428 projected_left_exprs: &[ProjectionExpr],
429 projected_right_exprs: &[ProjectionExpr],
430 ) -> Result<Option<Self>> {
431 let Some(left) = update_expr(&self.left, projected_left_exprs, false)? else {
432 return Ok(None);
433 };
434 let Some(right) = update_expr(&self.right, projected_right_exprs, false)? else {
435 return Ok(None);
436 };
437
438 let distance = match self.distance_side {
439 JoinSide::Left => {
440 let Some(distance) = update_expr(&self.distance, projected_left_exprs, false)?
441 else {
442 return Ok(None);
443 };
444 distance
445 }
446 JoinSide::Right => {
447 let Some(distance) = update_expr(&self.distance, projected_right_exprs, false)?
448 else {
449 return Ok(None);
450 };
451 distance
452 }
453 JoinSide::None => Arc::clone(&self.distance),
454 };
455
456 Ok(Some(Self {
457 left,
458 right,
459 distance,
460 distance_side: self.distance_side,
461 }))
462 }
463}
464
465impl SpatialPredicateTrait for KNNPredicate {
466 fn swap_for_swapped_children(&self) -> Self {
467 Self {
468 left: Arc::clone(&self.left),
470 right: Arc::clone(&self.right),
471 k: self.k,
472 use_spheroid: self.use_spheroid,
473 probe_side: self.probe_side.negate(),
474 }
475 }
476
477 fn update_for_child_projections(
478 &self,
479 projected_left_exprs: &[ProjectionExpr],
480 projected_right_exprs: &[ProjectionExpr],
481 ) -> Result<Option<Self>> {
482 let (query_exprs, object_exprs) = match self.probe_side {
483 JoinSide::Left => (projected_left_exprs, projected_right_exprs),
484 JoinSide::Right => (projected_right_exprs, projected_left_exprs),
485 JoinSide::None => {
486 return sedona_internal_err!("KNN join requires explicit probe_side designation")
487 }
488 };
489
490 let Some(left) = update_expr(&self.left, query_exprs, false)? else {
491 return Ok(None);
492 };
493 let Some(right) = update_expr(&self.right, object_exprs, false)? else {
494 return Ok(None);
495 };
496
497 Ok(Some(Self {
498 left,
499 right,
500 k: self.k,
501 use_spheroid: self.use_spheroid,
502 probe_side: self.probe_side,
503 }))
504 }
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510
511 use datafusion_common::ScalarValue;
512 use datafusion_physical_expr::expressions::{Column, Literal};
513
514 fn proj_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
515 Arc::new(Column::new(name, index))
516 }
517
518 fn proj_expr(expr: Arc<dyn PhysicalExpr>, alias: &str) -> ProjectionExpr {
519 ProjectionExpr {
520 expr,
521 alias: alias.to_string(),
522 }
523 }
524
525 fn assert_is_column(expr: &Arc<dyn PhysicalExpr>, name: &str, index: usize) {
526 let col = expr
527 .as_any()
528 .downcast_ref::<Column>()
529 .expect("expected Column");
530 assert_eq!(col.name(), name);
531 assert_eq!(col.index(), index);
532 }
533
534 #[test]
535 fn relation_rewrite_success() -> Result<()> {
536 let on = SpatialPredicate::Relation(RelationPredicate {
537 left: proj_col("a", 1),
538 right: proj_col("x", 2),
539 relation_type: SpatialRelationType::Intersects,
540 });
541
542 let projected_left_exprs = vec![proj_expr(proj_col("a", 1), "a_new")];
543 let projected_right_exprs = vec![proj_expr(proj_col("x", 2), "x_new")];
544
545 let updated = on
546 .update_for_child_projections(&projected_left_exprs, &projected_right_exprs)?
547 .unwrap();
548
549 let SpatialPredicate::Relation(updated) = updated else {
550 unreachable!("expected relation")
551 };
552 assert_is_column(&updated.left, "a_new", 0);
553 assert_is_column(&updated.right, "x_new", 0);
554 Ok(())
555 }
556
557 #[test]
558 fn relation_rewrite_column_index_unchanged() -> Result<()> {
559 let on = SpatialPredicate::Relation(RelationPredicate {
560 left: proj_col("a", 0),
561 right: proj_col("x", 0),
562 relation_type: SpatialRelationType::Intersects,
563 });
564
565 let projected_left_exprs = vec![proj_expr(proj_col("a", 0), "a_new")];
566 let projected_right_exprs = vec![proj_expr(proj_col("x", 0), "x_new")];
567
568 let updated = on
569 .update_for_child_projections(&projected_left_exprs, &projected_right_exprs)?
570 .unwrap();
571
572 let SpatialPredicate::Relation(updated) = updated else {
573 unreachable!("expected relation")
574 };
575 assert_is_column(&updated.left, "a_new", 0);
576 assert_is_column(&updated.right, "x_new", 0);
577 Ok(())
578 }
579
580 #[test]
581 fn relation_rewrite_none_when_missing() -> Result<()> {
582 let on = SpatialPredicate::Relation(RelationPredicate {
583 left: proj_col("a", 1),
584 right: proj_col("x", 0),
585 relation_type: SpatialRelationType::Intersects,
586 });
587
588 let projected_left_exprs = vec![proj_expr(proj_col("a", 0), "a0")];
589 let projected_right_exprs = vec![proj_expr(proj_col("x", 0), "x0")];
590
591 assert!(on
592 .update_for_child_projections(&projected_left_exprs, &projected_right_exprs)?
593 .is_none());
594 Ok(())
595 }
596
597 #[test]
598 fn distance_rewrite_distance_side_left() -> Result<()> {
599 let on = SpatialPredicate::Distance(DistancePredicate {
600 left: proj_col("geom", 0),
601 right: proj_col("geom", 0),
602 distance: proj_col("dist", 1),
603 distance_side: JoinSide::Left,
604 });
605
606 let projected_left_exprs = vec![
607 proj_expr(proj_col("geom", 0), "geom_out"),
608 proj_expr(proj_col("dist", 1), "dist_out"),
609 ];
610 let projected_right_exprs = vec![proj_expr(proj_col("geom", 0), "geom_r")];
611
612 let updated = on
613 .update_for_child_projections(&projected_left_exprs, &projected_right_exprs)?
614 .unwrap();
615
616 let SpatialPredicate::Distance(updated) = updated else {
617 unreachable!("expected distance")
618 };
619 assert_is_column(&updated.left, "geom_out", 0);
620 assert_is_column(&updated.right, "geom_r", 0);
621 assert_is_column(&updated.distance, "dist_out", 1);
622 assert_eq!(updated.distance_side, JoinSide::Left);
623 Ok(())
624 }
625
626 #[test]
627 fn distance_rewrite_distance_side_none_keeps_literal() -> Result<()> {
628 let distance_lit: Arc<dyn PhysicalExpr> =
629 Arc::new(Literal::new(ScalarValue::Float64(Some(1.0))));
630
631 let on = SpatialPredicate::Distance(DistancePredicate {
632 left: proj_col("geom", 2),
633 right: proj_col("geom", 1),
634 distance: Arc::clone(&distance_lit),
635 distance_side: JoinSide::None,
636 });
637
638 let projected_left_exprs = vec![proj_expr(proj_col("geom", 2), "geom_out")];
639 let projected_right_exprs = vec![proj_expr(proj_col("geom", 1), "geom_r")];
640
641 let updated = on
642 .update_for_child_projections(&projected_left_exprs, &projected_right_exprs)?
643 .unwrap();
644
645 let SpatialPredicate::Distance(updated) = updated else {
646 unreachable!("expected distance")
647 };
648 assert_is_column(&updated.left, "geom_out", 0);
649 assert_is_column(&updated.right, "geom_r", 0);
650 assert!(Arc::ptr_eq(&updated.distance, &distance_lit));
651 assert_eq!(updated.distance_side, JoinSide::None);
652 Ok(())
653 }
654
655 #[test]
656 fn knn_rewrite_success_probe_left_and_right() -> Result<()> {
657 let base = SpatialPredicate::KNearestNeighbors(KNNPredicate {
658 left: proj_col("probe", 1),
659 right: proj_col("build", 2),
660 k: 10,
661 use_spheroid: false,
662 probe_side: JoinSide::Left,
663 });
664
665 let left_exprs = vec![proj_expr(proj_col("probe", 1), "probe_out")];
666 let right_exprs = vec![proj_expr(proj_col("build", 2), "build_out")];
667
668 let updated = base
669 .update_for_child_projections(&left_exprs, &right_exprs)?
670 .unwrap();
671 let SpatialPredicate::KNearestNeighbors(updated) = updated else {
672 unreachable!("expected knn")
673 };
674 assert_is_column(&updated.left, "probe_out", 0);
675 assert_is_column(&updated.right, "build_out", 0);
676 assert_eq!(updated.probe_side, JoinSide::Left);
677
678 let base = SpatialPredicate::KNearestNeighbors(KNNPredicate {
679 left: proj_col("probe", 1),
680 right: proj_col("build", 2),
681 k: 10,
682 use_spheroid: false,
683 probe_side: JoinSide::Right,
684 });
685
686 let left_exprs = vec![proj_expr(proj_col("build", 2), "build_out_l")];
689 let right_exprs = vec![proj_expr(proj_col("probe", 1), "probe_out_r")];
690 let updated = base
691 .update_for_child_projections(&left_exprs, &right_exprs)?
692 .unwrap();
693 let SpatialPredicate::KNearestNeighbors(updated) = updated else {
694 unreachable!("expected knn")
695 };
696 assert_is_column(&updated.left, "probe_out_r", 0);
697 assert_is_column(&updated.right, "build_out_l", 0);
698 assert_eq!(updated.probe_side, JoinSide::Right);
699
700 Ok(())
701 }
702
703 #[test]
704 fn knn_rewrite_errors_on_none_probe_side() {
705 let on = SpatialPredicate::KNearestNeighbors(KNNPredicate {
706 left: proj_col("probe", 0),
707 right: proj_col("build", 0),
708 k: 10,
709 use_spheroid: false,
710 probe_side: JoinSide::None,
711 });
712
713 let left_exprs = vec![proj_expr(proj_col("probe", 0), "probe_out")];
714 let right_exprs = vec![proj_expr(proj_col("build", 0), "build_out")];
715
716 let err = on
717 .update_for_child_projections(&left_exprs, &right_exprs)
718 .expect_err("expected error");
719 let msg = err.to_string();
720 assert!(msg.contains("KNN join requires explicit probe_side designation"));
721 }
722}