Skip to main content

sedona_expr/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use datafusion_expr::Operator;
21use datafusion_physical_expr::{expressions::BinaryExpr, PhysicalExpr, ScalarFunctionExpr};
22
23/// Represents a parsed distance predicate with its constituent parts.
24///
25/// Distance predicates are spatial operations that determine whether two geometries
26/// are within a specified distance of each other. This struct holds the parsed
27/// components of such predicates for further processing.
28///
29/// ## Supported Distance Predicate Forms
30///
31/// This struct can represent the parsed components from any of these distance predicate forms:
32///
33/// 1. **Direct distance function**:
34///    - `st_dwithin(geom1, geom2, distance)` - Returns true if geometries are within the distance
35///
36/// 2. **Distance comparison (left-to-right)**:
37///    - `st_distance(geom1, geom2) <= distance` - Distance is less than or equal to threshold
38///    - `st_distance(geom1, geom2) < distance` - Distance is strictly less than threshold
39///
40/// 3. **Distance comparison (right-to-left)**:
41///    - `distance >= st_distance(geom1, geom2)` - Threshold is greater than or equal to distance
42///    - `distance > st_distance(geom1, geom2)` - Threshold is strictly greater than distance
43///
44/// All forms are logically equivalent but may appear differently in SQL queries. The parser
45/// normalizes them into this common structure for uniform processing.
46pub struct ParsedDistancePredicate {
47    /// The first geometry argument in the distance predicate
48    pub arg0: Arc<dyn PhysicalExpr>,
49    /// The second geometry argument in the distance predicate
50    pub arg1: Arc<dyn PhysicalExpr>,
51    /// The distance threshold argument (as a physical expression)
52    pub arg_distance: Arc<dyn PhysicalExpr>,
53}
54
55/// Parses a physical expression to extract distance predicate components.
56///
57/// This function recognizes and parses distance predicates in spatial queries.
58/// See [`ParsedDistancePredicate`] documentation for details on the supported
59/// distance predicate forms.
60///
61/// # Arguments
62///
63/// * `expr` - A physical expression that potentially represents a distance predicate
64///
65/// # Returns
66///
67/// * `Some(ParsedDistancePredicate)` - If the expression is a recognized distance predicate,
68///   returns the parsed components (two geometry arguments and the distance threshold)
69/// * `None` - If the expression is not a distance predicate or cannot be parsed
70///
71/// # Examples
72///
73/// The function can parse expressions like:
74/// - `st_dwithin(geometry_column, POINT(0 0), 100.0)`
75/// - `st_distance(geom_a, geom_b) <= 50.0`
76/// - `25.0 >= st_distance(geom_x, geom_y)`
77pub fn parse_distance_predicate(expr: &Arc<dyn PhysicalExpr>) -> Option<ParsedDistancePredicate> {
78    if let Some(binary_expr) = expr.as_any().downcast_ref::<BinaryExpr>() {
79        let left = binary_expr.left();
80        let right = binary_expr.right();
81        let (st_distance_expr, distance_bound_expr) = match *binary_expr.op() {
82            Operator::Lt | Operator::LtEq => (left, right),
83            Operator::Gt | Operator::GtEq => (right, left),
84            _ => return None,
85        };
86
87        if let Some(st_distance_expr) = st_distance_expr
88            .as_any()
89            .downcast_ref::<ScalarFunctionExpr>()
90        {
91            if st_distance_expr.fun().name() != "st_distance" {
92                return None;
93            }
94
95            let args = st_distance_expr.args();
96            assert!(args.len() >= 2);
97            Some(ParsedDistancePredicate {
98                arg0: Arc::clone(&args[0]),
99                arg1: Arc::clone(&args[1]),
100                arg_distance: Arc::clone(distance_bound_expr),
101            })
102        } else {
103            None
104        }
105    } else if let Some(st_dwithin_expr) = expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
106        if st_dwithin_expr.fun().name() != "st_dwithin" {
107            return None;
108        }
109
110        let args = st_dwithin_expr.args();
111        assert!(args.len() >= 3);
112        Some(ParsedDistancePredicate {
113            arg0: Arc::clone(&args[0]),
114            arg1: Arc::clone(&args[1]),
115            arg_distance: Arc::clone(&args[2]),
116        })
117    } else {
118        None
119    }
120}