Skip to main content

sedona_spatial_join/
spatial_predicate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17use std::sync::Arc;
18
19use datafusion_common::{JoinSide, Result};
20use datafusion_physical_expr::projection::update_expr;
21use datafusion_physical_expr::PhysicalExpr;
22use datafusion_physical_plan::projection::ProjectionExpr;
23use sedona_common::sedona_internal_err;
24
25/// Spatial predicate is the join condition of a spatial join. It can be a distance predicate,
26/// a relation predicate, or a KNN predicate.
27#[derive(Debug, Clone)]
28pub enum SpatialPredicate {
29    Distance(DistancePredicate),
30    Relation(RelationPredicate),
31    KNearestNeighbors(KNNPredicate),
32}
33
34impl std::fmt::Display for SpatialPredicate {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            SpatialPredicate::Distance(predicate) => write!(f, "{predicate}"),
38            SpatialPredicate::Relation(predicate) => write!(f, "{predicate}"),
39            SpatialPredicate::KNearestNeighbors(predicate) => write!(f, "{predicate}"),
40        }
41    }
42}
43
44/// Distance-based spatial join predicate.
45///
46/// This predicate represents a spatial join condition based on distance between geometries.
47/// It is used to find pairs of geometries from left and right tables where the distance
48/// between them is less than a specified threshold.
49///
50/// # Example SQL
51/// ```sql
52/// SELECT * FROM left_table l JOIN right_table r
53/// ON ST_Distance(l.geom, r.geom) < 100.0
54/// ```
55///
56/// # Fields
57/// * `left` - Expression to evaluate the left side geometry
58/// * `right` - Expression to evaluate the right side geometry
59/// * `distance` - Expression to evaluate the distance threshold
60/// * `distance_side` - Which side the distance expression belongs to (for column references)
61#[derive(Debug, Clone)]
62pub struct DistancePredicate {
63    /// The expression for evaluating the geometry value on the left side. The expression
64    /// should be evaluated directly on the left side batches.
65    pub left: Arc<dyn PhysicalExpr>,
66    /// The expression for evaluating the geometry value on the right side. The expression
67    /// should be evaluated directly on the right side batches.
68    pub right: Arc<dyn PhysicalExpr>,
69    /// The expression for evaluating the distance value. The expression
70    /// should be evaluated directly on the left or right side batches according to distance_side.
71    pub distance: Arc<dyn PhysicalExpr>,
72    /// The side of the distance expression. It could be JoinSide::None if the distance expression
73    /// is not a column reference. The most common case is that the distance expression is a
74    /// literal value.
75    pub distance_side: JoinSide,
76}
77
78impl DistancePredicate {
79    /// Creates a new distance predicate.
80    ///
81    /// # Arguments
82    /// * `left` - Expression for the left side geometry
83    /// * `right` - Expression for the right side geometry
84    /// * `distance` - Expression for the distance threshold
85    /// * `distance_side` - Which side (Left, Right, or None) the distance expression belongs to
86    pub fn new(
87        left: Arc<dyn PhysicalExpr>,
88        right: Arc<dyn PhysicalExpr>,
89        distance: Arc<dyn PhysicalExpr>,
90        distance_side: JoinSide,
91    ) -> Self {
92        Self {
93            left,
94            right,
95            distance,
96            distance_side,
97        }
98    }
99}
100
101impl std::fmt::Display for DistancePredicate {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        write!(
104            f,
105            "ST_Distance({}, {}) < {}",
106            self.left, self.right, self.distance
107        )
108    }
109}
110
111/// Spatial relation predicate for topological relationships.
112///
113/// This predicate represents a spatial join condition based on topological relationships
114/// between geometries, such as intersects, contains, within, etc. It follows the
115/// DE-9IM (Dimensionally Extended 9-Intersection Model) spatial relations.
116///
117/// # Example SQL
118/// ```sql
119/// SELECT * FROM buildings b JOIN parcels p
120/// ON ST_Intersects(b.geometry, p.geometry)
121/// ```
122///
123/// # Supported Relations
124/// * `Intersects` - Geometries share at least one point
125/// * `Contains` - Left geometry contains the right geometry
126/// * `Within` - Left geometry is within the right geometry
127/// * `Covers` - Left geometry covers the right geometry
128/// * `CoveredBy` - Left geometry is covered by the right geometry
129/// * `Touches` - Geometries touch at their boundaries
130/// * `Crosses` - Geometries cross each other
131/// * `Overlaps` - Geometries overlap
132/// * `Equals` - Geometries are spatially equal
133#[derive(Debug, Clone)]
134pub struct RelationPredicate {
135    /// The expression for evaluating the geometry value on the left side. The expression
136    /// should be evaluated directly on the left side batches.
137    pub left: Arc<dyn PhysicalExpr>,
138    /// The expression for evaluating the geometry value on the right side. The expression
139    /// should be evaluated directly on the right side batches.
140    pub right: Arc<dyn PhysicalExpr>,
141    /// The spatial relation type.
142    pub relation_type: SpatialRelationType,
143}
144
145impl RelationPredicate {
146    /// Creates a new spatial relation predicate.
147    ///
148    /// # Arguments
149    /// * `left` - Expression for the left side geometry
150    /// * `right` - Expression for the right side geometry
151    /// * `relation_type` - The type of spatial relationship to test
152    pub fn new(
153        left: Arc<dyn PhysicalExpr>,
154        right: Arc<dyn PhysicalExpr>,
155        relation_type: SpatialRelationType,
156    ) -> Self {
157        Self {
158            left,
159            right,
160            relation_type,
161        }
162    }
163}
164
165impl std::fmt::Display for RelationPredicate {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        write!(
168            f,
169            "ST_{}({}, {})",
170            self.relation_type, self.left, self.right
171        )
172    }
173}
174
175/// Type of spatial relation predicate.
176#[derive(Debug, Clone, Copy, PartialEq, Eq)]
177pub enum SpatialRelationType {
178    Intersects,
179    Contains,
180    Within,
181    Covers,
182    CoveredBy,
183    Touches,
184    Crosses,
185    Overlaps,
186    Equals,
187}
188
189impl SpatialRelationType {
190    /// Converts a function name string to a SpatialRelationType.
191    ///
192    /// # Arguments
193    /// * `name` - The spatial function name (e.g., "st_intersects", "st_contains")
194    ///
195    /// # Returns
196    /// * `Some(SpatialRelationType)` if the name is recognized
197    /// * `None` if the name is not a valid spatial relation function
198    pub fn from_name(name: &str) -> Option<Self> {
199        match name {
200            "st_intersects" => Some(SpatialRelationType::Intersects),
201            "st_contains" => Some(SpatialRelationType::Contains),
202            "st_within" => Some(SpatialRelationType::Within),
203            "st_covers" => Some(SpatialRelationType::Covers),
204            "st_coveredby" | "st_covered_by" => Some(SpatialRelationType::CoveredBy),
205            "st_touches" => Some(SpatialRelationType::Touches),
206            "st_crosses" => Some(SpatialRelationType::Crosses),
207            "st_overlaps" => Some(SpatialRelationType::Overlaps),
208            "st_equals" => Some(SpatialRelationType::Equals),
209            _ => None,
210        }
211    }
212
213    /// Returns the inverse spatial relation.
214    ///
215    /// Some spatial relations have natural inverses (e.g., Contains/Within),
216    /// while others are symmetric (e.g., Intersects, Touches, Equals).
217    ///
218    /// # Returns
219    /// The inverted spatial relation type
220    pub fn invert(&self) -> Self {
221        match self {
222            SpatialRelationType::Intersects => SpatialRelationType::Intersects,
223            SpatialRelationType::Covers => SpatialRelationType::CoveredBy,
224            SpatialRelationType::CoveredBy => SpatialRelationType::Covers,
225            SpatialRelationType::Contains => SpatialRelationType::Within,
226            SpatialRelationType::Within => SpatialRelationType::Contains,
227            SpatialRelationType::Touches => SpatialRelationType::Touches,
228            SpatialRelationType::Crosses => SpatialRelationType::Crosses,
229            SpatialRelationType::Overlaps => SpatialRelationType::Overlaps,
230            SpatialRelationType::Equals => SpatialRelationType::Equals,
231        }
232    }
233}
234
235impl std::fmt::Display for SpatialRelationType {
236    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237        match self {
238            SpatialRelationType::Intersects => write!(f, "intersects"),
239            SpatialRelationType::Contains => write!(f, "contains"),
240            SpatialRelationType::Within => write!(f, "within"),
241            SpatialRelationType::Covers => write!(f, "covers"),
242            SpatialRelationType::CoveredBy => write!(f, "coveredby"),
243            SpatialRelationType::Touches => write!(f, "touches"),
244            SpatialRelationType::Crosses => write!(f, "crosses"),
245            SpatialRelationType::Overlaps => write!(f, "overlaps"),
246            SpatialRelationType::Equals => write!(f, "equals"),
247        }
248    }
249}
250
251/// K-Nearest Neighbors (KNN) spatial join predicate.
252///
253/// This predicate represents a spatial join that finds the k nearest neighbors
254/// from the right side (object) table for each geometry in the left side (query) table.
255/// It's commonly used for proximity analysis and spatial recommendations.
256///
257/// # Example SQL
258/// ```sql
259/// SELECT * FROM restaurants r
260/// JOIN TABLE(ST_KNN(r.location, h.location, 5, false)) AS knn
261/// ON r.id = knn.restaurant_id
262/// ```
263///
264/// # Algorithm
265/// For each geometry in the left (query) side:
266/// 1. Find the k nearest geometries from the right (object) side
267/// 2. Use spatial index for efficient nearest neighbor search
268/// 3. Handle tie-breaking when multiple geometries have the same distance
269///
270/// # Performance Considerations
271/// * Uses R-tree spatial index for efficient search
272/// * Performance depends on k value and spatial distribution
273/// * Tie-breaking may require additional distance calculations
274///
275/// # Limitations
276/// * Currently only supports planar (Euclidean) distance calculations
277/// * Spheroid distance (use_spheroid=true) is not yet implemented
278#[derive(Debug, Clone)]
279pub struct KNNPredicate {
280    /// The expression for evaluating the geometry value on the left side (queries side).
281    /// The expression should be evaluated directly on the left side batches.
282    pub left: Arc<dyn PhysicalExpr>,
283    /// The expression for evaluating the geometry value on the right side (object side).
284    /// The expression should be evaluated directly on the right side batches.
285    pub right: Arc<dyn PhysicalExpr>,
286    /// The number of nearest neighbors to find (literal value).
287    pub k: u32,
288    /// Whether to use spheroid distance calculation or planar distance (literal value).
289    /// Currently must be false as spheroid distance is not yet implemented.
290    pub use_spheroid: bool,
291    /// Which execution plan side (Left or Right) the probe expression belongs to.
292    /// This is used to correctly assign build/probe plans in execution.
293    pub probe_side: JoinSide,
294}
295
296impl KNNPredicate {
297    /// Creates a new K-Nearest Neighbors predicate.
298    ///
299    /// # Arguments
300    /// * `left` - Expression for the left side (query) geometry
301    /// * `right` - Expression for the right side (object) geometry
302    /// * `k` - Number of nearest neighbors to find (literal value)
303    /// * `use_spheroid` - Whether to use spheroid distance (literal value, currently must be false)
304    /// * `probe_side` - Which execution plan side the probe expression belongs to, cannot be None
305    pub fn new(
306        left: Arc<dyn PhysicalExpr>,
307        right: Arc<dyn PhysicalExpr>,
308        k: u32,
309        use_spheroid: bool,
310        probe_side: JoinSide,
311    ) -> Self {
312        assert!(matches!(probe_side, JoinSide::Left | JoinSide::Right));
313        Self {
314            left,
315            right,
316            k,
317            use_spheroid,
318            probe_side,
319        }
320    }
321}
322
323impl std::fmt::Display for KNNPredicate {
324    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325        write!(
326            f,
327            "ST_KNN({}, {}, {}, {})",
328            self.left, self.right, self.k, self.use_spheroid
329        )
330    }
331}
332
333/// Common operations needed by the planner/executor to keep spatial predicates valid
334/// when join inputs are swapped or projected.
335pub trait SpatialPredicateTrait: Sized {
336    /// Returns a semantically equivalent predicate after the join children are swapped.
337    ///
338    /// Used by `SpatialJoinExec::swap_inputs` to keep the predicate aligned with the new
339    /// left/right inputs.
340    fn swap_for_swapped_children(&self) -> Self;
341
342    /// Rewrites the predicate to reference projected child expressions.
343    ///
344    /// Returns `Ok(None)` when the predicate cannot be expressed using the projected inputs
345    /// (so projection pushdown must be skipped).
346    fn update_for_child_projections(
347        &self,
348        projected_left_exprs: &[ProjectionExpr],
349        projected_right_exprs: &[ProjectionExpr],
350    ) -> Result<Option<Self>>;
351}
352
353impl SpatialPredicateTrait for SpatialPredicate {
354    fn swap_for_swapped_children(&self) -> Self {
355        match self {
356            SpatialPredicate::Relation(pred) => {
357                SpatialPredicate::Relation(pred.swap_for_swapped_children())
358            }
359            SpatialPredicate::Distance(pred) => {
360                SpatialPredicate::Distance(pred.swap_for_swapped_children())
361            }
362            SpatialPredicate::KNearestNeighbors(pred) => {
363                SpatialPredicate::KNearestNeighbors(pred.swap_for_swapped_children())
364            }
365        }
366    }
367
368    fn update_for_child_projections(
369        &self,
370        projected_left_exprs: &[ProjectionExpr],
371        projected_right_exprs: &[ProjectionExpr],
372    ) -> Result<Option<Self>> {
373        match self {
374            SpatialPredicate::Relation(pred) => Ok(pred
375                .update_for_child_projections(projected_left_exprs, projected_right_exprs)?
376                .map(SpatialPredicate::Relation)),
377            SpatialPredicate::Distance(pred) => Ok(pred
378                .update_for_child_projections(projected_left_exprs, projected_right_exprs)?
379                .map(SpatialPredicate::Distance)),
380            SpatialPredicate::KNearestNeighbors(pred) => Ok(pred
381                .update_for_child_projections(projected_left_exprs, projected_right_exprs)?
382                .map(SpatialPredicate::KNearestNeighbors)),
383        }
384    }
385}
386
387impl SpatialPredicateTrait for RelationPredicate {
388    fn swap_for_swapped_children(&self) -> Self {
389        Self {
390            left: Arc::clone(&self.right),
391            right: Arc::clone(&self.left),
392            relation_type: self.relation_type.invert(),
393        }
394    }
395
396    fn update_for_child_projections(
397        &self,
398        projected_left_exprs: &[ProjectionExpr],
399        projected_right_exprs: &[ProjectionExpr],
400    ) -> Result<Option<Self>> {
401        let Some(left) = update_expr(&self.left, projected_left_exprs, false)? else {
402            return Ok(None);
403        };
404        let Some(right) = update_expr(&self.right, projected_right_exprs, false)? else {
405            return Ok(None);
406        };
407
408        Ok(Some(Self {
409            left,
410            right,
411            relation_type: self.relation_type,
412        }))
413    }
414}
415
416impl SpatialPredicateTrait for DistancePredicate {
417    fn swap_for_swapped_children(&self) -> Self {
418        Self {
419            left: Arc::clone(&self.right),
420            right: Arc::clone(&self.left),
421            distance: Arc::clone(&self.distance),
422            distance_side: self.distance_side.negate(),
423        }
424    }
425
426    fn update_for_child_projections(
427        &self,
428        projected_left_exprs: &[ProjectionExpr],
429        projected_right_exprs: &[ProjectionExpr],
430    ) -> Result<Option<Self>> {
431        let Some(left) = update_expr(&self.left, projected_left_exprs, false)? else {
432            return Ok(None);
433        };
434        let Some(right) = update_expr(&self.right, projected_right_exprs, false)? else {
435            return Ok(None);
436        };
437
438        let distance = match self.distance_side {
439            JoinSide::Left => {
440                let Some(distance) = update_expr(&self.distance, projected_left_exprs, false)?
441                else {
442                    return Ok(None);
443                };
444                distance
445            }
446            JoinSide::Right => {
447                let Some(distance) = update_expr(&self.distance, projected_right_exprs, false)?
448                else {
449                    return Ok(None);
450                };
451                distance
452            }
453            JoinSide::None => Arc::clone(&self.distance),
454        };
455
456        Ok(Some(Self {
457            left,
458            right,
459            distance,
460            distance_side: self.distance_side,
461        }))
462    }
463}
464
465impl SpatialPredicateTrait for KNNPredicate {
466    fn swap_for_swapped_children(&self) -> Self {
467        Self {
468            // Keep query/object expressions stable; only flip which child is considered probe.
469            left: Arc::clone(&self.left),
470            right: Arc::clone(&self.right),
471            k: self.k,
472            use_spheroid: self.use_spheroid,
473            probe_side: self.probe_side.negate(),
474        }
475    }
476
477    fn update_for_child_projections(
478        &self,
479        projected_left_exprs: &[ProjectionExpr],
480        projected_right_exprs: &[ProjectionExpr],
481    ) -> Result<Option<Self>> {
482        let (query_exprs, object_exprs) = match self.probe_side {
483            JoinSide::Left => (projected_left_exprs, projected_right_exprs),
484            JoinSide::Right => (projected_right_exprs, projected_left_exprs),
485            JoinSide::None => {
486                return sedona_internal_err!("KNN join requires explicit probe_side designation")
487            }
488        };
489
490        let Some(left) = update_expr(&self.left, query_exprs, false)? else {
491            return Ok(None);
492        };
493        let Some(right) = update_expr(&self.right, object_exprs, false)? else {
494            return Ok(None);
495        };
496
497        Ok(Some(Self {
498            left,
499            right,
500            k: self.k,
501            use_spheroid: self.use_spheroid,
502            probe_side: self.probe_side,
503        }))
504    }
505}
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510
511    use datafusion_common::ScalarValue;
512    use datafusion_physical_expr::expressions::{Column, Literal};
513
514    fn proj_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
515        Arc::new(Column::new(name, index))
516    }
517
518    fn proj_expr(expr: Arc<dyn PhysicalExpr>, alias: &str) -> ProjectionExpr {
519        ProjectionExpr {
520            expr,
521            alias: alias.to_string(),
522        }
523    }
524
525    fn assert_is_column(expr: &Arc<dyn PhysicalExpr>, name: &str, index: usize) {
526        let col = expr
527            .as_any()
528            .downcast_ref::<Column>()
529            .expect("expected Column");
530        assert_eq!(col.name(), name);
531        assert_eq!(col.index(), index);
532    }
533
534    #[test]
535    fn relation_rewrite_success() -> Result<()> {
536        let on = SpatialPredicate::Relation(RelationPredicate {
537            left: proj_col("a", 1),
538            right: proj_col("x", 2),
539            relation_type: SpatialRelationType::Intersects,
540        });
541
542        let projected_left_exprs = vec![proj_expr(proj_col("a", 1), "a_new")];
543        let projected_right_exprs = vec![proj_expr(proj_col("x", 2), "x_new")];
544
545        let updated = on
546            .update_for_child_projections(&projected_left_exprs, &projected_right_exprs)?
547            .unwrap();
548
549        let SpatialPredicate::Relation(updated) = updated else {
550            unreachable!("expected relation")
551        };
552        assert_is_column(&updated.left, "a_new", 0);
553        assert_is_column(&updated.right, "x_new", 0);
554        Ok(())
555    }
556
557    #[test]
558    fn relation_rewrite_column_index_unchanged() -> Result<()> {
559        let on = SpatialPredicate::Relation(RelationPredicate {
560            left: proj_col("a", 0),
561            right: proj_col("x", 0),
562            relation_type: SpatialRelationType::Intersects,
563        });
564
565        let projected_left_exprs = vec![proj_expr(proj_col("a", 0), "a_new")];
566        let projected_right_exprs = vec![proj_expr(proj_col("x", 0), "x_new")];
567
568        let updated = on
569            .update_for_child_projections(&projected_left_exprs, &projected_right_exprs)?
570            .unwrap();
571
572        let SpatialPredicate::Relation(updated) = updated else {
573            unreachable!("expected relation")
574        };
575        assert_is_column(&updated.left, "a_new", 0);
576        assert_is_column(&updated.right, "x_new", 0);
577        Ok(())
578    }
579
580    #[test]
581    fn relation_rewrite_none_when_missing() -> Result<()> {
582        let on = SpatialPredicate::Relation(RelationPredicate {
583            left: proj_col("a", 1),
584            right: proj_col("x", 0),
585            relation_type: SpatialRelationType::Intersects,
586        });
587
588        let projected_left_exprs = vec![proj_expr(proj_col("a", 0), "a0")];
589        let projected_right_exprs = vec![proj_expr(proj_col("x", 0), "x0")];
590
591        assert!(on
592            .update_for_child_projections(&projected_left_exprs, &projected_right_exprs)?
593            .is_none());
594        Ok(())
595    }
596
597    #[test]
598    fn distance_rewrite_distance_side_left() -> Result<()> {
599        let on = SpatialPredicate::Distance(DistancePredicate {
600            left: proj_col("geom", 0),
601            right: proj_col("geom", 0),
602            distance: proj_col("dist", 1),
603            distance_side: JoinSide::Left,
604        });
605
606        let projected_left_exprs = vec![
607            proj_expr(proj_col("geom", 0), "geom_out"),
608            proj_expr(proj_col("dist", 1), "dist_out"),
609        ];
610        let projected_right_exprs = vec![proj_expr(proj_col("geom", 0), "geom_r")];
611
612        let updated = on
613            .update_for_child_projections(&projected_left_exprs, &projected_right_exprs)?
614            .unwrap();
615
616        let SpatialPredicate::Distance(updated) = updated else {
617            unreachable!("expected distance")
618        };
619        assert_is_column(&updated.left, "geom_out", 0);
620        assert_is_column(&updated.right, "geom_r", 0);
621        assert_is_column(&updated.distance, "dist_out", 1);
622        assert_eq!(updated.distance_side, JoinSide::Left);
623        Ok(())
624    }
625
626    #[test]
627    fn distance_rewrite_distance_side_none_keeps_literal() -> Result<()> {
628        let distance_lit: Arc<dyn PhysicalExpr> =
629            Arc::new(Literal::new(ScalarValue::Float64(Some(1.0))));
630
631        let on = SpatialPredicate::Distance(DistancePredicate {
632            left: proj_col("geom", 2),
633            right: proj_col("geom", 1),
634            distance: Arc::clone(&distance_lit),
635            distance_side: JoinSide::None,
636        });
637
638        let projected_left_exprs = vec![proj_expr(proj_col("geom", 2), "geom_out")];
639        let projected_right_exprs = vec![proj_expr(proj_col("geom", 1), "geom_r")];
640
641        let updated = on
642            .update_for_child_projections(&projected_left_exprs, &projected_right_exprs)?
643            .unwrap();
644
645        let SpatialPredicate::Distance(updated) = updated else {
646            unreachable!("expected distance")
647        };
648        assert_is_column(&updated.left, "geom_out", 0);
649        assert_is_column(&updated.right, "geom_r", 0);
650        assert!(Arc::ptr_eq(&updated.distance, &distance_lit));
651        assert_eq!(updated.distance_side, JoinSide::None);
652        Ok(())
653    }
654
655    #[test]
656    fn knn_rewrite_success_probe_left_and_right() -> Result<()> {
657        let base = SpatialPredicate::KNearestNeighbors(KNNPredicate {
658            left: proj_col("probe", 1),
659            right: proj_col("build", 2),
660            k: 10,
661            use_spheroid: false,
662            probe_side: JoinSide::Left,
663        });
664
665        let left_exprs = vec![proj_expr(proj_col("probe", 1), "probe_out")];
666        let right_exprs = vec![proj_expr(proj_col("build", 2), "build_out")];
667
668        let updated = base
669            .update_for_child_projections(&left_exprs, &right_exprs)?
670            .unwrap();
671        let SpatialPredicate::KNearestNeighbors(updated) = updated else {
672            unreachable!("expected knn")
673        };
674        assert_is_column(&updated.left, "probe_out", 0);
675        assert_is_column(&updated.right, "build_out", 0);
676        assert_eq!(updated.probe_side, JoinSide::Left);
677
678        let base = SpatialPredicate::KNearestNeighbors(KNNPredicate {
679            left: proj_col("probe", 1),
680            right: proj_col("build", 2),
681            k: 10,
682            use_spheroid: false,
683            probe_side: JoinSide::Right,
684        });
685
686        // For probe_side=Right: predicate.left (probe) is rewritten using right projections,
687        // and predicate.right (build) is rewritten using left projections.
688        let left_exprs = vec![proj_expr(proj_col("build", 2), "build_out_l")];
689        let right_exprs = vec![proj_expr(proj_col("probe", 1), "probe_out_r")];
690        let updated = base
691            .update_for_child_projections(&left_exprs, &right_exprs)?
692            .unwrap();
693        let SpatialPredicate::KNearestNeighbors(updated) = updated else {
694            unreachable!("expected knn")
695        };
696        assert_is_column(&updated.left, "probe_out_r", 0);
697        assert_is_column(&updated.right, "build_out_l", 0);
698        assert_eq!(updated.probe_side, JoinSide::Right);
699
700        Ok(())
701    }
702
703    #[test]
704    fn knn_rewrite_errors_on_none_probe_side() {
705        let on = SpatialPredicate::KNearestNeighbors(KNNPredicate {
706            left: proj_col("probe", 0),
707            right: proj_col("build", 0),
708            k: 10,
709            use_spheroid: false,
710            probe_side: JoinSide::None,
711        });
712
713        let left_exprs = vec![proj_expr(proj_col("probe", 0), "probe_out")];
714        let right_exprs = vec![proj_expr(proj_col("build", 0), "build_out")];
715
716        let err = on
717            .update_for_child_projections(&left_exprs, &right_exprs)
718            .expect_err("expected error");
719        let msg = err.to_string();
720        assert!(msg.contains("KNN join requires explicit probe_side designation"));
721    }
722}