Skip to main content

rustworkx_core/geometry/
distances.rs

1// Licensed under the Apache License, Version 2.0 (the "License"); you may
2// not use this file except in compliance with the License. You may obtain
3// a copy of the License at
4//
5//     http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10// License for the specific language governing permissions and limitations
11// under the License.
12
13use num_traits::Float;
14use std::iter::Sum;
15
16/// Error returned when two points have a different dimension.
17#[derive(Debug, PartialEq, Eq)]
18pub struct IncompatiblePointsError;
19
20/// Computes the L^`p` distance between `x` and `y`.
21///
22/// Works for any `p`>0. An [`IncompatiblePointsError`] is returned when `x` and `y` have different
23/// lengths.
24pub fn lp_distance<T>(x: &[T], y: &[T], p: i32) -> Result<T, IncompatiblePointsError>
25where
26    T: Float + Sum + From<i32>,
27{
28    if x.len() != y.len() {
29        Err(IncompatiblePointsError {})
30    } else {
31        Ok(x.iter()
32            .zip(y.iter())
33            .map(|(&a, &b)| T::powi(T::abs(b - a), p))
34            .sum::<T>()
35            .powf(T::one() / p.into()))
36    }
37}
38
39/// Computes the Euclidean distance between `x` and `y`.
40///
41/// An [`IncompatiblePointsError`] is returned when `x` and `y` have different lengths.
42pub fn euclidean_distance<T>(x: &[T], y: &[T]) -> Result<T, IncompatiblePointsError>
43where
44    T: Float + Sum + From<i32>,
45{
46    lp_distance(x, y, 2)
47}
48
49/// Computes the maximum distance (Chebyshev distance or L^infinity distance) between `x` and `y`.
50///
51/// An [`IncompatiblePointsError`] is returned when `x` and `y` have different lengths.
52pub fn maximum_distance<T: Float>(x: &[T], y: &[T]) -> Result<T, IncompatiblePointsError> {
53    if x.len() != y.len() {
54        Err(IncompatiblePointsError {})
55    } else {
56        Ok(x.iter()
57            .zip(y.iter())
58            .map(|(&a, &b)| T::abs(b - a))
59            .reduce(Float::max)
60            .unwrap_or(T::zero()))
61    }
62}
63
64/// Computes the Euclidean dot product between points the unit n-sphere, where n is the length of
65/// `angles1` and `angles2`.
66///
67/// No check is done on the lengths of `angles1` and `angles2`. The Euclidean dot product is also
68/// the cosine of the angular distance between `angles1` and `angles2`.
69fn euclidean_dot_product<T: Float>(angles1: &[T], angles2: &[T]) -> T {
70    let mut total = T::zero();
71    let mut sin_prod = T::one();
72    let d = angles1.len();
73    for (i, (t1, t2)) in angles1.iter().zip(angles2.iter()).enumerate() {
74        total = T::mul_add(sin_prod, t1.cos() * t2.cos(), total);
75        sin_prod = sin_prod * t1.sin() * t2.sin();
76        if i == d - 1 {
77            total = total + sin_prod;
78        }
79    }
80    total
81}
82
83/// Computes the distance between the points `angles1` and `angles2` on the unit n-sphere, where n
84/// is the length of `angles1` and `angles2`.
85///
86/// The last element of `angles1` and `angles2` is assumed to be in [0, 2pi] or [-pi, pi] (and the
87/// other elements are in [0, pi]). An [`IncompatiblePointsError`] is returned when `angles1` and
88/// `angles2` have different lengths.
89pub fn angular_distance<T: Float>(
90    angles1: &[T],
91    angles2: &[T],
92) -> Result<T, IncompatiblePointsError> {
93    if angles1.len() != angles2.len() {
94        Err(IncompatiblePointsError {})
95    } else {
96        Ok(euclidean_dot_product(angles1, angles2).acos().abs())
97    }
98}
99
100/// Computes the hyperbolic distance between two points in polar coordinates.
101///
102/// `r1` and `r2` are the distances to the origin and the last element of `angles1` and `angles2`
103/// is assumed to be in [0, 2pi] or [-pi, pi] (and the other elements are in [0, pi]). An
104/// [`IncompatiblePointsError`] is returned when `angles1` and `angles2` have different lengths.
105pub fn polar_hyperbolic_distance<T: Float>(
106    r1: T,
107    angles1: &[T],
108    r2: T,
109    angles2: &[T],
110) -> Result<T, IncompatiblePointsError> {
111    if angles1.len() != angles2.len() {
112        Err(IncompatiblePointsError {})
113    } else {
114        let arg = (r1 - r2).cosh()
115            + (T::one() - euclidean_dot_product(angles1, angles2)) * r1.sinh() * r2.sinh();
116        Ok(if arg < T::one() {
117            T::zero()
118        } else {
119            arg.acosh()
120        })
121    }
122}
123
124/// Computes the hyperbolic distance between the points `x1` and `x2` in the hyperboloid model.
125///
126/// The "time" coordinate (opposite sign in the metric) is inferred from the others and should not
127/// be included in `x1` and `x2`. An [`IncompatiblePointsError`] is returned when `x` and `y` have
128/// different lengths.
129pub fn hyperboloid_hyperbolic_distance<T: Float>(
130    x: &[T],
131    y: &[T],
132) -> Result<T, IncompatiblePointsError> {
133    if x.len() != y.len() {
134        Err(IncompatiblePointsError {})
135    } else {
136        let mut sum_x_squared = T::zero();
137        let mut sum_y_squared = T::zero();
138        let mut sum_xy = T::zero();
139        for (x_i, y_i) in x.iter().zip(y.iter()) {
140            sum_x_squared = T::mul_add(*x_i, *x_i, sum_x_squared);
141            sum_y_squared = T::mul_add(*y_i, *y_i, sum_y_squared);
142            sum_xy = T::mul_add(*x_i, *y_i, sum_xy);
143        }
144        let arg = (T::one() + sum_x_squared).sqrt() * (T::one() + sum_y_squared).sqrt() - sum_xy;
145        Ok(if arg < T::one() {
146            T::zero()
147        } else {
148            arg.acosh()
149        })
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use std::f64::consts::PI;
156
157    use super::{
158        IncompatiblePointsError, angular_distance, euclidean_distance,
159        hyperboloid_hyperbolic_distance, lp_distance, maximum_distance, polar_hyperbolic_distance,
160    };
161
162    #[test]
163    fn test_l4_dist() {
164        assert!(
165            (lp_distance(&[1., 2., 3.], &[5., 3., 1.], 4).unwrap()
166                - (256_f64 + 1. + 16.).sqrt().sqrt())
167            .abs()
168                < 1e-15
169        );
170    }
171
172    #[test]
173    fn test_l1_dist() {
174        assert!(
175            (lp_distance(&[1., 2., 3.], &[5., 3., 1.], 1).unwrap() - (4. + 1. + 2_f64)).abs()
176                < 1e-15
177        );
178    }
179
180    #[test]
181    fn test_lp_dist_incompatible_error() {
182        assert_eq!(
183            lp_distance(&[0., 0.], &[0., 0., 0.], 1),
184            Err(IncompatiblePointsError {})
185        );
186        assert_eq!(
187            lp_distance(&[0., 0.], &[0., 0., 0.], 4),
188            Err(IncompatiblePointsError {})
189        );
190    }
191
192    #[test]
193    fn test_euclidean_dist() {
194        assert!(
195            (euclidean_distance(&[1., 2., 3.], &[5., 3., 1.]).unwrap() - (16_f64 + 1. + 4.).sqrt())
196                .abs()
197                < 1e-15
198        );
199    }
200
201    #[test]
202    fn test_euclidean_dist_incompatible_error() {
203        assert_eq!(
204            euclidean_distance(&[0., 0.], &[0., 0., 0.]),
205            Err(IncompatiblePointsError {})
206        );
207    }
208
209    #[test]
210    fn test_maximum_dist() {
211        assert!((maximum_distance(&[1., 2., 3.], &[5., 3., 1.]).unwrap() - 4_f64).abs() < 1e-15);
212    }
213
214    #[test]
215    fn test_maximum_dist_incompatible_error() {
216        assert_eq!(
217            maximum_distance(&[0., 0.], &[0., 0., 0.]),
218            Err(IncompatiblePointsError {})
219        );
220    }
221
222    #[test]
223    fn test_angular_dist() {
224        assert!((angular_distance(&[0.3], &[0.5]).unwrap() - 0.2_f64).abs() < 1e-15);
225        assert!((angular_distance(&[0.5 * PI, PI], &[0.5 * PI, 0.]).unwrap() - PI).abs() < 1e-15);
226        assert!((angular_distance(&[0.2, 0., 1.], &[0., 0., 1.]).unwrap() - 0.2_f64).abs() < 1e-15);
227        assert!(
228            (angular_distance(&[2. * PI, 0., 1.], &[0., 0., 1.]).unwrap() - 0_f64).abs() < 1e-15
229        );
230    }
231
232    #[test]
233    fn test_angular_dist_incompatible_error() {
234        assert_eq!(
235            angular_distance(&[0.], &[0., 0.]),
236            Err(IncompatiblePointsError {})
237        );
238    }
239
240    #[test]
241    fn test_polar_hyperbolic_dist() {
242        assert_eq!(
243            polar_hyperbolic_distance(3., &[0.], 0.5, &[PI]).unwrap(),
244            3.5
245        );
246    }
247
248    #[test]
249    fn test_polar_hyperbolic_dist_incompatible_error() {
250        assert_eq!(
251            polar_hyperbolic_distance(1., &[0.], 1., &[0., 0.]),
252            Err(IncompatiblePointsError {})
253        );
254    }
255
256    #[test]
257    fn test_hyperboloid_dist() {
258        assert_eq!(
259            hyperboloid_hyperbolic_distance(&[3_f64.sinh(), 0.], &[-0.5_f64.sinh(), 0.]).unwrap(),
260            3.5
261        );
262    }
263
264    #[test]
265    fn test_hyperboloid_dist_inf() {
266        assert!(
267            hyperboloid_hyperbolic_distance(&[f64::INFINITY, 0.], &[0., 0.])
268                .unwrap()
269                .is_nan()
270        );
271    }
272    #[test]
273    fn test_hyperboloid_dist_length_error() {
274        assert_eq!(
275            euclidean_distance(&[0., 0.], &[0., 0., 0.]),
276            Err(IncompatiblePointsError {})
277        );
278    }
279}