1use num_traits::Float;
14use std::iter::Sum;
15
16#[derive(Debug, PartialEq, Eq)]
18pub struct IncompatiblePointsError;
19
20pub fn lp_distance<T>(x: &[T], y: &[T], p: i32) -> Result<T, IncompatiblePointsError>
25where
26 T: Float + Sum + From<i32>,
27{
28 if x.len() != y.len() {
29 Err(IncompatiblePointsError {})
30 } else {
31 Ok(x.iter()
32 .zip(y.iter())
33 .map(|(&a, &b)| T::powi(T::abs(b - a), p))
34 .sum::<T>()
35 .powf(T::one() / p.into()))
36 }
37}
38
39pub fn euclidean_distance<T>(x: &[T], y: &[T]) -> Result<T, IncompatiblePointsError>
43where
44 T: Float + Sum + From<i32>,
45{
46 lp_distance(x, y, 2)
47}
48
49pub fn maximum_distance<T: Float>(x: &[T], y: &[T]) -> Result<T, IncompatiblePointsError> {
53 if x.len() != y.len() {
54 Err(IncompatiblePointsError {})
55 } else {
56 Ok(x.iter()
57 .zip(y.iter())
58 .map(|(&a, &b)| T::abs(b - a))
59 .reduce(Float::max)
60 .unwrap_or(T::zero()))
61 }
62}
63
64fn euclidean_dot_product<T: Float>(angles1: &[T], angles2: &[T]) -> T {
70 let mut total = T::zero();
71 let mut sin_prod = T::one();
72 let d = angles1.len();
73 for (i, (t1, t2)) in angles1.iter().zip(angles2.iter()).enumerate() {
74 total = T::mul_add(sin_prod, t1.cos() * t2.cos(), total);
75 sin_prod = sin_prod * t1.sin() * t2.sin();
76 if i == d - 1 {
77 total = total + sin_prod;
78 }
79 }
80 total
81}
82
83pub fn angular_distance<T: Float>(
90 angles1: &[T],
91 angles2: &[T],
92) -> Result<T, IncompatiblePointsError> {
93 if angles1.len() != angles2.len() {
94 Err(IncompatiblePointsError {})
95 } else {
96 Ok(euclidean_dot_product(angles1, angles2).acos().abs())
97 }
98}
99
100pub fn polar_hyperbolic_distance<T: Float>(
106 r1: T,
107 angles1: &[T],
108 r2: T,
109 angles2: &[T],
110) -> Result<T, IncompatiblePointsError> {
111 if angles1.len() != angles2.len() {
112 Err(IncompatiblePointsError {})
113 } else {
114 let arg = (r1 - r2).cosh()
115 + (T::one() - euclidean_dot_product(angles1, angles2)) * r1.sinh() * r2.sinh();
116 Ok(if arg < T::one() {
117 T::zero()
118 } else {
119 arg.acosh()
120 })
121 }
122}
123
124pub fn hyperboloid_hyperbolic_distance<T: Float>(
130 x: &[T],
131 y: &[T],
132) -> Result<T, IncompatiblePointsError> {
133 if x.len() != y.len() {
134 Err(IncompatiblePointsError {})
135 } else {
136 let mut sum_x_squared = T::zero();
137 let mut sum_y_squared = T::zero();
138 let mut sum_xy = T::zero();
139 for (x_i, y_i) in x.iter().zip(y.iter()) {
140 sum_x_squared = T::mul_add(*x_i, *x_i, sum_x_squared);
141 sum_y_squared = T::mul_add(*y_i, *y_i, sum_y_squared);
142 sum_xy = T::mul_add(*x_i, *y_i, sum_xy);
143 }
144 let arg = (T::one() + sum_x_squared).sqrt() * (T::one() + sum_y_squared).sqrt() - sum_xy;
145 Ok(if arg < T::one() {
146 T::zero()
147 } else {
148 arg.acosh()
149 })
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use std::f64::consts::PI;
156
157 use super::{
158 IncompatiblePointsError, angular_distance, euclidean_distance,
159 hyperboloid_hyperbolic_distance, lp_distance, maximum_distance, polar_hyperbolic_distance,
160 };
161
162 #[test]
163 fn test_l4_dist() {
164 assert!(
165 (lp_distance(&[1., 2., 3.], &[5., 3., 1.], 4).unwrap()
166 - (256_f64 + 1. + 16.).sqrt().sqrt())
167 .abs()
168 < 1e-15
169 );
170 }
171
172 #[test]
173 fn test_l1_dist() {
174 assert!(
175 (lp_distance(&[1., 2., 3.], &[5., 3., 1.], 1).unwrap() - (4. + 1. + 2_f64)).abs()
176 < 1e-15
177 );
178 }
179
180 #[test]
181 fn test_lp_dist_incompatible_error() {
182 assert_eq!(
183 lp_distance(&[0., 0.], &[0., 0., 0.], 1),
184 Err(IncompatiblePointsError {})
185 );
186 assert_eq!(
187 lp_distance(&[0., 0.], &[0., 0., 0.], 4),
188 Err(IncompatiblePointsError {})
189 );
190 }
191
192 #[test]
193 fn test_euclidean_dist() {
194 assert!(
195 (euclidean_distance(&[1., 2., 3.], &[5., 3., 1.]).unwrap() - (16_f64 + 1. + 4.).sqrt())
196 .abs()
197 < 1e-15
198 );
199 }
200
201 #[test]
202 fn test_euclidean_dist_incompatible_error() {
203 assert_eq!(
204 euclidean_distance(&[0., 0.], &[0., 0., 0.]),
205 Err(IncompatiblePointsError {})
206 );
207 }
208
209 #[test]
210 fn test_maximum_dist() {
211 assert!((maximum_distance(&[1., 2., 3.], &[5., 3., 1.]).unwrap() - 4_f64).abs() < 1e-15);
212 }
213
214 #[test]
215 fn test_maximum_dist_incompatible_error() {
216 assert_eq!(
217 maximum_distance(&[0., 0.], &[0., 0., 0.]),
218 Err(IncompatiblePointsError {})
219 );
220 }
221
222 #[test]
223 fn test_angular_dist() {
224 assert!((angular_distance(&[0.3], &[0.5]).unwrap() - 0.2_f64).abs() < 1e-15);
225 assert!((angular_distance(&[0.5 * PI, PI], &[0.5 * PI, 0.]).unwrap() - PI).abs() < 1e-15);
226 assert!((angular_distance(&[0.2, 0., 1.], &[0., 0., 1.]).unwrap() - 0.2_f64).abs() < 1e-15);
227 assert!(
228 (angular_distance(&[2. * PI, 0., 1.], &[0., 0., 1.]).unwrap() - 0_f64).abs() < 1e-15
229 );
230 }
231
232 #[test]
233 fn test_angular_dist_incompatible_error() {
234 assert_eq!(
235 angular_distance(&[0.], &[0., 0.]),
236 Err(IncompatiblePointsError {})
237 );
238 }
239
240 #[test]
241 fn test_polar_hyperbolic_dist() {
242 assert_eq!(
243 polar_hyperbolic_distance(3., &[0.], 0.5, &[PI]).unwrap(),
244 3.5
245 );
246 }
247
248 #[test]
249 fn test_polar_hyperbolic_dist_incompatible_error() {
250 assert_eq!(
251 polar_hyperbolic_distance(1., &[0.], 1., &[0., 0.]),
252 Err(IncompatiblePointsError {})
253 );
254 }
255
256 #[test]
257 fn test_hyperboloid_dist() {
258 assert_eq!(
259 hyperboloid_hyperbolic_distance(&[3_f64.sinh(), 0.], &[-0.5_f64.sinh(), 0.]).unwrap(),
260 3.5
261 );
262 }
263
264 #[test]
265 fn test_hyperboloid_dist_inf() {
266 assert!(
267 hyperboloid_hyperbolic_distance(&[f64::INFINITY, 0.], &[0., 0.])
268 .unwrap()
269 .is_nan()
270 );
271 }
272 #[test]
273 fn test_hyperboloid_dist_length_error() {
274 assert_eq!(
275 euclidean_distance(&[0., 0.], &[0., 0., 0.]),
276 Err(IncompatiblePointsError {})
277 );
278 }
279}