1#![allow(dead_code)] pub const DEFAULT_CURVATURE: f32 = -1.0;
29
30const EPS: f32 = 1e-7;
32
33const MAX_NORM: f32 = 1.0 - 1e-5;
35
36pub fn poincare_distance(u: &[f32], v: &[f32], curvature: f32) -> f32 {
52 debug_assert_eq!(u.len(), v.len(), "Vector length mismatch");
53 debug_assert!(curvature < 0.0, "Curvature must be negative for hyperbolic space");
54
55 let sqrt_c = (-curvature).sqrt();
56
57 let norm_u_sq = squared_norm(u);
58 let norm_v_sq = squared_norm(v);
59
60 let norm_u_sq = norm_u_sq.min(MAX_NORM * MAX_NORM);
62 let norm_v_sq = norm_v_sq.min(MAX_NORM * MAX_NORM);
63
64 let diff_sq = squared_distance(u, v);
65
66 let denominator = (1.0 - norm_u_sq) * (1.0 - norm_v_sq);
67 let argument = 1.0 + 2.0 * diff_sq / (denominator + EPS);
68
69 let arcosh_val = (argument + (argument * argument - 1.0).max(0.0).sqrt()).ln();
71
72 arcosh_val / sqrt_c
73}
74
75pub fn exp_map(v: &[f32], curvature: f32) -> Vec<f32> {
89 let sqrt_c = (-curvature).sqrt();
90 let norm_v = l2_norm(v);
91
92 if norm_v < EPS {
93 return vec![0.0; v.len()];
94 }
95
96 let scale = (sqrt_c * norm_v / 2.0).tanh() / (sqrt_c * norm_v);
97
98 v.iter().map(|&x| x * scale).collect()
99}
100
101pub fn exp_map_at(x: &[f32], v: &[f32], curvature: f32) -> Vec<f32> {
110 let exp_v = exp_map(v, curvature);
111 mobius_add(x, &exp_v, curvature)
112}
113
114pub fn log_map(y: &[f32], curvature: f32) -> Vec<f32> {
127 let sqrt_c = (-curvature).sqrt();
128 let norm_y = l2_norm(y).min(MAX_NORM);
129
130 if norm_y < EPS {
131 return vec![0.0; y.len()];
132 }
133
134 let scale = (2.0 / sqrt_c) * (sqrt_c * norm_y).atanh() / norm_y;
135
136 y.iter().map(|&x| x * scale).collect()
137}
138
139pub fn log_map_at(x: &[f32], y: &[f32], curvature: f32) -> Vec<f32> {
148 let neg_x: Vec<f32> = x.iter().map(|&v| -v).collect();
149 let diff = mobius_add(&neg_x, y, curvature);
150 log_map(&diff, curvature)
151}
152
153pub fn mobius_add(u: &[f32], v: &[f32], curvature: f32) -> Vec<f32> {
170 debug_assert_eq!(u.len(), v.len(), "Vector length mismatch");
171
172 let c = -curvature;
173 let norm_u_sq = squared_norm(u);
174 let norm_v_sq = squared_norm(v);
175 let dot_uv = dot_product(u, v);
176
177 let numerator_u_coef = 1.0 + 2.0 * c * dot_uv + c * norm_v_sq;
178 let numerator_v_coef = 1.0 - c * norm_u_sq;
179 let denominator = 1.0 + 2.0 * c * dot_uv + c * c * norm_u_sq * norm_v_sq;
180
181 let mut result = Vec::with_capacity(u.len());
182 for i in 0..u.len() {
183 let value = (numerator_u_coef * u[i] + numerator_v_coef * v[i]) / (denominator + EPS);
184 result.push(value);
185 }
186
187 project_to_ball(&mut result);
189 result
190}
191
192pub fn mobius_scalar_mul(r: f32, x: &[f32], curvature: f32) -> Vec<f32> {
201 let sqrt_c = (-curvature).sqrt();
202 let norm_x = l2_norm(x).min(MAX_NORM);
203
204 if norm_x < EPS {
205 return vec![0.0; x.len()];
206 }
207
208 let scale = (r * (sqrt_c * norm_x).atanh()).tanh() / (sqrt_c * norm_x);
209
210 x.iter().map(|&v| v * scale).collect()
211}
212
213pub fn hyperbolic_midpoint(u: &[f32], v: &[f32], curvature: f32) -> Vec<f32> {
218 let log_v = log_map_at(u, v, curvature);
220
221 let half_log: Vec<f32> = log_v.iter().map(|&x| x * 0.5).collect();
223
224 exp_map_at(u, &half_log, curvature)
226}
227
228pub fn hyperbolic_centroid(points: &[&[f32]], curvature: f32) -> Option<Vec<f32>> {
232 if points.is_empty() {
233 return None;
234 }
235
236 let dim = points[0].len();
237
238 let mut centroid = vec![0.0; dim];
240 for point in points {
241 for (i, &v) in point.iter().enumerate() {
242 centroid[i] += v;
243 }
244 }
245 for x in centroid.iter_mut() {
246 *x /= points.len() as f32;
247 }
248 project_to_ball(&mut centroid);
249
250 for _ in 0..10 {
253 let mut grad = vec![0.0; dim];
254
255 for point in points {
256 let log_p = log_map_at(¢roid, point, curvature);
257 for (i, &v) in log_p.iter().enumerate() {
258 grad[i] += v;
259 }
260 }
261
262 for x in grad.iter_mut() {
264 *x /= points.len() as f32;
265 }
266
267 centroid = exp_map_at(¢roid, &grad, curvature);
269 }
270
271 Some(centroid)
272}
273
274pub fn euclidean_to_poincare(euclidean: &[f32], curvature: f32) -> Vec<f32> {
278 exp_map(euclidean, curvature)
279}
280
281pub fn poincare_to_euclidean(poincare: &[f32], curvature: f32) -> Vec<f32> {
285 log_map(poincare, curvature)
286}
287
288fn project_to_ball(v: &mut [f32]) {
290 let norm = l2_norm(v);
291 if norm >= MAX_NORM {
292 let scale = MAX_NORM / norm;
293 for x in v.iter_mut() {
294 *x *= scale;
295 }
296 }
297}
298
299#[inline]
301fn squared_norm(v: &[f32]) -> f32 {
302 v.iter().map(|x| x * x).sum()
303}
304
305#[inline]
307fn l2_norm(v: &[f32]) -> f32 {
308 squared_norm(v).sqrt()
309}
310
311#[inline]
313fn squared_distance(u: &[f32], v: &[f32]) -> f32 {
314 u.iter()
315 .zip(v.iter())
316 .map(|(a, b)| {
317 let diff = a - b;
318 diff * diff
319 })
320 .sum()
321}
322
323#[inline]
325fn dot_product(u: &[f32], v: &[f32]) -> f32 {
326 u.iter().zip(v.iter()).map(|(a, b)| a * b).sum()
327}
328
329pub fn conformal_factor(x: &[f32]) -> f32 {
334 let norm_sq = squared_norm(x).min(MAX_NORM * MAX_NORM);
335 2.0 / (1.0 - norm_sq)
336}
337
338pub fn is_in_ball(x: &[f32]) -> bool {
340 squared_norm(x) < 1.0
341}
342
343pub fn hyperbolic_angle(u: &[f32], v: &[f32]) -> f32 {
345 let norm_u = l2_norm(u);
346 let norm_v = l2_norm(v);
347
348 if norm_u < EPS || norm_v < EPS {
349 return 0.0;
350 }
351
352 let cos_angle = dot_product(u, v) / (norm_u * norm_v);
353 cos_angle.clamp(-1.0, 1.0).acos()
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use approx::assert_relative_eq;
360
361 #[test]
362 fn test_poincare_distance_same_point() {
363 let u = vec![0.1, 0.2, 0.3];
364 let dist = poincare_distance(&u, &u, DEFAULT_CURVATURE);
365 assert_relative_eq!(dist, 0.0, epsilon = 1e-5);
366 }
367
368 #[test]
369 fn test_poincare_distance_origin() {
370 let origin = vec![0.0, 0.0, 0.0];
371 let v = vec![0.5, 0.0, 0.0];
372 let dist = poincare_distance(&origin, &v, DEFAULT_CURVATURE);
373 assert!(dist > 0.0);
374 }
375
376 #[test]
377 fn test_exp_log_inverse() {
378 let v = vec![0.5, 0.3, 0.1];
379 let exp_v = exp_map(&v, DEFAULT_CURVATURE);
380 let log_exp_v = log_map(&exp_v, DEFAULT_CURVATURE);
381
382 for (a, b) in v.iter().zip(log_exp_v.iter()) {
383 assert_relative_eq!(a, b, epsilon = 1e-4);
384 }
385 }
386
387 #[test]
388 fn test_mobius_add_zero() {
389 let u = vec![0.1, 0.2, 0.3];
390 let zero = vec![0.0, 0.0, 0.0];
391
392 let result = mobius_add(&u, &zero, DEFAULT_CURVATURE);
393 for (a, b) in u.iter().zip(result.iter()) {
394 assert_relative_eq!(a, b, epsilon = 1e-5);
395 }
396 }
397
398 #[test]
399 fn test_mobius_add_stays_in_ball() {
400 let u = vec![0.8, 0.0, 0.0];
401 let v = vec![0.0, 0.8, 0.0];
402
403 let result = mobius_add(&u, &v, DEFAULT_CURVATURE);
404 let norm = l2_norm(&result);
405 assert!(norm < 1.0);
406 }
407
408 #[test]
409 fn test_hyperbolic_midpoint() {
410 let u = vec![0.1, 0.0, 0.0];
411 let v = vec![0.5, 0.0, 0.0];
412
413 let mid = hyperbolic_midpoint(&u, &v, DEFAULT_CURVATURE);
414
415 assert!(mid[0] > u[0] && mid[0] < v[0]);
417
418 let dist_u = poincare_distance(&u, &mid, DEFAULT_CURVATURE);
420 let dist_v = poincare_distance(&v, &mid, DEFAULT_CURVATURE);
421 assert_relative_eq!(dist_u, dist_v, epsilon = 1e-3);
422 }
423
424 #[test]
425 fn test_euclidean_poincare_conversion() {
426 let euclidean = vec![0.3, 0.2, 0.1];
427
428 let poincare = euclidean_to_poincare(&euclidean, DEFAULT_CURVATURE);
429 assert!(is_in_ball(&poincare));
430
431 let back = poincare_to_euclidean(&poincare, DEFAULT_CURVATURE);
432 for (a, b) in euclidean.iter().zip(back.iter()) {
433 assert_relative_eq!(a, b, epsilon = 1e-4);
434 }
435 }
436
437 #[test]
438 fn test_conformal_factor() {
439 let origin = vec![0.0, 0.0, 0.0];
440 assert_relative_eq!(conformal_factor(&origin), 2.0, epsilon = 1e-5);
441
442 let near_boundary = vec![0.99, 0.0, 0.0];
444 assert!(conformal_factor(&near_boundary) > 10.0);
445 }
446
447 #[test]
448 fn test_is_in_ball() {
449 assert!(is_in_ball(&[0.0, 0.0, 0.0]));
450 assert!(is_in_ball(&[0.5, 0.5, 0.0]));
451 assert!(!is_in_ball(&[1.0, 0.0, 0.0]));
452 assert!(!is_in_ball(&[0.6, 0.6, 0.6])); }
454}