Skip to main content

ruvector_math/spherical/
mod.rs

1//! Spherical Geometry
2//!
3//! Operations on the n-sphere S^n = {x ∈ R^{n+1} : ||x|| = 1}
4//!
5//! ## Use Cases in Vector Search
6//!
7//! - **Cyclical patterns**: Time-of-day, day-of-week, seasonal data
8//! - **Directional data**: Wind directions, compass bearings
9//! - **Normalized embeddings**: Common in NLP (unit-normalized word vectors)
10//! - **Angular similarity**: Natural for cosine similarity
11//!
12//! ## Key Operations
13//!
14//! - Geodesic distance: d(x, y) = arccos(⟨x, y⟩)
15//! - Exponential map: Move from x in direction v
16//! - Logarithmic map: Find direction from x to y
17//! - Fréchet mean: Spherical centroid
18
19use crate::error::{MathError, Result};
20use crate::utils::{dot, norm, normalize, EPS};
21
22/// Configuration for spherical operations
23#[derive(Debug, Clone)]
24pub struct SphericalConfig {
25    /// Maximum iterations for iterative algorithms
26    pub max_iterations: usize,
27    /// Convergence threshold
28    pub threshold: f64,
29}
30
31impl Default for SphericalConfig {
32    fn default() -> Self {
33        Self {
34            max_iterations: 100,
35            threshold: 1e-8,
36        }
37    }
38}
39
40/// Spherical space operations
41#[derive(Debug, Clone)]
42pub struct SphericalSpace {
43    /// Dimension of the sphere (ambient dimension - 1)
44    dim: usize,
45    /// Configuration
46    config: SphericalConfig,
47}
48
49impl SphericalSpace {
50    /// Create a new spherical space S^{n-1} embedded in R^n
51    ///
52    /// # Arguments
53    /// * `ambient_dim` - Dimension of ambient Euclidean space
54    pub fn new(ambient_dim: usize) -> Self {
55        Self {
56            dim: ambient_dim.max(1),
57            config: SphericalConfig::default(),
58        }
59    }
60
61    /// Set configuration
62    pub fn with_config(mut self, config: SphericalConfig) -> Self {
63        self.config = config;
64        self
65    }
66
67    /// Get ambient dimension
68    pub fn ambient_dim(&self) -> usize {
69        self.dim
70    }
71
72    /// Get intrinsic dimension (ambient_dim - 1)
73    pub fn intrinsic_dim(&self) -> usize {
74        self.dim.saturating_sub(1)
75    }
76
77    /// Project a point onto the sphere
78    pub fn project(&self, point: &[f64]) -> Result<Vec<f64>> {
79        if point.len() != self.dim {
80            return Err(MathError::dimension_mismatch(self.dim, point.len()));
81        }
82
83        let n = norm(point);
84        if n < EPS {
85            // Return north pole for zero vector
86            let mut result = vec![0.0; self.dim];
87            result[0] = 1.0;
88            return Ok(result);
89        }
90
91        Ok(normalize(point))
92    }
93
94    /// Check if point is on the sphere
95    pub fn is_on_sphere(&self, point: &[f64]) -> bool {
96        if point.len() != self.dim {
97            return false;
98        }
99        let n = norm(point);
100        (n - 1.0).abs() < 1e-6
101    }
102
103    /// Geodesic distance on the sphere: d(x, y) = arccos(⟨x, y⟩)
104    ///
105    /// This is the great-circle distance.
106    pub fn distance(&self, x: &[f64], y: &[f64]) -> Result<f64> {
107        if x.len() != self.dim || y.len() != self.dim {
108            return Err(MathError::dimension_mismatch(self.dim, x.len()));
109        }
110
111        let cos_angle = dot(x, y).clamp(-1.0, 1.0);
112        Ok(cos_angle.acos())
113    }
114
115    /// Squared geodesic distance (useful for optimization)
116    pub fn squared_distance(&self, x: &[f64], y: &[f64]) -> Result<f64> {
117        let d = self.distance(x, y)?;
118        Ok(d * d)
119    }
120
121    /// Exponential map: exp_x(v) - move from x in direction v
122    ///
123    /// exp_x(v) = cos(||v||) x + sin(||v||) (v / ||v||)
124    pub fn exp_map(&self, x: &[f64], v: &[f64]) -> Result<Vec<f64>> {
125        if x.len() != self.dim || v.len() != self.dim {
126            return Err(MathError::dimension_mismatch(self.dim, x.len()));
127        }
128
129        let v_norm = norm(v);
130
131        if v_norm < EPS {
132            return Ok(x.to_vec());
133        }
134
135        let cos_t = v_norm.cos();
136        let sin_t = v_norm.sin();
137
138        let result: Vec<f64> = x
139            .iter()
140            .zip(v.iter())
141            .map(|(&xi, &vi)| cos_t * xi + sin_t * vi / v_norm)
142            .collect();
143
144        // Ensure on sphere
145        Ok(normalize(&result))
146    }
147
148    /// Logarithmic map: log_x(y) - tangent vector at x pointing toward y
149    ///
150    /// log_x(y) = (θ / sin(θ)) (y - cos(θ) x)
151    /// where θ = d(x, y) = arccos(⟨x, y⟩)
152    pub fn log_map(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
153        if x.len() != self.dim || y.len() != self.dim {
154            return Err(MathError::dimension_mismatch(self.dim, x.len()));
155        }
156
157        let cos_theta = dot(x, y).clamp(-1.0, 1.0);
158        let theta = cos_theta.acos();
159
160        if theta < EPS {
161            // Points are the same
162            return Ok(vec![0.0; self.dim]);
163        }
164
165        if (theta - std::f64::consts::PI).abs() < EPS {
166            // Points are antipodal - log map is not well-defined
167            return Err(MathError::numerical_instability(
168                "Antipodal points have undefined log map",
169            ));
170        }
171
172        let scale = theta / theta.sin();
173
174        let result: Vec<f64> = x
175            .iter()
176            .zip(y.iter())
177            .map(|(&xi, &yi)| scale * (yi - cos_theta * xi))
178            .collect();
179
180        Ok(result)
181    }
182
183    /// Parallel transport vector v from x to y
184    ///
185    /// Transports tangent vector at x along geodesic to y
186    pub fn parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
187        if x.len() != self.dim || y.len() != self.dim || v.len() != self.dim {
188            return Err(MathError::dimension_mismatch(self.dim, x.len()));
189        }
190
191        let cos_theta = dot(x, y).clamp(-1.0, 1.0);
192
193        if (cos_theta - 1.0).abs() < EPS {
194            // Same point, no transport needed
195            return Ok(v.to_vec());
196        }
197
198        let theta = cos_theta.acos();
199
200        // Direction from x to y (unit tangent)
201        let u: Vec<f64> = x
202            .iter()
203            .zip(y.iter())
204            .map(|(&xi, &yi)| yi - cos_theta * xi)
205            .collect();
206        let u = normalize(&u);
207
208        // Component of v along u
209        let v_u = dot(v, &u);
210
211        // Transport formula
212        let result: Vec<f64> = (0..self.dim)
213            .map(|i| {
214                let v_perp = v[i] - v_u * u[i] - dot(v, x) * x[i];
215                v_perp + v_u * (-theta.sin() * x[i] + theta.cos() * u[i])
216                    - dot(v, x) * (theta.cos() * x[i] + theta.sin() * u[i])
217            })
218            .collect();
219
220        Ok(result)
221    }
222
223    /// Fréchet mean on the sphere (spherical centroid)
224    ///
225    /// Minimizes: Σᵢ wᵢ d(m, xᵢ)²
226    pub fn frechet_mean(&self, points: &[Vec<f64>], weights: Option<&[f64]>) -> Result<Vec<f64>> {
227        if points.is_empty() {
228            return Err(MathError::empty_input("points"));
229        }
230
231        let n = points.len();
232        let uniform_weight = 1.0 / n as f64;
233        let weights: Vec<f64> = match weights {
234            Some(w) => {
235                let sum: f64 = w.iter().sum();
236                w.iter().map(|&wi| wi / sum).collect()
237            }
238            None => vec![uniform_weight; n],
239        };
240
241        // Initialize with weighted Euclidean mean, then project
242        let mut mean: Vec<f64> = vec![0.0; self.dim];
243        for (p, &w) in points.iter().zip(weights.iter()) {
244            for (mi, &pi) in mean.iter_mut().zip(p.iter()) {
245                *mi += w * pi;
246            }
247        }
248        mean = self.project(&mean)?;
249
250        // Iterative refinement (Riemannian gradient descent)
251        for _ in 0..self.config.max_iterations {
252            // Compute Riemannian gradient: Σ wᵢ log_{mean}(xᵢ)
253            let mut gradient = vec![0.0; self.dim];
254
255            for (p, &w) in points.iter().zip(weights.iter()) {
256                if let Ok(log_v) = self.log_map(&mean, p) {
257                    for (gi, &li) in gradient.iter_mut().zip(log_v.iter()) {
258                        *gi += w * li;
259                    }
260                }
261            }
262
263            let grad_norm = norm(&gradient);
264            if grad_norm < self.config.threshold {
265                break;
266            }
267
268            // Step along geodesic
269            mean = self.exp_map(&mean, &gradient)?;
270        }
271
272        Ok(mean)
273    }
274
275    /// Geodesic interpolation: point at fraction t along geodesic from x to y
276    ///
277    /// γ(t) = sin((1-t)θ)/sin(θ) x + sin(tθ)/sin(θ) y
278    pub fn geodesic(&self, x: &[f64], y: &[f64], t: f64) -> Result<Vec<f64>> {
279        if x.len() != self.dim || y.len() != self.dim {
280            return Err(MathError::dimension_mismatch(self.dim, x.len()));
281        }
282
283        let t = t.clamp(0.0, 1.0);
284
285        let cos_theta = dot(x, y).clamp(-1.0, 1.0);
286        let theta = cos_theta.acos();
287
288        if theta < EPS {
289            return Ok(x.to_vec());
290        }
291
292        let sin_theta = theta.sin();
293        let a = ((1.0 - t) * theta).sin() / sin_theta;
294        let b = (t * theta).sin() / sin_theta;
295
296        let result: Vec<f64> = x
297            .iter()
298            .zip(y.iter())
299            .map(|(&xi, &yi)| a * xi + b * yi)
300            .collect();
301
302        // Ensure on sphere
303        Ok(normalize(&result))
304    }
305
306    /// Sample uniformly from the sphere
307    pub fn sample_uniform(&self, rng: &mut impl rand::Rng) -> Vec<f64> {
308        use rand_distr::{Distribution, StandardNormal};
309
310        let point: Vec<f64> = (0..self.dim).map(|_| StandardNormal.sample(rng)).collect();
311
312        normalize(&point)
313    }
314
315    /// Von Mises-Fisher mean direction MLE
316    ///
317    /// Computes the mean direction (mode of vMF distribution)
318    pub fn mean_direction(&self, points: &[Vec<f64>]) -> Result<Vec<f64>> {
319        if points.is_empty() {
320            return Err(MathError::empty_input("points"));
321        }
322
323        let mut sum = vec![0.0; self.dim];
324        for p in points {
325            if p.len() != self.dim {
326                return Err(MathError::dimension_mismatch(self.dim, p.len()));
327            }
328            for (si, &pi) in sum.iter_mut().zip(p.iter()) {
329                *si += pi;
330            }
331        }
332
333        Ok(normalize(&sum))
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn test_project_onto_sphere() {
343        let sphere = SphericalSpace::new(3);
344
345        let point = vec![3.0, 4.0, 0.0];
346        let projected = sphere.project(&point).unwrap();
347
348        let norm: f64 = projected.iter().map(|&x| x * x).sum::<f64>().sqrt();
349        assert!((norm - 1.0).abs() < 1e-10);
350    }
351
352    #[test]
353    fn test_geodesic_distance() {
354        let sphere = SphericalSpace::new(3);
355
356        // Orthogonal unit vectors
357        let x = vec![1.0, 0.0, 0.0];
358        let y = vec![0.0, 1.0, 0.0];
359
360        let dist = sphere.distance(&x, &y).unwrap();
361        let expected = std::f64::consts::PI / 2.0;
362
363        assert!((dist - expected).abs() < 1e-10);
364    }
365
366    #[test]
367    fn test_exp_log_inverse() {
368        let sphere = SphericalSpace::new(3);
369
370        let x = vec![1.0, 0.0, 0.0];
371        let y = sphere.project(&vec![1.0, 1.0, 0.0]).unwrap();
372
373        // log then exp should return to y
374        let v = sphere.log_map(&x, &y).unwrap();
375        let y_recovered = sphere.exp_map(&x, &v).unwrap();
376
377        for (yi, &yr) in y.iter().zip(y_recovered.iter()) {
378            assert!((yi - yr).abs() < 1e-6, "Exp-log inverse failed");
379        }
380    }
381
382    #[test]
383    fn test_geodesic_interpolation() {
384        let sphere = SphericalSpace::new(3);
385
386        let x = vec![1.0, 0.0, 0.0];
387        let y = vec![0.0, 1.0, 0.0];
388
389        // Midpoint
390        let mid = sphere.geodesic(&x, &y, 0.5).unwrap();
391
392        // Should be on sphere
393        let norm: f64 = mid.iter().map(|&m| m * m).sum::<f64>().sqrt();
394        assert!((norm - 1.0).abs() < 1e-10);
395
396        // Should be equidistant
397        let d_x = sphere.distance(&x, &mid).unwrap();
398        let d_y = sphere.distance(&mid, &y).unwrap();
399        assert!((d_x - d_y).abs() < 1e-10);
400    }
401
402    #[test]
403    fn test_frechet_mean() {
404        let sphere = SphericalSpace::new(3);
405
406        // Points near north pole
407        let points = vec![
408            vec![0.9, 0.1, 0.0],
409            vec![0.9, -0.1, 0.0],
410            vec![0.9, 0.0, 0.1],
411            vec![0.9, 0.0, -0.1],
412        ];
413
414        let points: Vec<Vec<f64>> = points
415            .into_iter()
416            .map(|p| sphere.project(&p).unwrap())
417            .collect();
418
419        let mean = sphere.frechet_mean(&points, None).unwrap();
420
421        // Mean should be close to (1, 0, 0)
422        assert!(mean[0] > 0.95);
423    }
424}