Skip to main content

ruvector_math/product_manifold/
operations.rs

1//! Additional product manifold operations
2
3use super::ProductManifold;
4use crate::error::{MathError, Result};
5use crate::utils::{norm, EPS};
6
7#[cfg(feature = "parallel")]
8use rayon::prelude::*;
9
10/// Batch operations on product manifolds
11impl ProductManifold {
12    /// Compute pairwise distances between all points
13    /// Uses parallel computation when 'parallel' feature is enabled
14    pub fn pairwise_distances(&self, points: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
15        let n = points.len();
16
17        #[cfg(feature = "parallel")]
18        {
19            self.pairwise_distances_parallel(points, n)
20        }
21
22        #[cfg(not(feature = "parallel"))]
23        {
24            self.pairwise_distances_sequential(points, n)
25        }
26    }
27
28    /// Sequential pairwise distance computation
29    #[inline]
30    fn pairwise_distances_sequential(
31        &self,
32        points: &[Vec<f64>],
33        n: usize,
34    ) -> Result<Vec<Vec<f64>>> {
35        let mut distances = vec![vec![0.0; n]; n];
36
37        for i in 0..n {
38            for j in (i + 1)..n {
39                let d = self.distance(&points[i], &points[j])?;
40                distances[i][j] = d;
41                distances[j][i] = d;
42            }
43        }
44
45        Ok(distances)
46    }
47
48    /// Parallel pairwise distance computation using rayon
49    #[cfg(feature = "parallel")]
50    fn pairwise_distances_parallel(&self, points: &[Vec<f64>], n: usize) -> Result<Vec<Vec<f64>>> {
51        // Compute upper triangle in parallel
52        let pairs: Vec<_> = (0..n)
53            .flat_map(|i| ((i + 1)..n).map(move |j| (i, j)))
54            .collect();
55
56        let results: Vec<(usize, usize, f64)> = pairs
57            .par_iter()
58            .filter_map(|&(i, j)| {
59                self.distance(&points[i], &points[j])
60                    .ok()
61                    .map(|d| (i, j, d))
62            })
63            .collect();
64
65        let mut distances = vec![vec![0.0; n]; n];
66        for (i, j, d) in results {
67            distances[i][j] = d;
68            distances[j][i] = d;
69        }
70
71        Ok(distances)
72    }
73
74    /// Find k-nearest neighbors
75    /// Uses parallel computation when 'parallel' feature is enabled
76    pub fn knn(&self, query: &[f64], points: &[Vec<f64>], k: usize) -> Result<Vec<(usize, f64)>> {
77        #[cfg(feature = "parallel")]
78        {
79            self.knn_parallel(query, points, k)
80        }
81
82        #[cfg(not(feature = "parallel"))]
83        {
84            self.knn_sequential(query, points, k)
85        }
86    }
87
88    /// Sequential k-nearest neighbors
89    #[inline]
90    fn knn_sequential(
91        &self,
92        query: &[f64],
93        points: &[Vec<f64>],
94        k: usize,
95    ) -> Result<Vec<(usize, f64)>> {
96        let mut distances: Vec<(usize, f64)> = points
97            .iter()
98            .enumerate()
99            .filter_map(|(i, p)| self.distance(query, p).ok().map(|d| (i, d)))
100            .collect();
101
102        // Use sort_unstable_by for better performance
103        distances
104            .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
105        distances.truncate(k);
106
107        Ok(distances)
108    }
109
110    /// Parallel k-nearest neighbors using rayon
111    #[cfg(feature = "parallel")]
112    fn knn_parallel(
113        &self,
114        query: &[f64],
115        points: &[Vec<f64>],
116        k: usize,
117    ) -> Result<Vec<(usize, f64)>> {
118        let mut distances: Vec<(usize, f64)> = points
119            .par_iter()
120            .enumerate()
121            .filter_map(|(i, p)| self.distance(query, p).ok().map(|d| (i, d)))
122            .collect();
123
124        // Use sort_unstable_by for better performance
125        distances
126            .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
127        distances.truncate(k);
128
129        Ok(distances)
130    }
131
132    /// Geodesic interpolation between two points
133    ///
134    /// Returns point at fraction t along geodesic from x to y
135    pub fn geodesic(&self, x: &[f64], y: &[f64], t: f64) -> Result<Vec<f64>> {
136        let t = t.clamp(0.0, 1.0);
137
138        // log_x(y) gives direction
139        let v = self.log_map(x, y)?;
140
141        // Scale by t
142        let tv: Vec<f64> = v.iter().map(|&vi| t * vi).collect();
143
144        // exp_x(t * v)
145        self.exp_map(x, &tv)
146    }
147
148    /// Sample points along geodesic
149    pub fn geodesic_path(&self, x: &[f64], y: &[f64], num_points: usize) -> Result<Vec<Vec<f64>>> {
150        let mut path = Vec::with_capacity(num_points);
151
152        for i in 0..num_points {
153            let t = i as f64 / (num_points - 1).max(1) as f64;
154            path.push(self.geodesic(x, y, t)?);
155        }
156
157        Ok(path)
158    }
159
160    /// Parallel transport vector v from x to y
161    pub fn parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
162        if x.len() != self.dim() || y.len() != self.dim() || v.len() != self.dim() {
163            return Err(MathError::dimension_mismatch(self.dim(), x.len()));
164        }
165
166        let mut result = vec![0.0; self.dim()];
167        let (e_range, h_range, s_range) = self.config().component_ranges();
168
169        // Euclidean: parallel transport is identity
170        for i in e_range.clone() {
171            result[i] = v[i];
172        }
173
174        // Hyperbolic parallel transport
175        if !h_range.is_empty() {
176            let x_h = &x[h_range.clone()];
177            let y_h = &y[h_range.clone()];
178            let v_h = &v[h_range.clone()];
179            let pt_h = self.poincare_parallel_transport(x_h, y_h, v_h)?;
180            for (i, val) in h_range.clone().zip(pt_h.iter()) {
181                result[i] = *val;
182            }
183        }
184
185        // Spherical parallel transport
186        if !s_range.is_empty() {
187            let x_s = &x[s_range.clone()];
188            let y_s = &y[s_range.clone()];
189            let v_s = &v[s_range.clone()];
190            let pt_s = self.spherical_parallel_transport(x_s, y_s, v_s)?;
191            for (i, val) in s_range.clone().zip(pt_s.iter()) {
192                result[i] = *val;
193            }
194        }
195
196        Ok(result)
197    }
198
199    /// Poincaré ball parallel transport
200    fn poincare_parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
201        let c = -self.config().hyperbolic_curvature;
202
203        let x_norm_sq: f64 = x.iter().map(|&xi| xi * xi).sum();
204        let y_norm_sq: f64 = y.iter().map(|&yi| yi * yi).sum();
205
206        let lambda_x = 2.0 / (1.0 - c * x_norm_sq).max(EPS);
207        let lambda_y = 2.0 / (1.0 - c * y_norm_sq).max(EPS);
208
209        let scale = lambda_x / lambda_y;
210
211        // Gyration correction
212        let xy_dot: f64 = x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum();
213        let _gyration_factor = 1.0 + c * xy_dot;
214
215        // Simplified parallel transport (good approximation for small distances)
216        Ok(v.iter().map(|&vi| scale * vi).collect())
217    }
218
219    /// Spherical parallel transport
220    fn spherical_parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
221        use crate::utils::dot;
222
223        let cos_theta = dot(x, y).clamp(-1.0, 1.0);
224
225        if (cos_theta - 1.0).abs() < EPS {
226            return Ok(v.to_vec());
227        }
228
229        let theta = cos_theta.acos();
230
231        // Direction from x to y
232        let u: Vec<f64> = x
233            .iter()
234            .zip(y.iter())
235            .map(|(&xi, &yi)| yi - cos_theta * xi)
236            .collect();
237        let u_norm = norm(&u);
238
239        if u_norm < EPS {
240            return Ok(v.to_vec());
241        }
242
243        let u: Vec<f64> = u.iter().map(|&ui| ui / u_norm).collect();
244
245        // Components of v
246        let v_u = dot(v, &u);
247        let v_x = dot(v, x);
248
249        // Parallel transport formula
250        let result: Vec<f64> = (0..x.len())
251            .map(|i| {
252                let v_perp = v[i] - v_u * u[i] - v_x * x[i];
253                v_perp + v_u * (-theta.sin() * x[i] + theta.cos() * u[i])
254                    - v_x * (theta.cos() * x[i] + theta.sin() * u[i])
255            })
256            .collect();
257
258        Ok(result)
259    }
260
261    /// Compute variance of points on manifold
262    pub fn variance(&self, points: &[Vec<f64>], mean: Option<&[f64]>) -> Result<f64> {
263        if points.is_empty() {
264            return Ok(0.0);
265        }
266
267        let mean = match mean {
268            Some(m) => m.to_vec(),
269            None => self.frechet_mean(points, None)?,
270        };
271
272        let mut total_sq_dist = 0.0;
273        for p in points {
274            let d = self.distance(&mean, p)?;
275            total_sq_dist += d * d;
276        }
277
278        Ok(total_sq_dist / points.len() as f64)
279    }
280
281    /// Project gradient to tangent space at point
282    ///
283    /// For product manifolds, this projects each component appropriately
284    pub fn project_gradient(&self, point: &[f64], gradient: &[f64]) -> Result<Vec<f64>> {
285        if point.len() != self.dim() || gradient.len() != self.dim() {
286            return Err(MathError::dimension_mismatch(self.dim(), point.len()));
287        }
288
289        let mut result = gradient.to_vec();
290        let (_e_range, h_range, s_range) = self.config().component_ranges();
291
292        // Euclidean: gradient is already in tangent space (no modification needed)
293
294        // Hyperbolic: scale by (1 - ||x||²)² / 4
295        if !h_range.is_empty() {
296            let x_h = &point[h_range.clone()];
297            let x_norm_sq: f64 = x_h.iter().map(|&xi| xi * xi).sum();
298            let c = -self.config().hyperbolic_curvature;
299            let lambda = 2.0 / (1.0 - c * x_norm_sq).max(EPS);
300            let scale = 1.0 / (lambda * lambda);
301
302            for i in h_range.clone() {
303                result[i] *= scale;
304            }
305        }
306
307        // Spherical: project out normal component
308        if !s_range.is_empty() {
309            let x_s = &point[s_range.clone()];
310            let g_s = &gradient[s_range.clone()];
311
312            // Normal component: (g · x) x
313            let normal_component: f64 = g_s.iter().zip(x_s.iter()).map(|(&gi, &xi)| gi * xi).sum();
314
315            for (i, &xi) in s_range.clone().zip(x_s.iter()) {
316                result[i] -= normal_component * xi;
317            }
318        }
319
320        Ok(result)
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[test]
329    fn test_pairwise_distances() {
330        let manifold = ProductManifold::new(2, 0, 0);
331
332        let points = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
333
334        let dists = manifold.pairwise_distances(&points).unwrap();
335
336        assert!(dists[0][0].abs() < 1e-10);
337        assert!((dists[0][1] - 1.0).abs() < 1e-10);
338        assert!((dists[0][2] - 1.0).abs() < 1e-10);
339    }
340
341    #[test]
342    fn test_knn() {
343        let manifold = ProductManifold::new(2, 0, 0);
344
345        let points = vec![
346            vec![0.0, 0.0],
347            vec![1.0, 0.0],
348            vec![2.0, 0.0],
349            vec![3.0, 0.0],
350        ];
351
352        let query = vec![0.5, 0.0];
353        let neighbors = manifold.knn(&query, &points, 2).unwrap();
354
355        assert_eq!(neighbors.len(), 2);
356        // Closest should be [0,0] or [1,0]
357        assert!(neighbors[0].0 == 0 || neighbors[0].0 == 1);
358    }
359
360    #[test]
361    fn test_geodesic_path() {
362        let manifold = ProductManifold::new(2, 0, 0);
363
364        let x = vec![0.0, 0.0];
365        let y = vec![2.0, 2.0];
366
367        let path = manifold.geodesic_path(&x, &y, 5).unwrap();
368
369        assert_eq!(path.len(), 5);
370
371        // Midpoint should be (1, 1)
372        assert!((path[2][0] - 1.0).abs() < 1e-6);
373        assert!((path[2][1] - 1.0).abs() < 1e-6);
374    }
375
376    #[test]
377    fn test_variance() {
378        let manifold = ProductManifold::new(2, 0, 0);
379
380        // Points at unit distance from origin
381        let points = vec![
382            vec![1.0, 0.0],
383            vec![-1.0, 0.0],
384            vec![0.0, 1.0],
385            vec![0.0, -1.0],
386        ];
387
388        let variance = manifold.variance(&points, Some(&vec![0.0, 0.0])).unwrap();
389        assert!((variance - 1.0).abs() < 1e-10);
390    }
391}