sevensense_vector/
hyperbolic.rs

1//! Hyperbolic geometry operations for hierarchical embeddings.
2//!
3//! This module implements operations in the Poincare ball model of hyperbolic space,
4//! which is particularly useful for representing hierarchical relationships
5//! in embeddings (e.g., taxonomy trees, part-whole relationships).
6//!
7//! ## Poincare Ball Model
8//!
9//! The Poincare ball is the open unit ball B^n = {x in R^n : ||x|| < 1}
10//! equipped with the Riemannian metric:
11//!
12//! g_x = (2 / (1 - ||x||^2))^2 * I
13//!
14//! This metric causes distances to grow exponentially near the boundary,
15//! making it ideal for tree-like structures.
16//!
17//! ## Key Operations
18//!
19//! - `exp_map`: Project from tangent space to hyperbolic space
20//! - `log_map`: Project from hyperbolic space to tangent space
21//! - `mobius_add`: Gyrovector addition (parallel transport)
22//! - `poincare_distance`: Geodesic distance on the manifold
23
24#![allow(dead_code)]  // Hyperbolic geometry utilities for future use
25
26/// Default curvature for the Poincare ball model.
27/// Negative curvature corresponds to hyperbolic space.
28pub const DEFAULT_CURVATURE: f32 = -1.0;
29
30/// Epsilon for numerical stability.
31const EPS: f32 = 1e-7;
32
33/// Maximum norm to prevent points from reaching the boundary.
34const MAX_NORM: f32 = 1.0 - 1e-5;
35
36/// Compute the Poincare distance between two points in the Poincare ball.
37///
38/// The geodesic distance in the Poincare ball model is:
39///
40/// d(u, v) = (1/sqrt(-c)) * arcosh(1 + 2 * ||u - v||^2 / ((1 - ||u||^2) * (1 - ||v||^2)))
41///
42/// where c is the (negative) curvature.
43///
44/// # Arguments
45/// * `u` - First point in the Poincare ball
46/// * `v` - Second point in the Poincare ball
47/// * `curvature` - Curvature of the space (negative for hyperbolic)
48///
49/// # Returns
50/// The geodesic distance between u and v.
51pub fn poincare_distance(u: &[f32], v: &[f32], curvature: f32) -> f32 {
52    debug_assert_eq!(u.len(), v.len(), "Vector length mismatch");
53    debug_assert!(curvature < 0.0, "Curvature must be negative for hyperbolic space");
54
55    let sqrt_c = (-curvature).sqrt();
56
57    let norm_u_sq = squared_norm(u);
58    let norm_v_sq = squared_norm(v);
59
60    // Clamp norms to ensure they're inside the ball
61    let norm_u_sq = norm_u_sq.min(MAX_NORM * MAX_NORM);
62    let norm_v_sq = norm_v_sq.min(MAX_NORM * MAX_NORM);
63
64    let diff_sq = squared_distance(u, v);
65
66    let denominator = (1.0 - norm_u_sq) * (1.0 - norm_v_sq);
67    let argument = 1.0 + 2.0 * diff_sq / (denominator + EPS);
68
69    // arcosh(x) = ln(x + sqrt(x^2 - 1))
70    let arcosh_val = (argument + (argument * argument - 1.0).max(0.0).sqrt()).ln();
71
72    arcosh_val / sqrt_c
73}
74
75/// Exponential map: project from tangent space at origin to the Poincare ball.
76///
77/// Maps a Euclidean vector v from the tangent space T_0 B^n at the origin
78/// to a point on the Poincare ball.
79///
80/// exp_0(v) = tanh(sqrt(-c) * ||v|| / 2) * v / (sqrt(-c) * ||v||)
81///
82/// # Arguments
83/// * `v` - Vector in tangent space
84/// * `curvature` - Curvature of the space (negative for hyperbolic)
85///
86/// # Returns
87/// Point in the Poincare ball.
88pub fn exp_map(v: &[f32], curvature: f32) -> Vec<f32> {
89    let sqrt_c = (-curvature).sqrt();
90    let norm_v = l2_norm(v);
91
92    if norm_v < EPS {
93        return vec![0.0; v.len()];
94    }
95
96    let scale = (sqrt_c * norm_v / 2.0).tanh() / (sqrt_c * norm_v);
97
98    v.iter().map(|&x| x * scale).collect()
99}
100
101/// Exponential map from an arbitrary base point.
102///
103/// exp_x(v) = mobius_add(x, exp_0(v), c)
104///
105/// # Arguments
106/// * `x` - Base point in the Poincare ball
107/// * `v` - Vector in tangent space at x
108/// * `curvature` - Curvature of the space
109pub fn exp_map_at(x: &[f32], v: &[f32], curvature: f32) -> Vec<f32> {
110    let exp_v = exp_map(v, curvature);
111    mobius_add(x, &exp_v, curvature)
112}
113
114/// Logarithmic map: project from Poincare ball to tangent space at origin.
115///
116/// Inverse of the exponential map.
117///
118/// log_0(y) = (2 / sqrt(-c)) * arctanh(sqrt(-c) * ||y||) * y / ||y||
119///
120/// # Arguments
121/// * `y` - Point in the Poincare ball
122/// * `curvature` - Curvature of the space (negative for hyperbolic)
123///
124/// # Returns
125/// Vector in tangent space at origin.
126pub fn log_map(y: &[f32], curvature: f32) -> Vec<f32> {
127    let sqrt_c = (-curvature).sqrt();
128    let norm_y = l2_norm(y).min(MAX_NORM);
129
130    if norm_y < EPS {
131        return vec![0.0; y.len()];
132    }
133
134    let scale = (2.0 / sqrt_c) * (sqrt_c * norm_y).atanh() / norm_y;
135
136    y.iter().map(|&x| x * scale).collect()
137}
138
139/// Logarithmic map from an arbitrary base point.
140///
141/// log_x(y) = log_0(mobius_add(-x, y, c))
142///
143/// # Arguments
144/// * `x` - Base point in the Poincare ball
145/// * `y` - Target point in the Poincare ball
146/// * `curvature` - Curvature of the space
147pub fn log_map_at(x: &[f32], y: &[f32], curvature: f32) -> Vec<f32> {
148    let neg_x: Vec<f32> = x.iter().map(|&v| -v).collect();
149    let diff = mobius_add(&neg_x, y, curvature);
150    log_map(&diff, curvature)
151}
152
153/// Mobius addition (gyrovector addition).
154///
155/// The Mobius addition is the binary operation in the Poincare ball
156/// that generalizes vector addition. It can be seen as parallel transport
157/// followed by addition.
158///
159/// u ⊕ v = ((1 + 2c<u,v> + c||v||^2)u + (1 - c||u||^2)v) /
160///         (1 + 2c<u,v> + c^2||u||^2||v||^2)
161///
162/// # Arguments
163/// * `u` - First point in the Poincare ball
164/// * `v` - Second point in the Poincare ball
165/// * `curvature` - Curvature of the space (negative for hyperbolic)
166///
167/// # Returns
168/// Result of Mobius addition u ⊕ v.
169pub fn mobius_add(u: &[f32], v: &[f32], curvature: f32) -> Vec<f32> {
170    debug_assert_eq!(u.len(), v.len(), "Vector length mismatch");
171
172    let c = -curvature;
173    let norm_u_sq = squared_norm(u);
174    let norm_v_sq = squared_norm(v);
175    let dot_uv = dot_product(u, v);
176
177    let numerator_u_coef = 1.0 + 2.0 * c * dot_uv + c * norm_v_sq;
178    let numerator_v_coef = 1.0 - c * norm_u_sq;
179    let denominator = 1.0 + 2.0 * c * dot_uv + c * c * norm_u_sq * norm_v_sq;
180
181    let mut result = Vec::with_capacity(u.len());
182    for i in 0..u.len() {
183        let value = (numerator_u_coef * u[i] + numerator_v_coef * v[i]) / (denominator + EPS);
184        result.push(value);
185    }
186
187    // Project back into the ball if needed
188    project_to_ball(&mut result);
189    result
190}
191
192/// Mobius scalar multiplication.
193///
194/// r ⊗ x = (1/sqrt(c)) * tanh(r * arctanh(sqrt(c) * ||x||)) * x / ||x||
195///
196/// # Arguments
197/// * `r` - Scalar multiplier
198/// * `x` - Point in the Poincare ball
199/// * `curvature` - Curvature of the space
200pub fn mobius_scalar_mul(r: f32, x: &[f32], curvature: f32) -> Vec<f32> {
201    let sqrt_c = (-curvature).sqrt();
202    let norm_x = l2_norm(x).min(MAX_NORM);
203
204    if norm_x < EPS {
205        return vec![0.0; x.len()];
206    }
207
208    let scale = (r * (sqrt_c * norm_x).atanh()).tanh() / (sqrt_c * norm_x);
209
210    x.iter().map(|&v| v * scale).collect()
211}
212
213/// Compute the hyperbolic midpoint of two points.
214///
215/// The midpoint is the point on the geodesic between u and v
216/// that is equidistant from both.
217pub fn hyperbolic_midpoint(u: &[f32], v: &[f32], curvature: f32) -> Vec<f32> {
218    // log_u(v) gives direction and distance to v from u
219    let log_v = log_map_at(u, v, curvature);
220
221    // Scale by 0.5 to get halfway
222    let half_log: Vec<f32> = log_v.iter().map(|&x| x * 0.5).collect();
223
224    // Map back to the ball
225    exp_map_at(u, &half_log, curvature)
226}
227
228/// Compute the hyperbolic centroid of multiple points.
229///
230/// This is the Einstein (Frechet) mean in hyperbolic space.
231pub fn hyperbolic_centroid(points: &[&[f32]], curvature: f32) -> Option<Vec<f32>> {
232    if points.is_empty() {
233        return None;
234    }
235
236    let dim = points[0].len();
237
238    // Start with the Euclidean centroid projected onto the ball
239    let mut centroid = vec![0.0; dim];
240    for point in points {
241        for (i, &v) in point.iter().enumerate() {
242            centroid[i] += v;
243        }
244    }
245    for x in centroid.iter_mut() {
246        *x /= points.len() as f32;
247    }
248    project_to_ball(&mut centroid);
249
250    // Iteratively refine using gradient descent
251    // (simplified version - could use Riemannian gradient descent for better accuracy)
252    for _ in 0..10 {
253        let mut grad = vec![0.0; dim];
254
255        for point in points {
256            let log_p = log_map_at(&centroid, point, curvature);
257            for (i, &v) in log_p.iter().enumerate() {
258                grad[i] += v;
259            }
260        }
261
262        // Average gradient
263        for x in grad.iter_mut() {
264            *x /= points.len() as f32;
265        }
266
267        // Update centroid
268        centroid = exp_map_at(&centroid, &grad, curvature);
269    }
270
271    Some(centroid)
272}
273
274/// Convert a Euclidean embedding to a Poincare ball embedding.
275///
276/// Uses the exponential map at the origin.
277pub fn euclidean_to_poincare(euclidean: &[f32], curvature: f32) -> Vec<f32> {
278    exp_map(euclidean, curvature)
279}
280
281/// Convert a Poincare ball embedding to Euclidean space.
282///
283/// Uses the logarithmic map at the origin.
284pub fn poincare_to_euclidean(poincare: &[f32], curvature: f32) -> Vec<f32> {
285    log_map(poincare, curvature)
286}
287
288/// Project a point into the Poincare ball if it lies outside.
289fn project_to_ball(v: &mut [f32]) {
290    let norm = l2_norm(v);
291    if norm >= MAX_NORM {
292        let scale = MAX_NORM / norm;
293        for x in v.iter_mut() {
294            *x *= scale;
295        }
296    }
297}
298
299/// Compute the squared L2 norm of a vector.
300#[inline]
301fn squared_norm(v: &[f32]) -> f32 {
302    v.iter().map(|x| x * x).sum()
303}
304
305/// Compute the L2 norm of a vector.
306#[inline]
307fn l2_norm(v: &[f32]) -> f32 {
308    squared_norm(v).sqrt()
309}
310
311/// Compute the squared Euclidean distance between two vectors.
312#[inline]
313fn squared_distance(u: &[f32], v: &[f32]) -> f32 {
314    u.iter()
315        .zip(v.iter())
316        .map(|(a, b)| {
317            let diff = a - b;
318            diff * diff
319        })
320        .sum()
321}
322
323/// Compute the dot product of two vectors.
324#[inline]
325fn dot_product(u: &[f32], v: &[f32]) -> f32 {
326    u.iter().zip(v.iter()).map(|(a, b)| a * b).sum()
327}
328
329/// Conformal factor at a point (metric scaling).
330///
331/// The conformal factor lambda(x) = 2 / (1 - ||x||^2)
332/// determines how much distances are scaled at point x.
333pub fn conformal_factor(x: &[f32]) -> f32 {
334    let norm_sq = squared_norm(x).min(MAX_NORM * MAX_NORM);
335    2.0 / (1.0 - norm_sq)
336}
337
338/// Check if a point is inside the Poincare ball.
339pub fn is_in_ball(x: &[f32]) -> bool {
340    squared_norm(x) < 1.0
341}
342
343/// Compute hyperbolic angle between vectors in tangent space.
344pub fn hyperbolic_angle(u: &[f32], v: &[f32]) -> f32 {
345    let norm_u = l2_norm(u);
346    let norm_v = l2_norm(v);
347
348    if norm_u < EPS || norm_v < EPS {
349        return 0.0;
350    }
351
352    let cos_angle = dot_product(u, v) / (norm_u * norm_v);
353    cos_angle.clamp(-1.0, 1.0).acos()
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359    use approx::assert_relative_eq;
360
361    #[test]
362    fn test_poincare_distance_same_point() {
363        let u = vec![0.1, 0.2, 0.3];
364        let dist = poincare_distance(&u, &u, DEFAULT_CURVATURE);
365        assert_relative_eq!(dist, 0.0, epsilon = 1e-5);
366    }
367
368    #[test]
369    fn test_poincare_distance_origin() {
370        let origin = vec![0.0, 0.0, 0.0];
371        let v = vec![0.5, 0.0, 0.0];
372        let dist = poincare_distance(&origin, &v, DEFAULT_CURVATURE);
373        assert!(dist > 0.0);
374    }
375
376    #[test]
377    fn test_exp_log_inverse() {
378        let v = vec![0.5, 0.3, 0.1];
379        let exp_v = exp_map(&v, DEFAULT_CURVATURE);
380        let log_exp_v = log_map(&exp_v, DEFAULT_CURVATURE);
381
382        for (a, b) in v.iter().zip(log_exp_v.iter()) {
383            assert_relative_eq!(a, b, epsilon = 1e-4);
384        }
385    }
386
387    #[test]
388    fn test_mobius_add_zero() {
389        let u = vec![0.1, 0.2, 0.3];
390        let zero = vec![0.0, 0.0, 0.0];
391
392        let result = mobius_add(&u, &zero, DEFAULT_CURVATURE);
393        for (a, b) in u.iter().zip(result.iter()) {
394            assert_relative_eq!(a, b, epsilon = 1e-5);
395        }
396    }
397
398    #[test]
399    fn test_mobius_add_stays_in_ball() {
400        let u = vec![0.8, 0.0, 0.0];
401        let v = vec![0.0, 0.8, 0.0];
402
403        let result = mobius_add(&u, &v, DEFAULT_CURVATURE);
404        let norm = l2_norm(&result);
405        assert!(norm < 1.0);
406    }
407
408    #[test]
409    fn test_hyperbolic_midpoint() {
410        let u = vec![0.1, 0.0, 0.0];
411        let v = vec![0.5, 0.0, 0.0];
412
413        let mid = hyperbolic_midpoint(&u, &v, DEFAULT_CURVATURE);
414
415        // Midpoint should be between u and v
416        assert!(mid[0] > u[0] && mid[0] < v[0]);
417
418        // Distances should be approximately equal
419        let dist_u = poincare_distance(&u, &mid, DEFAULT_CURVATURE);
420        let dist_v = poincare_distance(&v, &mid, DEFAULT_CURVATURE);
421        assert_relative_eq!(dist_u, dist_v, epsilon = 1e-3);
422    }
423
424    #[test]
425    fn test_euclidean_poincare_conversion() {
426        let euclidean = vec![0.3, 0.2, 0.1];
427
428        let poincare = euclidean_to_poincare(&euclidean, DEFAULT_CURVATURE);
429        assert!(is_in_ball(&poincare));
430
431        let back = poincare_to_euclidean(&poincare, DEFAULT_CURVATURE);
432        for (a, b) in euclidean.iter().zip(back.iter()) {
433            assert_relative_eq!(a, b, epsilon = 1e-4);
434        }
435    }
436
437    #[test]
438    fn test_conformal_factor() {
439        let origin = vec![0.0, 0.0, 0.0];
440        assert_relative_eq!(conformal_factor(&origin), 2.0, epsilon = 1e-5);
441
442        // Near boundary, factor should be large
443        let near_boundary = vec![0.99, 0.0, 0.0];
444        assert!(conformal_factor(&near_boundary) > 10.0);
445    }
446
447    #[test]
448    fn test_is_in_ball() {
449        assert!(is_in_ball(&[0.0, 0.0, 0.0]));
450        assert!(is_in_ball(&[0.5, 0.5, 0.0]));
451        assert!(!is_in_ball(&[1.0, 0.0, 0.0]));
452        assert!(!is_in_ball(&[0.6, 0.6, 0.6])); // norm > 1
453    }
454}