ruvector_math/product_manifold/
config.rs

1//! Configuration for product manifolds
2
3use crate::error::{MathError, Result};
4
5/// Type of curvature for a manifold component
6#[derive(Debug, Clone, Copy, PartialEq)]
7pub enum CurvatureType {
8    /// Euclidean (flat) space, curvature = 0
9    Euclidean,
10    /// Hyperbolic space, curvature < 0
11    Hyperbolic {
12        /// Negative curvature parameter (typically -1)
13        curvature: f64,
14    },
15    /// Spherical space, curvature > 0
16    Spherical {
17        /// Positive curvature parameter (typically 1)
18        curvature: f64,
19    },
20}
21
22impl CurvatureType {
23    /// Create hyperbolic component with default curvature -1
24    pub fn hyperbolic() -> Self {
25        Self::Hyperbolic { curvature: -1.0 }
26    }
27
28    /// Create hyperbolic component with custom curvature
29    pub fn hyperbolic_with(curvature: f64) -> Self {
30        Self::Hyperbolic {
31            curvature: curvature.min(-1e-6),
32        }
33    }
34
35    /// Create spherical component with default curvature 1
36    pub fn spherical() -> Self {
37        Self::Spherical { curvature: 1.0 }
38    }
39
40    /// Create spherical component with custom curvature
41    pub fn spherical_with(curvature: f64) -> Self {
42        Self::Spherical {
43            curvature: curvature.max(1e-6),
44        }
45    }
46
47    /// Get curvature value
48    pub fn curvature(&self) -> f64 {
49        match self {
50            Self::Euclidean => 0.0,
51            Self::Hyperbolic { curvature } => *curvature,
52            Self::Spherical { curvature } => *curvature,
53        }
54    }
55}
56
57/// Configuration for a product manifold
58#[derive(Debug, Clone)]
59pub struct ProductManifoldConfig {
60    /// Euclidean dimension
61    pub euclidean_dim: usize,
62    /// Hyperbolic dimension (Poincaré ball ambient dimension)
63    pub hyperbolic_dim: usize,
64    /// Hyperbolic curvature (negative)
65    pub hyperbolic_curvature: f64,
66    /// Spherical dimension (ambient dimension)
67    pub spherical_dim: usize,
68    /// Spherical curvature (positive)
69    pub spherical_curvature: f64,
70    /// Weights for combining distances
71    pub component_weights: (f64, f64, f64),
72}
73
74impl ProductManifoldConfig {
75    /// Create a new product manifold configuration
76    ///
77    /// # Arguments
78    /// * `euclidean_dim` - Dimension of Euclidean component E^e
79    /// * `hyperbolic_dim` - Dimension of hyperbolic component H^h
80    /// * `spherical_dim` - Dimension of spherical component S^s
81    pub fn new(euclidean_dim: usize, hyperbolic_dim: usize, spherical_dim: usize) -> Self {
82        Self {
83            euclidean_dim,
84            hyperbolic_dim,
85            hyperbolic_curvature: -1.0,
86            spherical_dim,
87            spherical_curvature: 1.0,
88            component_weights: (1.0, 1.0, 1.0),
89        }
90    }
91
92    /// Create Euclidean-only configuration
93    pub fn euclidean(dim: usize) -> Self {
94        Self::new(dim, 0, 0)
95    }
96
97    /// Create hyperbolic-only configuration
98    pub fn hyperbolic(dim: usize) -> Self {
99        Self::new(0, dim, 0)
100    }
101
102    /// Create spherical-only configuration
103    pub fn spherical(dim: usize) -> Self {
104        Self::new(0, 0, dim)
105    }
106
107    /// Create Euclidean Ă— Hyperbolic configuration
108    pub fn euclidean_hyperbolic(euclidean_dim: usize, hyperbolic_dim: usize) -> Self {
109        Self::new(euclidean_dim, hyperbolic_dim, 0)
110    }
111
112    /// Set hyperbolic curvature
113    pub fn with_hyperbolic_curvature(mut self, c: f64) -> Self {
114        self.hyperbolic_curvature = c.min(-1e-6);
115        self
116    }
117
118    /// Set spherical curvature
119    pub fn with_spherical_curvature(mut self, c: f64) -> Self {
120        self.spherical_curvature = c.max(1e-6);
121        self
122    }
123
124    /// Set component weights for distance computation
125    pub fn with_weights(mut self, euclidean: f64, hyperbolic: f64, spherical: f64) -> Self {
126        self.component_weights = (euclidean.max(0.0), hyperbolic.max(0.0), spherical.max(0.0));
127        self
128    }
129
130    /// Total dimension of the product manifold
131    pub fn total_dim(&self) -> usize {
132        self.euclidean_dim + self.hyperbolic_dim + self.spherical_dim
133    }
134
135    /// Validate configuration
136    pub fn validate(&self) -> Result<()> {
137        if self.total_dim() == 0 {
138            return Err(MathError::invalid_parameter(
139                "dimensions",
140                "at least one component must have non-zero dimension",
141            ));
142        }
143
144        if self.hyperbolic_curvature >= 0.0 {
145            return Err(MathError::invalid_parameter(
146                "hyperbolic_curvature",
147                "must be negative",
148            ));
149        }
150
151        if self.spherical_curvature <= 0.0 {
152            return Err(MathError::invalid_parameter(
153                "spherical_curvature",
154                "must be positive",
155            ));
156        }
157
158        Ok(())
159    }
160
161    /// Get slice ranges for each component
162    pub fn component_ranges(&self) -> (std::ops::Range<usize>, std::ops::Range<usize>, std::ops::Range<usize>) {
163        let e_end = self.euclidean_dim;
164        let h_end = e_end + self.hyperbolic_dim;
165        let s_end = h_end + self.spherical_dim;
166
167        (0..e_end, e_end..h_end, h_end..s_end)
168    }
169}
170
171impl Default for ProductManifoldConfig {
172    fn default() -> Self {
173        // Default: 64-dim Euclidean + 16-dim Hyperbolic + 8-dim Spherical
174        Self::new(64, 16, 8)
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn test_config_creation() {
184        let config = ProductManifoldConfig::new(32, 16, 8);
185
186        assert_eq!(config.euclidean_dim, 32);
187        assert_eq!(config.hyperbolic_dim, 16);
188        assert_eq!(config.spherical_dim, 8);
189        assert_eq!(config.total_dim(), 56);
190    }
191
192    #[test]
193    fn test_component_ranges() {
194        let config = ProductManifoldConfig::new(10, 5, 3);
195        let (e, h, s) = config.component_ranges();
196
197        assert_eq!(e, 0..10);
198        assert_eq!(h, 10..15);
199        assert_eq!(s, 15..18);
200    }
201
202    #[test]
203    fn test_validation() {
204        let config = ProductManifoldConfig::new(0, 0, 0);
205        assert!(config.validate().is_err());
206
207        let config = ProductManifoldConfig::new(10, 5, 0);
208        assert!(config.validate().is_ok());
209    }
210}