ruvector_math/product_manifold/
config.rs1use crate::error::{MathError, Result};
4
5#[derive(Debug, Clone, Copy, PartialEq)]
7pub enum CurvatureType {
8 Euclidean,
10 Hyperbolic {
12 curvature: f64,
14 },
15 Spherical {
17 curvature: f64,
19 },
20}
21
22impl CurvatureType {
23 pub fn hyperbolic() -> Self {
25 Self::Hyperbolic { curvature: -1.0 }
26 }
27
28 pub fn hyperbolic_with(curvature: f64) -> Self {
30 Self::Hyperbolic {
31 curvature: curvature.min(-1e-6),
32 }
33 }
34
35 pub fn spherical() -> Self {
37 Self::Spherical { curvature: 1.0 }
38 }
39
40 pub fn spherical_with(curvature: f64) -> Self {
42 Self::Spherical {
43 curvature: curvature.max(1e-6),
44 }
45 }
46
47 pub fn curvature(&self) -> f64 {
49 match self {
50 Self::Euclidean => 0.0,
51 Self::Hyperbolic { curvature } => *curvature,
52 Self::Spherical { curvature } => *curvature,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct ProductManifoldConfig {
60 pub euclidean_dim: usize,
62 pub hyperbolic_dim: usize,
64 pub hyperbolic_curvature: f64,
66 pub spherical_dim: usize,
68 pub spherical_curvature: f64,
70 pub component_weights: (f64, f64, f64),
72}
73
74impl ProductManifoldConfig {
75 pub fn new(euclidean_dim: usize, hyperbolic_dim: usize, spherical_dim: usize) -> Self {
82 Self {
83 euclidean_dim,
84 hyperbolic_dim,
85 hyperbolic_curvature: -1.0,
86 spherical_dim,
87 spherical_curvature: 1.0,
88 component_weights: (1.0, 1.0, 1.0),
89 }
90 }
91
92 pub fn euclidean(dim: usize) -> Self {
94 Self::new(dim, 0, 0)
95 }
96
97 pub fn hyperbolic(dim: usize) -> Self {
99 Self::new(0, dim, 0)
100 }
101
102 pub fn spherical(dim: usize) -> Self {
104 Self::new(0, 0, dim)
105 }
106
107 pub fn euclidean_hyperbolic(euclidean_dim: usize, hyperbolic_dim: usize) -> Self {
109 Self::new(euclidean_dim, hyperbolic_dim, 0)
110 }
111
112 pub fn with_hyperbolic_curvature(mut self, c: f64) -> Self {
114 self.hyperbolic_curvature = c.min(-1e-6);
115 self
116 }
117
118 pub fn with_spherical_curvature(mut self, c: f64) -> Self {
120 self.spherical_curvature = c.max(1e-6);
121 self
122 }
123
124 pub fn with_weights(mut self, euclidean: f64, hyperbolic: f64, spherical: f64) -> Self {
126 self.component_weights = (euclidean.max(0.0), hyperbolic.max(0.0), spherical.max(0.0));
127 self
128 }
129
130 pub fn total_dim(&self) -> usize {
132 self.euclidean_dim + self.hyperbolic_dim + self.spherical_dim
133 }
134
135 pub fn validate(&self) -> Result<()> {
137 if self.total_dim() == 0 {
138 return Err(MathError::invalid_parameter(
139 "dimensions",
140 "at least one component must have non-zero dimension",
141 ));
142 }
143
144 if self.hyperbolic_curvature >= 0.0 {
145 return Err(MathError::invalid_parameter(
146 "hyperbolic_curvature",
147 "must be negative",
148 ));
149 }
150
151 if self.spherical_curvature <= 0.0 {
152 return Err(MathError::invalid_parameter(
153 "spherical_curvature",
154 "must be positive",
155 ));
156 }
157
158 Ok(())
159 }
160
161 pub fn component_ranges(
163 &self,
164 ) -> (
165 std::ops::Range<usize>,
166 std::ops::Range<usize>,
167 std::ops::Range<usize>,
168 ) {
169 let e_end = self.euclidean_dim;
170 let h_end = e_end + self.hyperbolic_dim;
171 let s_end = h_end + self.spherical_dim;
172
173 (0..e_end, e_end..h_end, h_end..s_end)
174 }
175}
176
177impl Default for ProductManifoldConfig {
178 fn default() -> Self {
179 Self::new(64, 16, 8)
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn test_config_creation() {
190 let config = ProductManifoldConfig::new(32, 16, 8);
191
192 assert_eq!(config.euclidean_dim, 32);
193 assert_eq!(config.hyperbolic_dim, 16);
194 assert_eq!(config.spherical_dim, 8);
195 assert_eq!(config.total_dim(), 56);
196 }
197
198 #[test]
199 fn test_component_ranges() {
200 let config = ProductManifoldConfig::new(10, 5, 3);
201 let (e, h, s) = config.component_ranges();
202
203 assert_eq!(e, 0..10);
204 assert_eq!(h, 10..15);
205 assert_eq!(s, 15..18);
206 }
207
208 #[test]
209 fn test_validation() {
210 let config = ProductManifoldConfig::new(0, 0, 0);
211 assert!(config.validate().is_err());
212
213 let config = ProductManifoldConfig::new(10, 5, 0);
214 assert!(config.validate().is_ok());
215 }
216}