ruvector_math/product_manifold/
manifold.rs

1//! Product manifold implementation
2
3use crate::error::{MathError, Result};
4use crate::spherical::SphericalSpace;
5use crate::utils::{dot, norm, EPS};
6use super::config::ProductManifoldConfig;
7
8/// Product manifold: M = E^e × H^h × S^s
9#[derive(Debug, Clone)]
10pub struct ProductManifold {
11    config: ProductManifoldConfig,
12    spherical: Option<SphericalSpace>,
13}
14
15impl ProductManifold {
16    /// Create a new product manifold
17    ///
18    /// # Arguments
19    /// * `euclidean_dim` - Dimension of Euclidean component
20    /// * `hyperbolic_dim` - Dimension of hyperbolic component (Poincaré ball)
21    /// * `spherical_dim` - Dimension of spherical component
22    pub fn new(euclidean_dim: usize, hyperbolic_dim: usize, spherical_dim: usize) -> Self {
23        let config = ProductManifoldConfig::new(euclidean_dim, hyperbolic_dim, spherical_dim);
24        let spherical = if spherical_dim > 0 {
25            Some(SphericalSpace::new(spherical_dim))
26        } else {
27            None
28        };
29
30        Self { config, spherical }
31    }
32
33    /// Create from configuration
34    pub fn from_config(config: ProductManifoldConfig) -> Self {
35        let spherical = if config.spherical_dim > 0 {
36            Some(SphericalSpace::new(config.spherical_dim))
37        } else {
38            None
39        };
40
41        Self { config, spherical }
42    }
43
44    /// Get configuration
45    pub fn config(&self) -> &ProductManifoldConfig {
46        &self.config
47    }
48
49    /// Total dimension
50    pub fn dim(&self) -> usize {
51        self.config.total_dim()
52    }
53
54    /// Extract Euclidean component from point
55    pub fn euclidean_component<'a>(&self, point: &'a [f64]) -> &'a [f64] {
56        let (e_range, _, _) = self.config.component_ranges();
57        &point[e_range]
58    }
59
60    /// Extract hyperbolic component from point
61    pub fn hyperbolic_component<'a>(&self, point: &'a [f64]) -> &'a [f64] {
62        let (_, h_range, _) = self.config.component_ranges();
63        &point[h_range]
64    }
65
66    /// Extract spherical component from point
67    pub fn spherical_component<'a>(&self, point: &'a [f64]) -> &'a [f64] {
68        let (_, _, s_range) = self.config.component_ranges();
69        &point[s_range]
70    }
71
72    /// Project point onto the product manifold
73    ///
74    /// - Euclidean: no projection needed
75    /// - Hyperbolic: project into Poincaré ball
76    /// - Spherical: normalize to unit sphere
77    pub fn project(&self, point: &[f64]) -> Result<Vec<f64>> {
78        if point.len() != self.dim() {
79            return Err(MathError::dimension_mismatch(self.dim(), point.len()));
80        }
81
82        let mut result = point.to_vec();
83        let (_e_range, h_range, s_range) = self.config.component_ranges();
84
85        // Euclidean: no projection needed (kept as-is)
86        // Hyperbolic: project to Poincaré ball (||x|| < 1)
87        if !h_range.is_empty() {
88            let h_part = &mut result[h_range.clone()];
89            let h_norm: f64 = h_part.iter().map(|&x| x * x).sum::<f64>().sqrt();
90
91            if h_norm >= 1.0 - EPS {
92                let scale = (1.0 - EPS) / h_norm;
93                for x in h_part.iter_mut() {
94                    *x *= scale;
95                }
96            }
97        }
98
99        // Spherical: normalize to unit sphere
100        if !s_range.is_empty() {
101            let s_part = &mut result[s_range.clone()];
102            let s_norm: f64 = s_part.iter().map(|&x| x * x).sum::<f64>().sqrt();
103
104            if s_norm > EPS {
105                for x in s_part.iter_mut() {
106                    *x /= s_norm;
107                }
108            } else {
109                // Set to north pole
110                s_part[0] = 1.0;
111                for x in s_part[1..].iter_mut() {
112                    *x = 0.0;
113                }
114            }
115        }
116
117        Ok(result)
118    }
119
120    /// Compute distance in product manifold
121    ///
122    /// d(x, y)² = w_e d_E(x_e, y_e)² + w_h d_H(x_h, y_h)² + w_s d_S(x_s, y_s)²
123    #[inline]
124    pub fn distance(&self, x: &[f64], y: &[f64]) -> Result<f64> {
125        if x.len() != self.dim() || y.len() != self.dim() {
126            return Err(MathError::dimension_mismatch(self.dim(), x.len()));
127        }
128
129        let (w_e, w_h, w_s) = self.config.component_weights;
130        let (e_range, h_range, s_range) = self.config.component_ranges();
131
132        let mut dist_sq = 0.0;
133
134        // Euclidean distance with SIMD-friendly accumulation
135        if !e_range.is_empty() && w_e > 0.0 {
136            let d_e = self.euclidean_distance_sq(&x[e_range.clone()], &y[e_range.clone()]);
137            dist_sq += w_e * d_e;
138        }
139
140        // Hyperbolic (Poincaré) distance
141        if !h_range.is_empty() && w_h > 0.0 {
142            let x_h = &x[h_range.clone()];
143            let y_h = &y[h_range.clone()];
144            let d_h = self.poincare_distance(x_h, y_h)?;
145            dist_sq += w_h * d_h * d_h;
146        }
147
148        // Spherical distance
149        if !s_range.is_empty() && w_s > 0.0 {
150            let x_s = &x[s_range.clone()];
151            let y_s = &y[s_range.clone()];
152            let d_s = self.spherical_distance(x_s, y_s)?;
153            dist_sq += w_s * d_s * d_s;
154        }
155
156        Ok(dist_sq.sqrt())
157    }
158
159    /// SIMD-friendly squared Euclidean distance using 4-way unrolled accumulator
160    #[inline(always)]
161    fn euclidean_distance_sq(&self, x: &[f64], y: &[f64]) -> f64 {
162        let len = x.len();
163        let chunks = len / 4;
164        let remainder = len % 4;
165
166        let mut sum0 = 0.0f64;
167        let mut sum1 = 0.0f64;
168        let mut sum2 = 0.0f64;
169        let mut sum3 = 0.0f64;
170
171        // Process 4 elements at a time for SIMD vectorization
172        for i in 0..chunks {
173            let base = i * 4;
174            let d0 = x[base] - y[base];
175            let d1 = x[base + 1] - y[base + 1];
176            let d2 = x[base + 2] - y[base + 2];
177            let d3 = x[base + 3] - y[base + 3];
178            sum0 += d0 * d0;
179            sum1 += d1 * d1;
180            sum2 += d2 * d2;
181            sum3 += d3 * d3;
182        }
183
184        // Handle remainder
185        let base = chunks * 4;
186        for i in 0..remainder {
187            let d = x[base + i] - y[base + i];
188            sum0 += d * d;
189        }
190
191        sum0 + sum1 + sum2 + sum3
192    }
193
194    /// Poincaré ball distance
195    ///
196    /// d(x, y) = arcosh(1 + 2 ||x - y||² / ((1 - ||x||²)(1 - ||y||²)))
197    ///
198    /// Optimized with SIMD-friendly 4-way accumulator for computing norms
199    #[inline]
200    fn poincare_distance(&self, x: &[f64], y: &[f64]) -> Result<f64> {
201        let len = x.len();
202        let chunks = len / 4;
203        let remainder = len % 4;
204
205        // Compute all three values in one pass for better cache utilization
206        let mut x_norm_sq = 0.0f64;
207        let mut y_norm_sq = 0.0f64;
208        let mut diff_sq = 0.0f64;
209
210        // 4-way unrolled for SIMD
211        for i in 0..chunks {
212            let base = i * 4;
213
214            let x0 = x[base];
215            let x1 = x[base + 1];
216            let x2 = x[base + 2];
217            let x3 = x[base + 3];
218
219            let y0 = y[base];
220            let y1 = y[base + 1];
221            let y2 = y[base + 2];
222            let y3 = y[base + 3];
223
224            x_norm_sq += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
225            y_norm_sq += y0 * y0 + y1 * y1 + y2 * y2 + y3 * y3;
226
227            let d0 = x0 - y0;
228            let d1 = x1 - y1;
229            let d2 = x2 - y2;
230            let d3 = x3 - y3;
231            diff_sq += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
232        }
233
234        // Handle remainder
235        let base = chunks * 4;
236        for i in 0..remainder {
237            let xi = x[base + i];
238            let yi = y[base + i];
239            x_norm_sq += xi * xi;
240            y_norm_sq += yi * yi;
241            let d = xi - yi;
242            diff_sq += d * d;
243        }
244
245        let denom = (1.0 - x_norm_sq).max(EPS) * (1.0 - y_norm_sq).max(EPS);
246        let arg = 1.0 + 2.0 * diff_sq / denom;
247
248        // Apply curvature scaling
249        let c = (-self.config.hyperbolic_curvature).sqrt();
250        Ok(arg.max(1.0).acosh() / c)
251    }
252
253    /// Spherical distance (geodesic)
254    fn spherical_distance(&self, x: &[f64], y: &[f64]) -> Result<f64> {
255        let cos_angle = dot(x, y).clamp(-1.0, 1.0);
256        let c = self.config.spherical_curvature.sqrt();
257        Ok(cos_angle.acos() / c)
258    }
259
260    /// Exponential map at point x with tangent vector v
261    pub fn exp_map(&self, x: &[f64], v: &[f64]) -> Result<Vec<f64>> {
262        if x.len() != self.dim() || v.len() != self.dim() {
263            return Err(MathError::dimension_mismatch(self.dim(), x.len()));
264        }
265
266        let mut result = vec![0.0; self.dim()];
267        let (e_range, h_range, s_range) = self.config.component_ranges();
268
269        // Euclidean: exp_x(v) = x + v
270        for i in e_range.clone() {
271            result[i] = x[i] + v[i];
272        }
273
274        // Hyperbolic (Poincaré) exp map
275        if !h_range.is_empty() {
276            let x_h = &x[h_range.clone()];
277            let v_h = &v[h_range.clone()];
278            let exp_h = self.poincare_exp_map(x_h, v_h)?;
279            for (i, val) in h_range.clone().zip(exp_h.iter()) {
280                result[i] = *val;
281            }
282        }
283
284        // Spherical exp map
285        if !s_range.is_empty() {
286            let x_s = &x[s_range.clone()];
287            let v_s = &v[s_range.clone()];
288            let exp_s = self.spherical_exp_map(x_s, v_s)?;
289            for (i, val) in s_range.clone().zip(exp_s.iter()) {
290                result[i] = *val;
291            }
292        }
293
294        self.project(&result)
295    }
296
297    /// Poincaré ball exponential map
298    fn poincare_exp_map(&self, x: &[f64], v: &[f64]) -> Result<Vec<f64>> {
299        let c = -self.config.hyperbolic_curvature;
300        let sqrt_c = c.sqrt();
301
302        let x_norm_sq: f64 = x.iter().map(|&xi| xi * xi).sum();
303        let v_norm: f64 = v.iter().map(|&vi| vi * vi).sum::<f64>().sqrt();
304
305        if v_norm < EPS {
306            return Ok(x.to_vec());
307        }
308
309        let lambda_x = 2.0 / (1.0 - c * x_norm_sq).max(EPS);
310        let norm_v = lambda_x * v_norm;
311
312        let t = (sqrt_c * norm_v).tanh() / (sqrt_c * v_norm);
313
314        // Möbius addition: x ⊕_c (t * v)
315        let tv: Vec<f64> = v.iter().map(|&vi| t * vi).collect();
316        self.mobius_add(x, &tv, c)
317    }
318
319    /// Möbius addition in Poincaré ball
320    fn mobius_add(&self, x: &[f64], y: &[f64], c: f64) -> Result<Vec<f64>> {
321        let x_norm_sq: f64 = x.iter().map(|&xi| xi * xi).sum();
322        let y_norm_sq: f64 = y.iter().map(|&yi| yi * yi).sum();
323        let xy_dot: f64 = x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum();
324
325        let num_coef = 1.0 + 2.0 * c * xy_dot + c * y_norm_sq;
326        let denom = 1.0 + 2.0 * c * xy_dot + c * c * x_norm_sq * y_norm_sq;
327
328        if denom.abs() < EPS {
329            return Ok(x.to_vec());
330        }
331
332        let y_coef = 1.0 - c * x_norm_sq;
333
334        let result: Vec<f64> = x
335            .iter()
336            .zip(y.iter())
337            .map(|(&xi, &yi)| (num_coef * xi + y_coef * yi) / denom)
338            .collect();
339
340        Ok(result)
341    }
342
343    /// Spherical exponential map
344    fn spherical_exp_map(&self, x: &[f64], v: &[f64]) -> Result<Vec<f64>> {
345        let v_norm = norm(v);
346
347        if v_norm < EPS {
348            return Ok(x.to_vec());
349        }
350
351        let cos_t = v_norm.cos();
352        let sin_t = v_norm.sin();
353
354        let result: Vec<f64> = x
355            .iter()
356            .zip(v.iter())
357            .map(|(&xi, &vi)| cos_t * xi + sin_t * vi / v_norm)
358            .collect();
359
360        // Normalize to sphere
361        let n = norm(&result);
362        if n > EPS {
363            Ok(result.iter().map(|&r| r / n).collect())
364        } else {
365            Ok(x.to_vec())
366        }
367    }
368
369    /// Logarithmic map at point x toward point y
370    pub fn log_map(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
371        if x.len() != self.dim() || y.len() != self.dim() {
372            return Err(MathError::dimension_mismatch(self.dim(), x.len()));
373        }
374
375        let mut result = vec![0.0; self.dim()];
376        let (e_range, h_range, s_range) = self.config.component_ranges();
377
378        // Euclidean: log_x(y) = y - x
379        for i in e_range.clone() {
380            result[i] = y[i] - x[i];
381        }
382
383        // Hyperbolic log map
384        if !h_range.is_empty() {
385            let x_h = &x[h_range.clone()];
386            let y_h = &y[h_range.clone()];
387            let log_h = self.poincare_log_map(x_h, y_h)?;
388            for (i, val) in h_range.clone().zip(log_h.iter()) {
389                result[i] = *val;
390            }
391        }
392
393        // Spherical log map
394        if !s_range.is_empty() {
395            let x_s = &x[s_range.clone()];
396            let y_s = &y[s_range.clone()];
397            let log_s = self.spherical_log_map(x_s, y_s)?;
398            for (i, val) in s_range.clone().zip(log_s.iter()) {
399                result[i] = *val;
400            }
401        }
402
403        Ok(result)
404    }
405
406    /// Poincaré ball logarithmic map
407    fn poincare_log_map(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
408        let c = -self.config.hyperbolic_curvature;
409
410        // -x ⊕_c y
411        let neg_x: Vec<f64> = x.iter().map(|&xi| -xi).collect();
412        let diff = self.mobius_add(&neg_x, y, c)?;
413
414        let diff_norm: f64 = diff.iter().map(|&d| d * d).sum::<f64>().sqrt();
415
416        if diff_norm < EPS {
417            return Ok(vec![0.0; x.len()]);
418        }
419
420        let x_norm_sq: f64 = x.iter().map(|&xi| xi * xi).sum();
421        let lambda_x = 2.0 / (1.0 - c * x_norm_sq).max(EPS);
422
423        let sqrt_c = c.sqrt();
424        let arctanh_arg = (sqrt_c * diff_norm).min(1.0 - EPS);
425        let scale = (2.0 / (lambda_x * sqrt_c)) * arctanh_arg.atanh() / diff_norm;
426
427        Ok(diff.iter().map(|&d| scale * d).collect())
428    }
429
430    /// Spherical logarithmic map
431    fn spherical_log_map(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
432        let cos_theta = dot(x, y).clamp(-1.0, 1.0);
433        let theta = cos_theta.acos();
434
435        if theta < EPS {
436            return Ok(vec![0.0; x.len()]);
437        }
438
439        if (theta - std::f64::consts::PI).abs() < EPS {
440            return Err(MathError::numerical_instability("Antipodal points"));
441        }
442
443        let scale = theta / theta.sin();
444
445        Ok(x
446            .iter()
447            .zip(y.iter())
448            .map(|(&xi, &yi)| scale * (yi - cos_theta * xi))
449            .collect())
450    }
451
452    /// Compute Fréchet mean on product manifold
453    pub fn frechet_mean(&self, points: &[Vec<f64>], weights: Option<&[f64]>) -> Result<Vec<f64>> {
454        if points.is_empty() {
455            return Err(MathError::empty_input("points"));
456        }
457
458        let n = points.len();
459        let uniform = 1.0 / n as f64;
460        let weights: Vec<f64> = match weights {
461            Some(w) => {
462                let sum: f64 = w.iter().sum();
463                w.iter().map(|&wi| wi / sum).collect()
464            }
465            None => vec![uniform; n],
466        };
467
468        // Initialize with weighted Euclidean mean
469        let mut mean = vec![0.0; self.dim()];
470        for (p, &w) in points.iter().zip(weights.iter()) {
471            for (mi, &pi) in mean.iter_mut().zip(p.iter()) {
472                *mi += w * pi;
473            }
474        }
475        mean = self.project(&mean)?;
476
477        // Iterative refinement
478        for _ in 0..100 {
479            let mut gradient = vec![0.0; self.dim()];
480
481            for (p, &w) in points.iter().zip(weights.iter()) {
482                if let Ok(log_v) = self.log_map(&mean, p) {
483                    for (gi, &li) in gradient.iter_mut().zip(log_v.iter()) {
484                        *gi += w * li;
485                    }
486                }
487            }
488
489            let grad_norm = norm(&gradient);
490            if grad_norm < 1e-8 {
491                break;
492            }
493
494            // Step along geodesic (learning rate = 1.0)
495            mean = self.exp_map(&mean, &gradient)?;
496        }
497
498        Ok(mean)
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505
506    #[test]
507    fn test_product_manifold_creation() {
508        let manifold = ProductManifold::new(32, 16, 8);
509
510        assert_eq!(manifold.dim(), 56);
511        assert_eq!(manifold.config.euclidean_dim, 32);
512        assert_eq!(manifold.config.hyperbolic_dim, 16);
513        assert_eq!(manifold.config.spherical_dim, 8);
514    }
515
516    #[test]
517    fn test_projection() {
518        let manifold = ProductManifold::new(2, 2, 3);
519
520        // Point with hyperbolic component outside ball and unnormalized spherical
521        let point = vec![1.0, 2.0, 2.0, 0.0, 3.0, 4.0, 0.0];
522
523        let projected = manifold.project(&point).unwrap();
524
525        // Check hyperbolic is in ball
526        let h = manifold.hyperbolic_component(&projected);
527        let h_norm: f64 = h.iter().map(|&x| x * x).sum::<f64>().sqrt();
528        assert!(h_norm < 1.0);
529
530        // Check spherical is normalized
531        let s = manifold.spherical_component(&projected);
532        let s_norm: f64 = s.iter().map(|&x| x * x).sum::<f64>().sqrt();
533        assert!((s_norm - 1.0).abs() < 1e-6);
534    }
535
536    #[test]
537    fn test_euclidean_only_distance() {
538        let manifold = ProductManifold::new(3, 0, 0);
539
540        let x = vec![0.0, 0.0, 0.0];
541        let y = vec![3.0, 4.0, 0.0];
542
543        let dist = manifold.distance(&x, &y).unwrap();
544        assert!((dist - 5.0).abs() < 1e-10);
545    }
546
547    #[test]
548    fn test_product_distance() {
549        let manifold = ProductManifold::new(2, 2, 3);
550
551        let x = manifold
552            .project(&vec![0.0, 0.0, 0.1, 0.0, 1.0, 0.0, 0.0])
553            .unwrap();
554        let y = manifold
555            .project(&vec![1.0, 1.0, 0.0, 0.1, 0.0, 1.0, 0.0])
556            .unwrap();
557
558        let dist = manifold.distance(&x, &y).unwrap();
559        assert!(dist > 0.0);
560    }
561
562    #[test]
563    fn test_exp_log_inverse() {
564        let manifold = ProductManifold::new(2, 0, 0); // Euclidean only for simplicity
565
566        let x = vec![1.0, 2.0];
567        let y = vec![3.0, 4.0];
568
569        let v = manifold.log_map(&x, &y).unwrap();
570        let y_recovered = manifold.exp_map(&x, &v).unwrap();
571
572        for (yi, yr) in y.iter().zip(y_recovered.iter()) {
573            assert!((yi - yr).abs() < 1e-6);
574        }
575    }
576}