1use crate::error::{MathError, Result};
20use crate::utils::{dot, normalize, norm, EPS};
21
22#[derive(Debug, Clone)]
24pub struct SphericalConfig {
25 pub max_iterations: usize,
27 pub threshold: f64,
29}
30
31impl Default for SphericalConfig {
32 fn default() -> Self {
33 Self {
34 max_iterations: 100,
35 threshold: 1e-8,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct SphericalSpace {
43 dim: usize,
45 config: SphericalConfig,
47}
48
49impl SphericalSpace {
50 pub fn new(ambient_dim: usize) -> Self {
55 Self {
56 dim: ambient_dim.max(1),
57 config: SphericalConfig::default(),
58 }
59 }
60
61 pub fn with_config(mut self, config: SphericalConfig) -> Self {
63 self.config = config;
64 self
65 }
66
67 pub fn ambient_dim(&self) -> usize {
69 self.dim
70 }
71
72 pub fn intrinsic_dim(&self) -> usize {
74 self.dim.saturating_sub(1)
75 }
76
77 pub fn project(&self, point: &[f64]) -> Result<Vec<f64>> {
79 if point.len() != self.dim {
80 return Err(MathError::dimension_mismatch(self.dim, point.len()));
81 }
82
83 let n = norm(point);
84 if n < EPS {
85 let mut result = vec![0.0; self.dim];
87 result[0] = 1.0;
88 return Ok(result);
89 }
90
91 Ok(normalize(point))
92 }
93
94 pub fn is_on_sphere(&self, point: &[f64]) -> bool {
96 if point.len() != self.dim {
97 return false;
98 }
99 let n = norm(point);
100 (n - 1.0).abs() < 1e-6
101 }
102
103 pub fn distance(&self, x: &[f64], y: &[f64]) -> Result<f64> {
107 if x.len() != self.dim || y.len() != self.dim {
108 return Err(MathError::dimension_mismatch(self.dim, x.len()));
109 }
110
111 let cos_angle = dot(x, y).clamp(-1.0, 1.0);
112 Ok(cos_angle.acos())
113 }
114
115 pub fn squared_distance(&self, x: &[f64], y: &[f64]) -> Result<f64> {
117 let d = self.distance(x, y)?;
118 Ok(d * d)
119 }
120
121 pub fn exp_map(&self, x: &[f64], v: &[f64]) -> Result<Vec<f64>> {
125 if x.len() != self.dim || v.len() != self.dim {
126 return Err(MathError::dimension_mismatch(self.dim, x.len()));
127 }
128
129 let v_norm = norm(v);
130
131 if v_norm < EPS {
132 return Ok(x.to_vec());
133 }
134
135 let cos_t = v_norm.cos();
136 let sin_t = v_norm.sin();
137
138 let result: Vec<f64> = x
139 .iter()
140 .zip(v.iter())
141 .map(|(&xi, &vi)| cos_t * xi + sin_t * vi / v_norm)
142 .collect();
143
144 Ok(normalize(&result))
146 }
147
148 pub fn log_map(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
153 if x.len() != self.dim || y.len() != self.dim {
154 return Err(MathError::dimension_mismatch(self.dim, x.len()));
155 }
156
157 let cos_theta = dot(x, y).clamp(-1.0, 1.0);
158 let theta = cos_theta.acos();
159
160 if theta < EPS {
161 return Ok(vec![0.0; self.dim]);
163 }
164
165 if (theta - std::f64::consts::PI).abs() < EPS {
166 return Err(MathError::numerical_instability(
168 "Antipodal points have undefined log map",
169 ));
170 }
171
172 let scale = theta / theta.sin();
173
174 let result: Vec<f64> = x
175 .iter()
176 .zip(y.iter())
177 .map(|(&xi, &yi)| scale * (yi - cos_theta * xi))
178 .collect();
179
180 Ok(result)
181 }
182
183 pub fn parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
187 if x.len() != self.dim || y.len() != self.dim || v.len() != self.dim {
188 return Err(MathError::dimension_mismatch(self.dim, x.len()));
189 }
190
191 let cos_theta = dot(x, y).clamp(-1.0, 1.0);
192
193 if (cos_theta - 1.0).abs() < EPS {
194 return Ok(v.to_vec());
196 }
197
198 let theta = cos_theta.acos();
199
200 let u: Vec<f64> = x
202 .iter()
203 .zip(y.iter())
204 .map(|(&xi, &yi)| yi - cos_theta * xi)
205 .collect();
206 let u = normalize(&u);
207
208 let v_u = dot(v, &u);
210
211 let result: Vec<f64> = (0..self.dim)
213 .map(|i| {
214 let v_perp = v[i] - v_u * u[i] - dot(v, x) * x[i];
215 v_perp
216 + v_u * (-theta.sin() * x[i] + theta.cos() * u[i])
217 - dot(v, x) * (theta.cos() * x[i] + theta.sin() * u[i])
218 })
219 .collect();
220
221 Ok(result)
222 }
223
224 pub fn frechet_mean(&self, points: &[Vec<f64>], weights: Option<&[f64]>) -> Result<Vec<f64>> {
228 if points.is_empty() {
229 return Err(MathError::empty_input("points"));
230 }
231
232 let n = points.len();
233 let uniform_weight = 1.0 / n as f64;
234 let weights: Vec<f64> = match weights {
235 Some(w) => {
236 let sum: f64 = w.iter().sum();
237 w.iter().map(|&wi| wi / sum).collect()
238 }
239 None => vec![uniform_weight; n],
240 };
241
242 let mut mean: Vec<f64> = vec![0.0; self.dim];
244 for (p, &w) in points.iter().zip(weights.iter()) {
245 for (mi, &pi) in mean.iter_mut().zip(p.iter()) {
246 *mi += w * pi;
247 }
248 }
249 mean = self.project(&mean)?;
250
251 for _ in 0..self.config.max_iterations {
253 let mut gradient = vec![0.0; self.dim];
255
256 for (p, &w) in points.iter().zip(weights.iter()) {
257 if let Ok(log_v) = self.log_map(&mean, p) {
258 for (gi, &li) in gradient.iter_mut().zip(log_v.iter()) {
259 *gi += w * li;
260 }
261 }
262 }
263
264 let grad_norm = norm(&gradient);
265 if grad_norm < self.config.threshold {
266 break;
267 }
268
269 mean = self.exp_map(&mean, &gradient)?;
271 }
272
273 Ok(mean)
274 }
275
276 pub fn geodesic(&self, x: &[f64], y: &[f64], t: f64) -> Result<Vec<f64>> {
280 if x.len() != self.dim || y.len() != self.dim {
281 return Err(MathError::dimension_mismatch(self.dim, x.len()));
282 }
283
284 let t = t.clamp(0.0, 1.0);
285
286 let cos_theta = dot(x, y).clamp(-1.0, 1.0);
287 let theta = cos_theta.acos();
288
289 if theta < EPS {
290 return Ok(x.to_vec());
291 }
292
293 let sin_theta = theta.sin();
294 let a = ((1.0 - t) * theta).sin() / sin_theta;
295 let b = (t * theta).sin() / sin_theta;
296
297 let result: Vec<f64> = x
298 .iter()
299 .zip(y.iter())
300 .map(|(&xi, &yi)| a * xi + b * yi)
301 .collect();
302
303 Ok(normalize(&result))
305 }
306
307 pub fn sample_uniform(&self, rng: &mut impl rand::Rng) -> Vec<f64> {
309 use rand_distr::{Distribution, StandardNormal};
310
311 let point: Vec<f64> = (0..self.dim)
312 .map(|_| StandardNormal.sample(rng))
313 .collect();
314
315 normalize(&point)
316 }
317
318 pub fn mean_direction(&self, points: &[Vec<f64>]) -> Result<Vec<f64>> {
322 if points.is_empty() {
323 return Err(MathError::empty_input("points"));
324 }
325
326 let mut sum = vec![0.0; self.dim];
327 for p in points {
328 if p.len() != self.dim {
329 return Err(MathError::dimension_mismatch(self.dim, p.len()));
330 }
331 for (si, &pi) in sum.iter_mut().zip(p.iter()) {
332 *si += pi;
333 }
334 }
335
336 Ok(normalize(&sum))
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
345 fn test_project_onto_sphere() {
346 let sphere = SphericalSpace::new(3);
347
348 let point = vec![3.0, 4.0, 0.0];
349 let projected = sphere.project(&point).unwrap();
350
351 let norm: f64 = projected.iter().map(|&x| x * x).sum::<f64>().sqrt();
352 assert!((norm - 1.0).abs() < 1e-10);
353 }
354
355 #[test]
356 fn test_geodesic_distance() {
357 let sphere = SphericalSpace::new(3);
358
359 let x = vec![1.0, 0.0, 0.0];
361 let y = vec![0.0, 1.0, 0.0];
362
363 let dist = sphere.distance(&x, &y).unwrap();
364 let expected = std::f64::consts::PI / 2.0;
365
366 assert!((dist - expected).abs() < 1e-10);
367 }
368
369 #[test]
370 fn test_exp_log_inverse() {
371 let sphere = SphericalSpace::new(3);
372
373 let x = vec![1.0, 0.0, 0.0];
374 let y = sphere.project(&vec![1.0, 1.0, 0.0]).unwrap();
375
376 let v = sphere.log_map(&x, &y).unwrap();
378 let y_recovered = sphere.exp_map(&x, &v).unwrap();
379
380 for (yi, &yr) in y.iter().zip(y_recovered.iter()) {
381 assert!((yi - yr).abs() < 1e-6, "Exp-log inverse failed");
382 }
383 }
384
385 #[test]
386 fn test_geodesic_interpolation() {
387 let sphere = SphericalSpace::new(3);
388
389 let x = vec![1.0, 0.0, 0.0];
390 let y = vec![0.0, 1.0, 0.0];
391
392 let mid = sphere.geodesic(&x, &y, 0.5).unwrap();
394
395 let norm: f64 = mid.iter().map(|&m| m * m).sum::<f64>().sqrt();
397 assert!((norm - 1.0).abs() < 1e-10);
398
399 let d_x = sphere.distance(&x, &mid).unwrap();
401 let d_y = sphere.distance(&mid, &y).unwrap();
402 assert!((d_x - d_y).abs() < 1e-10);
403 }
404
405 #[test]
406 fn test_frechet_mean() {
407 let sphere = SphericalSpace::new(3);
408
409 let points = vec![
411 vec![0.9, 0.1, 0.0],
412 vec![0.9, -0.1, 0.0],
413 vec![0.9, 0.0, 0.1],
414 vec![0.9, 0.0, -0.1],
415 ];
416
417 let points: Vec<Vec<f64>> = points
418 .into_iter()
419 .map(|p| sphere.project(&p).unwrap())
420 .collect();
421
422 let mean = sphere.frechet_mean(&points, None).unwrap();
423
424 assert!(mean[0] > 0.95);
426 }
427}