ruvector_math/product_manifold/
config.rs1use crate::error::{MathError, Result};
4
5#[derive(Debug, Clone, Copy, PartialEq)]
7pub enum CurvatureType {
8 Euclidean,
10 Hyperbolic {
12 curvature: f64,
14 },
15 Spherical {
17 curvature: f64,
19 },
20}
21
22impl CurvatureType {
23 pub fn hyperbolic() -> Self {
25 Self::Hyperbolic { curvature: -1.0 }
26 }
27
28 pub fn hyperbolic_with(curvature: f64) -> Self {
30 Self::Hyperbolic {
31 curvature: curvature.min(-1e-6),
32 }
33 }
34
35 pub fn spherical() -> Self {
37 Self::Spherical { curvature: 1.0 }
38 }
39
40 pub fn spherical_with(curvature: f64) -> Self {
42 Self::Spherical {
43 curvature: curvature.max(1e-6),
44 }
45 }
46
47 pub fn curvature(&self) -> f64 {
49 match self {
50 Self::Euclidean => 0.0,
51 Self::Hyperbolic { curvature } => *curvature,
52 Self::Spherical { curvature } => *curvature,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct ProductManifoldConfig {
60 pub euclidean_dim: usize,
62 pub hyperbolic_dim: usize,
64 pub hyperbolic_curvature: f64,
66 pub spherical_dim: usize,
68 pub spherical_curvature: f64,
70 pub component_weights: (f64, f64, f64),
72}
73
74impl ProductManifoldConfig {
75 pub fn new(euclidean_dim: usize, hyperbolic_dim: usize, spherical_dim: usize) -> Self {
82 Self {
83 euclidean_dim,
84 hyperbolic_dim,
85 hyperbolic_curvature: -1.0,
86 spherical_dim,
87 spherical_curvature: 1.0,
88 component_weights: (1.0, 1.0, 1.0),
89 }
90 }
91
92 pub fn euclidean(dim: usize) -> Self {
94 Self::new(dim, 0, 0)
95 }
96
97 pub fn hyperbolic(dim: usize) -> Self {
99 Self::new(0, dim, 0)
100 }
101
102 pub fn spherical(dim: usize) -> Self {
104 Self::new(0, 0, dim)
105 }
106
107 pub fn euclidean_hyperbolic(euclidean_dim: usize, hyperbolic_dim: usize) -> Self {
109 Self::new(euclidean_dim, hyperbolic_dim, 0)
110 }
111
112 pub fn with_hyperbolic_curvature(mut self, c: f64) -> Self {
114 self.hyperbolic_curvature = c.min(-1e-6);
115 self
116 }
117
118 pub fn with_spherical_curvature(mut self, c: f64) -> Self {
120 self.spherical_curvature = c.max(1e-6);
121 self
122 }
123
124 pub fn with_weights(mut self, euclidean: f64, hyperbolic: f64, spherical: f64) -> Self {
126 self.component_weights = (euclidean.max(0.0), hyperbolic.max(0.0), spherical.max(0.0));
127 self
128 }
129
130 pub fn total_dim(&self) -> usize {
132 self.euclidean_dim + self.hyperbolic_dim + self.spherical_dim
133 }
134
135 pub fn validate(&self) -> Result<()> {
137 if self.total_dim() == 0 {
138 return Err(MathError::invalid_parameter(
139 "dimensions",
140 "at least one component must have non-zero dimension",
141 ));
142 }
143
144 if self.hyperbolic_curvature >= 0.0 {
145 return Err(MathError::invalid_parameter(
146 "hyperbolic_curvature",
147 "must be negative",
148 ));
149 }
150
151 if self.spherical_curvature <= 0.0 {
152 return Err(MathError::invalid_parameter(
153 "spherical_curvature",
154 "must be positive",
155 ));
156 }
157
158 Ok(())
159 }
160
161 pub fn component_ranges(&self) -> (std::ops::Range<usize>, std::ops::Range<usize>, std::ops::Range<usize>) {
163 let e_end = self.euclidean_dim;
164 let h_end = e_end + self.hyperbolic_dim;
165 let s_end = h_end + self.spherical_dim;
166
167 (0..e_end, e_end..h_end, h_end..s_end)
168 }
169}
170
171impl Default for ProductManifoldConfig {
172 fn default() -> Self {
173 Self::new(64, 16, 8)
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn test_config_creation() {
184 let config = ProductManifoldConfig::new(32, 16, 8);
185
186 assert_eq!(config.euclidean_dim, 32);
187 assert_eq!(config.hyperbolic_dim, 16);
188 assert_eq!(config.spherical_dim, 8);
189 assert_eq!(config.total_dim(), 56);
190 }
191
192 #[test]
193 fn test_component_ranges() {
194 let config = ProductManifoldConfig::new(10, 5, 3);
195 let (e, h, s) = config.component_ranges();
196
197 assert_eq!(e, 0..10);
198 assert_eq!(h, 10..15);
199 assert_eq!(s, 15..18);
200 }
201
202 #[test]
203 fn test_validation() {
204 let config = ProductManifoldConfig::new(0, 0, 0);
205 assert!(config.validate().is_err());
206
207 let config = ProductManifoldConfig::new(10, 5, 0);
208 assert!(config.validate().is_ok());
209 }
210}