ruvector_math/spherical/
mod.rs

1//! Spherical Geometry
2//!
3//! Operations on the n-sphere S^n = {x ∈ R^{n+1} : ||x|| = 1}
4//!
5//! ## Use Cases in Vector Search
6//!
7//! - **Cyclical patterns**: Time-of-day, day-of-week, seasonal data
8//! - **Directional data**: Wind directions, compass bearings
9//! - **Normalized embeddings**: Common in NLP (unit-normalized word vectors)
10//! - **Angular similarity**: Natural for cosine similarity
11//!
12//! ## Key Operations
13//!
14//! - Geodesic distance: d(x, y) = arccos(⟨x, y⟩)
15//! - Exponential map: Move from x in direction v
16//! - Logarithmic map: Find direction from x to y
17//! - Fréchet mean: Spherical centroid
18
19use crate::error::{MathError, Result};
20use crate::utils::{dot, normalize, norm, EPS};
21
22/// Configuration for spherical operations
23#[derive(Debug, Clone)]
24pub struct SphericalConfig {
25    /// Maximum iterations for iterative algorithms
26    pub max_iterations: usize,
27    /// Convergence threshold
28    pub threshold: f64,
29}
30
31impl Default for SphericalConfig {
32    fn default() -> Self {
33        Self {
34            max_iterations: 100,
35            threshold: 1e-8,
36        }
37    }
38}
39
40/// Spherical space operations
41#[derive(Debug, Clone)]
42pub struct SphericalSpace {
43    /// Dimension of the sphere (ambient dimension - 1)
44    dim: usize,
45    /// Configuration
46    config: SphericalConfig,
47}
48
49impl SphericalSpace {
50    /// Create a new spherical space S^{n-1} embedded in R^n
51    ///
52    /// # Arguments
53    /// * `ambient_dim` - Dimension of ambient Euclidean space
54    pub fn new(ambient_dim: usize) -> Self {
55        Self {
56            dim: ambient_dim.max(1),
57            config: SphericalConfig::default(),
58        }
59    }
60
61    /// Set configuration
62    pub fn with_config(mut self, config: SphericalConfig) -> Self {
63        self.config = config;
64        self
65    }
66
67    /// Get ambient dimension
68    pub fn ambient_dim(&self) -> usize {
69        self.dim
70    }
71
72    /// Get intrinsic dimension (ambient_dim - 1)
73    pub fn intrinsic_dim(&self) -> usize {
74        self.dim.saturating_sub(1)
75    }
76
77    /// Project a point onto the sphere
78    pub fn project(&self, point: &[f64]) -> Result<Vec<f64>> {
79        if point.len() != self.dim {
80            return Err(MathError::dimension_mismatch(self.dim, point.len()));
81        }
82
83        let n = norm(point);
84        if n < EPS {
85            // Return north pole for zero vector
86            let mut result = vec![0.0; self.dim];
87            result[0] = 1.0;
88            return Ok(result);
89        }
90
91        Ok(normalize(point))
92    }
93
94    /// Check if point is on the sphere
95    pub fn is_on_sphere(&self, point: &[f64]) -> bool {
96        if point.len() != self.dim {
97            return false;
98        }
99        let n = norm(point);
100        (n - 1.0).abs() < 1e-6
101    }
102
103    /// Geodesic distance on the sphere: d(x, y) = arccos(⟨x, y⟩)
104    ///
105    /// This is the great-circle distance.
106    pub fn distance(&self, x: &[f64], y: &[f64]) -> Result<f64> {
107        if x.len() != self.dim || y.len() != self.dim {
108            return Err(MathError::dimension_mismatch(self.dim, x.len()));
109        }
110
111        let cos_angle = dot(x, y).clamp(-1.0, 1.0);
112        Ok(cos_angle.acos())
113    }
114
115    /// Squared geodesic distance (useful for optimization)
116    pub fn squared_distance(&self, x: &[f64], y: &[f64]) -> Result<f64> {
117        let d = self.distance(x, y)?;
118        Ok(d * d)
119    }
120
121    /// Exponential map: exp_x(v) - move from x in direction v
122    ///
123    /// exp_x(v) = cos(||v||) x + sin(||v||) (v / ||v||)
124    pub fn exp_map(&self, x: &[f64], v: &[f64]) -> Result<Vec<f64>> {
125        if x.len() != self.dim || v.len() != self.dim {
126            return Err(MathError::dimension_mismatch(self.dim, x.len()));
127        }
128
129        let v_norm = norm(v);
130
131        if v_norm < EPS {
132            return Ok(x.to_vec());
133        }
134
135        let cos_t = v_norm.cos();
136        let sin_t = v_norm.sin();
137
138        let result: Vec<f64> = x
139            .iter()
140            .zip(v.iter())
141            .map(|(&xi, &vi)| cos_t * xi + sin_t * vi / v_norm)
142            .collect();
143
144        // Ensure on sphere
145        Ok(normalize(&result))
146    }
147
148    /// Logarithmic map: log_x(y) - tangent vector at x pointing toward y
149    ///
150    /// log_x(y) = (θ / sin(θ)) (y - cos(θ) x)
151    /// where θ = d(x, y) = arccos(⟨x, y⟩)
152    pub fn log_map(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
153        if x.len() != self.dim || y.len() != self.dim {
154            return Err(MathError::dimension_mismatch(self.dim, x.len()));
155        }
156
157        let cos_theta = dot(x, y).clamp(-1.0, 1.0);
158        let theta = cos_theta.acos();
159
160        if theta < EPS {
161            // Points are the same
162            return Ok(vec![0.0; self.dim]);
163        }
164
165        if (theta - std::f64::consts::PI).abs() < EPS {
166            // Points are antipodal - log map is not well-defined
167            return Err(MathError::numerical_instability(
168                "Antipodal points have undefined log map",
169            ));
170        }
171
172        let scale = theta / theta.sin();
173
174        let result: Vec<f64> = x
175            .iter()
176            .zip(y.iter())
177            .map(|(&xi, &yi)| scale * (yi - cos_theta * xi))
178            .collect();
179
180        Ok(result)
181    }
182
183    /// Parallel transport vector v from x to y
184    ///
185    /// Transports tangent vector at x along geodesic to y
186    pub fn parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
187        if x.len() != self.dim || y.len() != self.dim || v.len() != self.dim {
188            return Err(MathError::dimension_mismatch(self.dim, x.len()));
189        }
190
191        let cos_theta = dot(x, y).clamp(-1.0, 1.0);
192
193        if (cos_theta - 1.0).abs() < EPS {
194            // Same point, no transport needed
195            return Ok(v.to_vec());
196        }
197
198        let theta = cos_theta.acos();
199
200        // Direction from x to y (unit tangent)
201        let u: Vec<f64> = x
202            .iter()
203            .zip(y.iter())
204            .map(|(&xi, &yi)| yi - cos_theta * xi)
205            .collect();
206        let u = normalize(&u);
207
208        // Component of v along u
209        let v_u = dot(v, &u);
210
211        // Transport formula
212        let result: Vec<f64> = (0..self.dim)
213            .map(|i| {
214                let v_perp = v[i] - v_u * u[i] - dot(v, x) * x[i];
215                v_perp
216                    + v_u * (-theta.sin() * x[i] + theta.cos() * u[i])
217                    - dot(v, x) * (theta.cos() * x[i] + theta.sin() * u[i])
218            })
219            .collect();
220
221        Ok(result)
222    }
223
224    /// Fréchet mean on the sphere (spherical centroid)
225    ///
226    /// Minimizes: Σᵢ wᵢ d(m, xᵢ)²
227    pub fn frechet_mean(&self, points: &[Vec<f64>], weights: Option<&[f64]>) -> Result<Vec<f64>> {
228        if points.is_empty() {
229            return Err(MathError::empty_input("points"));
230        }
231
232        let n = points.len();
233        let uniform_weight = 1.0 / n as f64;
234        let weights: Vec<f64> = match weights {
235            Some(w) => {
236                let sum: f64 = w.iter().sum();
237                w.iter().map(|&wi| wi / sum).collect()
238            }
239            None => vec![uniform_weight; n],
240        };
241
242        // Initialize with weighted Euclidean mean, then project
243        let mut mean: Vec<f64> = vec![0.0; self.dim];
244        for (p, &w) in points.iter().zip(weights.iter()) {
245            for (mi, &pi) in mean.iter_mut().zip(p.iter()) {
246                *mi += w * pi;
247            }
248        }
249        mean = self.project(&mean)?;
250
251        // Iterative refinement (Riemannian gradient descent)
252        for _ in 0..self.config.max_iterations {
253            // Compute Riemannian gradient: Σ wᵢ log_{mean}(xᵢ)
254            let mut gradient = vec![0.0; self.dim];
255
256            for (p, &w) in points.iter().zip(weights.iter()) {
257                if let Ok(log_v) = self.log_map(&mean, p) {
258                    for (gi, &li) in gradient.iter_mut().zip(log_v.iter()) {
259                        *gi += w * li;
260                    }
261                }
262            }
263
264            let grad_norm = norm(&gradient);
265            if grad_norm < self.config.threshold {
266                break;
267            }
268
269            // Step along geodesic
270            mean = self.exp_map(&mean, &gradient)?;
271        }
272
273        Ok(mean)
274    }
275
276    /// Geodesic interpolation: point at fraction t along geodesic from x to y
277    ///
278    /// γ(t) = sin((1-t)θ)/sin(θ) x + sin(tθ)/sin(θ) y
279    pub fn geodesic(&self, x: &[f64], y: &[f64], t: f64) -> Result<Vec<f64>> {
280        if x.len() != self.dim || y.len() != self.dim {
281            return Err(MathError::dimension_mismatch(self.dim, x.len()));
282        }
283
284        let t = t.clamp(0.0, 1.0);
285
286        let cos_theta = dot(x, y).clamp(-1.0, 1.0);
287        let theta = cos_theta.acos();
288
289        if theta < EPS {
290            return Ok(x.to_vec());
291        }
292
293        let sin_theta = theta.sin();
294        let a = ((1.0 - t) * theta).sin() / sin_theta;
295        let b = (t * theta).sin() / sin_theta;
296
297        let result: Vec<f64> = x
298            .iter()
299            .zip(y.iter())
300            .map(|(&xi, &yi)| a * xi + b * yi)
301            .collect();
302
303        // Ensure on sphere
304        Ok(normalize(&result))
305    }
306
307    /// Sample uniformly from the sphere
308    pub fn sample_uniform(&self, rng: &mut impl rand::Rng) -> Vec<f64> {
309        use rand_distr::{Distribution, StandardNormal};
310
311        let point: Vec<f64> = (0..self.dim)
312            .map(|_| StandardNormal.sample(rng))
313            .collect();
314
315        normalize(&point)
316    }
317
318    /// Von Mises-Fisher mean direction MLE
319    ///
320    /// Computes the mean direction (mode of vMF distribution)
321    pub fn mean_direction(&self, points: &[Vec<f64>]) -> Result<Vec<f64>> {
322        if points.is_empty() {
323            return Err(MathError::empty_input("points"));
324        }
325
326        let mut sum = vec![0.0; self.dim];
327        for p in points {
328            if p.len() != self.dim {
329                return Err(MathError::dimension_mismatch(self.dim, p.len()));
330            }
331            for (si, &pi) in sum.iter_mut().zip(p.iter()) {
332                *si += pi;
333            }
334        }
335
336        Ok(normalize(&sum))
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    #[test]
345    fn test_project_onto_sphere() {
346        let sphere = SphericalSpace::new(3);
347
348        let point = vec![3.0, 4.0, 0.0];
349        let projected = sphere.project(&point).unwrap();
350
351        let norm: f64 = projected.iter().map(|&x| x * x).sum::<f64>().sqrt();
352        assert!((norm - 1.0).abs() < 1e-10);
353    }
354
355    #[test]
356    fn test_geodesic_distance() {
357        let sphere = SphericalSpace::new(3);
358
359        // Orthogonal unit vectors
360        let x = vec![1.0, 0.0, 0.0];
361        let y = vec![0.0, 1.0, 0.0];
362
363        let dist = sphere.distance(&x, &y).unwrap();
364        let expected = std::f64::consts::PI / 2.0;
365
366        assert!((dist - expected).abs() < 1e-10);
367    }
368
369    #[test]
370    fn test_exp_log_inverse() {
371        let sphere = SphericalSpace::new(3);
372
373        let x = vec![1.0, 0.0, 0.0];
374        let y = sphere.project(&vec![1.0, 1.0, 0.0]).unwrap();
375
376        // log then exp should return to y
377        let v = sphere.log_map(&x, &y).unwrap();
378        let y_recovered = sphere.exp_map(&x, &v).unwrap();
379
380        for (yi, &yr) in y.iter().zip(y_recovered.iter()) {
381            assert!((yi - yr).abs() < 1e-6, "Exp-log inverse failed");
382        }
383    }
384
385    #[test]
386    fn test_geodesic_interpolation() {
387        let sphere = SphericalSpace::new(3);
388
389        let x = vec![1.0, 0.0, 0.0];
390        let y = vec![0.0, 1.0, 0.0];
391
392        // Midpoint
393        let mid = sphere.geodesic(&x, &y, 0.5).unwrap();
394
395        // Should be on sphere
396        let norm: f64 = mid.iter().map(|&m| m * m).sum::<f64>().sqrt();
397        assert!((norm - 1.0).abs() < 1e-10);
398
399        // Should be equidistant
400        let d_x = sphere.distance(&x, &mid).unwrap();
401        let d_y = sphere.distance(&mid, &y).unwrap();
402        assert!((d_x - d_y).abs() < 1e-10);
403    }
404
405    #[test]
406    fn test_frechet_mean() {
407        let sphere = SphericalSpace::new(3);
408
409        // Points near north pole
410        let points = vec![
411            vec![0.9, 0.1, 0.0],
412            vec![0.9, -0.1, 0.0],
413            vec![0.9, 0.0, 0.1],
414            vec![0.9, 0.0, -0.1],
415        ];
416
417        let points: Vec<Vec<f64>> = points
418            .into_iter()
419            .map(|p| sphere.project(&p).unwrap())
420            .collect();
421
422        let mean = sphere.frechet_mean(&points, None).unwrap();
423
424        // Mean should be close to (1, 0, 0)
425        assert!(mean[0] > 0.95);
426    }
427}