ruvector_math/product_manifold/
manifold.rs1use crate::error::{MathError, Result};
4use crate::spherical::SphericalSpace;
5use crate::utils::{dot, norm, EPS};
6use super::config::ProductManifoldConfig;
7
8#[derive(Debug, Clone)]
10pub struct ProductManifold {
11 config: ProductManifoldConfig,
12 spherical: Option<SphericalSpace>,
13}
14
15impl ProductManifold {
16 pub fn new(euclidean_dim: usize, hyperbolic_dim: usize, spherical_dim: usize) -> Self {
23 let config = ProductManifoldConfig::new(euclidean_dim, hyperbolic_dim, spherical_dim);
24 let spherical = if spherical_dim > 0 {
25 Some(SphericalSpace::new(spherical_dim))
26 } else {
27 None
28 };
29
30 Self { config, spherical }
31 }
32
33 pub fn from_config(config: ProductManifoldConfig) -> Self {
35 let spherical = if config.spherical_dim > 0 {
36 Some(SphericalSpace::new(config.spherical_dim))
37 } else {
38 None
39 };
40
41 Self { config, spherical }
42 }
43
44 pub fn config(&self) -> &ProductManifoldConfig {
46 &self.config
47 }
48
49 pub fn dim(&self) -> usize {
51 self.config.total_dim()
52 }
53
54 pub fn euclidean_component<'a>(&self, point: &'a [f64]) -> &'a [f64] {
56 let (e_range, _, _) = self.config.component_ranges();
57 &point[e_range]
58 }
59
60 pub fn hyperbolic_component<'a>(&self, point: &'a [f64]) -> &'a [f64] {
62 let (_, h_range, _) = self.config.component_ranges();
63 &point[h_range]
64 }
65
66 pub fn spherical_component<'a>(&self, point: &'a [f64]) -> &'a [f64] {
68 let (_, _, s_range) = self.config.component_ranges();
69 &point[s_range]
70 }
71
72 pub fn project(&self, point: &[f64]) -> Result<Vec<f64>> {
78 if point.len() != self.dim() {
79 return Err(MathError::dimension_mismatch(self.dim(), point.len()));
80 }
81
82 let mut result = point.to_vec();
83 let (_e_range, h_range, s_range) = self.config.component_ranges();
84
85 if !h_range.is_empty() {
88 let h_part = &mut result[h_range.clone()];
89 let h_norm: f64 = h_part.iter().map(|&x| x * x).sum::<f64>().sqrt();
90
91 if h_norm >= 1.0 - EPS {
92 let scale = (1.0 - EPS) / h_norm;
93 for x in h_part.iter_mut() {
94 *x *= scale;
95 }
96 }
97 }
98
99 if !s_range.is_empty() {
101 let s_part = &mut result[s_range.clone()];
102 let s_norm: f64 = s_part.iter().map(|&x| x * x).sum::<f64>().sqrt();
103
104 if s_norm > EPS {
105 for x in s_part.iter_mut() {
106 *x /= s_norm;
107 }
108 } else {
109 s_part[0] = 1.0;
111 for x in s_part[1..].iter_mut() {
112 *x = 0.0;
113 }
114 }
115 }
116
117 Ok(result)
118 }
119
120 #[inline]
124 pub fn distance(&self, x: &[f64], y: &[f64]) -> Result<f64> {
125 if x.len() != self.dim() || y.len() != self.dim() {
126 return Err(MathError::dimension_mismatch(self.dim(), x.len()));
127 }
128
129 let (w_e, w_h, w_s) = self.config.component_weights;
130 let (e_range, h_range, s_range) = self.config.component_ranges();
131
132 let mut dist_sq = 0.0;
133
134 if !e_range.is_empty() && w_e > 0.0 {
136 let d_e = self.euclidean_distance_sq(&x[e_range.clone()], &y[e_range.clone()]);
137 dist_sq += w_e * d_e;
138 }
139
140 if !h_range.is_empty() && w_h > 0.0 {
142 let x_h = &x[h_range.clone()];
143 let y_h = &y[h_range.clone()];
144 let d_h = self.poincare_distance(x_h, y_h)?;
145 dist_sq += w_h * d_h * d_h;
146 }
147
148 if !s_range.is_empty() && w_s > 0.0 {
150 let x_s = &x[s_range.clone()];
151 let y_s = &y[s_range.clone()];
152 let d_s = self.spherical_distance(x_s, y_s)?;
153 dist_sq += w_s * d_s * d_s;
154 }
155
156 Ok(dist_sq.sqrt())
157 }
158
159 #[inline(always)]
161 fn euclidean_distance_sq(&self, x: &[f64], y: &[f64]) -> f64 {
162 let len = x.len();
163 let chunks = len / 4;
164 let remainder = len % 4;
165
166 let mut sum0 = 0.0f64;
167 let mut sum1 = 0.0f64;
168 let mut sum2 = 0.0f64;
169 let mut sum3 = 0.0f64;
170
171 for i in 0..chunks {
173 let base = i * 4;
174 let d0 = x[base] - y[base];
175 let d1 = x[base + 1] - y[base + 1];
176 let d2 = x[base + 2] - y[base + 2];
177 let d3 = x[base + 3] - y[base + 3];
178 sum0 += d0 * d0;
179 sum1 += d1 * d1;
180 sum2 += d2 * d2;
181 sum3 += d3 * d3;
182 }
183
184 let base = chunks * 4;
186 for i in 0..remainder {
187 let d = x[base + i] - y[base + i];
188 sum0 += d * d;
189 }
190
191 sum0 + sum1 + sum2 + sum3
192 }
193
194 #[inline]
200 fn poincare_distance(&self, x: &[f64], y: &[f64]) -> Result<f64> {
201 let len = x.len();
202 let chunks = len / 4;
203 let remainder = len % 4;
204
205 let mut x_norm_sq = 0.0f64;
207 let mut y_norm_sq = 0.0f64;
208 let mut diff_sq = 0.0f64;
209
210 for i in 0..chunks {
212 let base = i * 4;
213
214 let x0 = x[base];
215 let x1 = x[base + 1];
216 let x2 = x[base + 2];
217 let x3 = x[base + 3];
218
219 let y0 = y[base];
220 let y1 = y[base + 1];
221 let y2 = y[base + 2];
222 let y3 = y[base + 3];
223
224 x_norm_sq += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
225 y_norm_sq += y0 * y0 + y1 * y1 + y2 * y2 + y3 * y3;
226
227 let d0 = x0 - y0;
228 let d1 = x1 - y1;
229 let d2 = x2 - y2;
230 let d3 = x3 - y3;
231 diff_sq += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
232 }
233
234 let base = chunks * 4;
236 for i in 0..remainder {
237 let xi = x[base + i];
238 let yi = y[base + i];
239 x_norm_sq += xi * xi;
240 y_norm_sq += yi * yi;
241 let d = xi - yi;
242 diff_sq += d * d;
243 }
244
245 let denom = (1.0 - x_norm_sq).max(EPS) * (1.0 - y_norm_sq).max(EPS);
246 let arg = 1.0 + 2.0 * diff_sq / denom;
247
248 let c = (-self.config.hyperbolic_curvature).sqrt();
250 Ok(arg.max(1.0).acosh() / c)
251 }
252
253 fn spherical_distance(&self, x: &[f64], y: &[f64]) -> Result<f64> {
255 let cos_angle = dot(x, y).clamp(-1.0, 1.0);
256 let c = self.config.spherical_curvature.sqrt();
257 Ok(cos_angle.acos() / c)
258 }
259
260 pub fn exp_map(&self, x: &[f64], v: &[f64]) -> Result<Vec<f64>> {
262 if x.len() != self.dim() || v.len() != self.dim() {
263 return Err(MathError::dimension_mismatch(self.dim(), x.len()));
264 }
265
266 let mut result = vec![0.0; self.dim()];
267 let (e_range, h_range, s_range) = self.config.component_ranges();
268
269 for i in e_range.clone() {
271 result[i] = x[i] + v[i];
272 }
273
274 if !h_range.is_empty() {
276 let x_h = &x[h_range.clone()];
277 let v_h = &v[h_range.clone()];
278 let exp_h = self.poincare_exp_map(x_h, v_h)?;
279 for (i, val) in h_range.clone().zip(exp_h.iter()) {
280 result[i] = *val;
281 }
282 }
283
284 if !s_range.is_empty() {
286 let x_s = &x[s_range.clone()];
287 let v_s = &v[s_range.clone()];
288 let exp_s = self.spherical_exp_map(x_s, v_s)?;
289 for (i, val) in s_range.clone().zip(exp_s.iter()) {
290 result[i] = *val;
291 }
292 }
293
294 self.project(&result)
295 }
296
297 fn poincare_exp_map(&self, x: &[f64], v: &[f64]) -> Result<Vec<f64>> {
299 let c = -self.config.hyperbolic_curvature;
300 let sqrt_c = c.sqrt();
301
302 let x_norm_sq: f64 = x.iter().map(|&xi| xi * xi).sum();
303 let v_norm: f64 = v.iter().map(|&vi| vi * vi).sum::<f64>().sqrt();
304
305 if v_norm < EPS {
306 return Ok(x.to_vec());
307 }
308
309 let lambda_x = 2.0 / (1.0 - c * x_norm_sq).max(EPS);
310 let norm_v = lambda_x * v_norm;
311
312 let t = (sqrt_c * norm_v).tanh() / (sqrt_c * v_norm);
313
314 let tv: Vec<f64> = v.iter().map(|&vi| t * vi).collect();
316 self.mobius_add(x, &tv, c)
317 }
318
319 fn mobius_add(&self, x: &[f64], y: &[f64], c: f64) -> Result<Vec<f64>> {
321 let x_norm_sq: f64 = x.iter().map(|&xi| xi * xi).sum();
322 let y_norm_sq: f64 = y.iter().map(|&yi| yi * yi).sum();
323 let xy_dot: f64 = x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum();
324
325 let num_coef = 1.0 + 2.0 * c * xy_dot + c * y_norm_sq;
326 let denom = 1.0 + 2.0 * c * xy_dot + c * c * x_norm_sq * y_norm_sq;
327
328 if denom.abs() < EPS {
329 return Ok(x.to_vec());
330 }
331
332 let y_coef = 1.0 - c * x_norm_sq;
333
334 let result: Vec<f64> = x
335 .iter()
336 .zip(y.iter())
337 .map(|(&xi, &yi)| (num_coef * xi + y_coef * yi) / denom)
338 .collect();
339
340 Ok(result)
341 }
342
343 fn spherical_exp_map(&self, x: &[f64], v: &[f64]) -> Result<Vec<f64>> {
345 let v_norm = norm(v);
346
347 if v_norm < EPS {
348 return Ok(x.to_vec());
349 }
350
351 let cos_t = v_norm.cos();
352 let sin_t = v_norm.sin();
353
354 let result: Vec<f64> = x
355 .iter()
356 .zip(v.iter())
357 .map(|(&xi, &vi)| cos_t * xi + sin_t * vi / v_norm)
358 .collect();
359
360 let n = norm(&result);
362 if n > EPS {
363 Ok(result.iter().map(|&r| r / n).collect())
364 } else {
365 Ok(x.to_vec())
366 }
367 }
368
369 pub fn log_map(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
371 if x.len() != self.dim() || y.len() != self.dim() {
372 return Err(MathError::dimension_mismatch(self.dim(), x.len()));
373 }
374
375 let mut result = vec![0.0; self.dim()];
376 let (e_range, h_range, s_range) = self.config.component_ranges();
377
378 for i in e_range.clone() {
380 result[i] = y[i] - x[i];
381 }
382
383 if !h_range.is_empty() {
385 let x_h = &x[h_range.clone()];
386 let y_h = &y[h_range.clone()];
387 let log_h = self.poincare_log_map(x_h, y_h)?;
388 for (i, val) in h_range.clone().zip(log_h.iter()) {
389 result[i] = *val;
390 }
391 }
392
393 if !s_range.is_empty() {
395 let x_s = &x[s_range.clone()];
396 let y_s = &y[s_range.clone()];
397 let log_s = self.spherical_log_map(x_s, y_s)?;
398 for (i, val) in s_range.clone().zip(log_s.iter()) {
399 result[i] = *val;
400 }
401 }
402
403 Ok(result)
404 }
405
406 fn poincare_log_map(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
408 let c = -self.config.hyperbolic_curvature;
409
410 let neg_x: Vec<f64> = x.iter().map(|&xi| -xi).collect();
412 let diff = self.mobius_add(&neg_x, y, c)?;
413
414 let diff_norm: f64 = diff.iter().map(|&d| d * d).sum::<f64>().sqrt();
415
416 if diff_norm < EPS {
417 return Ok(vec![0.0; x.len()]);
418 }
419
420 let x_norm_sq: f64 = x.iter().map(|&xi| xi * xi).sum();
421 let lambda_x = 2.0 / (1.0 - c * x_norm_sq).max(EPS);
422
423 let sqrt_c = c.sqrt();
424 let arctanh_arg = (sqrt_c * diff_norm).min(1.0 - EPS);
425 let scale = (2.0 / (lambda_x * sqrt_c)) * arctanh_arg.atanh() / diff_norm;
426
427 Ok(diff.iter().map(|&d| scale * d).collect())
428 }
429
430 fn spherical_log_map(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
432 let cos_theta = dot(x, y).clamp(-1.0, 1.0);
433 let theta = cos_theta.acos();
434
435 if theta < EPS {
436 return Ok(vec![0.0; x.len()]);
437 }
438
439 if (theta - std::f64::consts::PI).abs() < EPS {
440 return Err(MathError::numerical_instability("Antipodal points"));
441 }
442
443 let scale = theta / theta.sin();
444
445 Ok(x
446 .iter()
447 .zip(y.iter())
448 .map(|(&xi, &yi)| scale * (yi - cos_theta * xi))
449 .collect())
450 }
451
452 pub fn frechet_mean(&self, points: &[Vec<f64>], weights: Option<&[f64]>) -> Result<Vec<f64>> {
454 if points.is_empty() {
455 return Err(MathError::empty_input("points"));
456 }
457
458 let n = points.len();
459 let uniform = 1.0 / n as f64;
460 let weights: Vec<f64> = match weights {
461 Some(w) => {
462 let sum: f64 = w.iter().sum();
463 w.iter().map(|&wi| wi / sum).collect()
464 }
465 None => vec![uniform; n],
466 };
467
468 let mut mean = vec![0.0; self.dim()];
470 for (p, &w) in points.iter().zip(weights.iter()) {
471 for (mi, &pi) in mean.iter_mut().zip(p.iter()) {
472 *mi += w * pi;
473 }
474 }
475 mean = self.project(&mean)?;
476
477 for _ in 0..100 {
479 let mut gradient = vec![0.0; self.dim()];
480
481 for (p, &w) in points.iter().zip(weights.iter()) {
482 if let Ok(log_v) = self.log_map(&mean, p) {
483 for (gi, &li) in gradient.iter_mut().zip(log_v.iter()) {
484 *gi += w * li;
485 }
486 }
487 }
488
489 let grad_norm = norm(&gradient);
490 if grad_norm < 1e-8 {
491 break;
492 }
493
494 mean = self.exp_map(&mean, &gradient)?;
496 }
497
498 Ok(mean)
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505
506 #[test]
507 fn test_product_manifold_creation() {
508 let manifold = ProductManifold::new(32, 16, 8);
509
510 assert_eq!(manifold.dim(), 56);
511 assert_eq!(manifold.config.euclidean_dim, 32);
512 assert_eq!(manifold.config.hyperbolic_dim, 16);
513 assert_eq!(manifold.config.spherical_dim, 8);
514 }
515
516 #[test]
517 fn test_projection() {
518 let manifold = ProductManifold::new(2, 2, 3);
519
520 let point = vec![1.0, 2.0, 2.0, 0.0, 3.0, 4.0, 0.0];
522
523 let projected = manifold.project(&point).unwrap();
524
525 let h = manifold.hyperbolic_component(&projected);
527 let h_norm: f64 = h.iter().map(|&x| x * x).sum::<f64>().sqrt();
528 assert!(h_norm < 1.0);
529
530 let s = manifold.spherical_component(&projected);
532 let s_norm: f64 = s.iter().map(|&x| x * x).sum::<f64>().sqrt();
533 assert!((s_norm - 1.0).abs() < 1e-6);
534 }
535
536 #[test]
537 fn test_euclidean_only_distance() {
538 let manifold = ProductManifold::new(3, 0, 0);
539
540 let x = vec![0.0, 0.0, 0.0];
541 let y = vec![3.0, 4.0, 0.0];
542
543 let dist = manifold.distance(&x, &y).unwrap();
544 assert!((dist - 5.0).abs() < 1e-10);
545 }
546
547 #[test]
548 fn test_product_distance() {
549 let manifold = ProductManifold::new(2, 2, 3);
550
551 let x = manifold
552 .project(&vec![0.0, 0.0, 0.1, 0.0, 1.0, 0.0, 0.0])
553 .unwrap();
554 let y = manifold
555 .project(&vec![1.0, 1.0, 0.0, 0.1, 0.0, 1.0, 0.0])
556 .unwrap();
557
558 let dist = manifold.distance(&x, &y).unwrap();
559 assert!(dist > 0.0);
560 }
561
562 #[test]
563 fn test_exp_log_inverse() {
564 let manifold = ProductManifold::new(2, 0, 0); let x = vec![1.0, 2.0];
567 let y = vec![3.0, 4.0];
568
569 let v = manifold.log_map(&x, &y).unwrap();
570 let y_recovered = manifold.exp_map(&x, &v).unwrap();
571
572 for (yi, yr) in y.iter().zip(y_recovered.iter()) {
573 assert!((yi - yr).abs() < 1e-6);
574 }
575 }
576}