ruvector_math/product_manifold/
operations.rs

1//! Additional product manifold operations
2
3use crate::error::{MathError, Result};
4use crate::utils::{norm, EPS};
5use super::ProductManifold;
6
7#[cfg(feature = "parallel")]
8use rayon::prelude::*;
9
10/// Batch operations on product manifolds
11impl ProductManifold {
12    /// Compute pairwise distances between all points
13    /// Uses parallel computation when 'parallel' feature is enabled
14    pub fn pairwise_distances(&self, points: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
15        let n = points.len();
16
17        #[cfg(feature = "parallel")]
18        {
19            self.pairwise_distances_parallel(points, n)
20        }
21
22        #[cfg(not(feature = "parallel"))]
23        {
24            self.pairwise_distances_sequential(points, n)
25        }
26    }
27
28    /// Sequential pairwise distance computation
29    #[inline]
30    fn pairwise_distances_sequential(&self, points: &[Vec<f64>], n: usize) -> Result<Vec<Vec<f64>>> {
31        let mut distances = vec![vec![0.0; n]; n];
32
33        for i in 0..n {
34            for j in (i + 1)..n {
35                let d = self.distance(&points[i], &points[j])?;
36                distances[i][j] = d;
37                distances[j][i] = d;
38            }
39        }
40
41        Ok(distances)
42    }
43
44    /// Parallel pairwise distance computation using rayon
45    #[cfg(feature = "parallel")]
46    fn pairwise_distances_parallel(&self, points: &[Vec<f64>], n: usize) -> Result<Vec<Vec<f64>>> {
47        // Compute upper triangle in parallel
48        let pairs: Vec<_> = (0..n)
49            .flat_map(|i| ((i + 1)..n).map(move |j| (i, j)))
50            .collect();
51
52        let results: Vec<(usize, usize, f64)> = pairs
53            .par_iter()
54            .filter_map(|&(i, j)| {
55                self.distance(&points[i], &points[j])
56                    .ok()
57                    .map(|d| (i, j, d))
58            })
59            .collect();
60
61        let mut distances = vec![vec![0.0; n]; n];
62        for (i, j, d) in results {
63            distances[i][j] = d;
64            distances[j][i] = d;
65        }
66
67        Ok(distances)
68    }
69
70    /// Find k-nearest neighbors
71    /// Uses parallel computation when 'parallel' feature is enabled
72    pub fn knn(&self, query: &[f64], points: &[Vec<f64>], k: usize) -> Result<Vec<(usize, f64)>> {
73        #[cfg(feature = "parallel")]
74        {
75            self.knn_parallel(query, points, k)
76        }
77
78        #[cfg(not(feature = "parallel"))]
79        {
80            self.knn_sequential(query, points, k)
81        }
82    }
83
84    /// Sequential k-nearest neighbors
85    #[inline]
86    fn knn_sequential(&self, query: &[f64], points: &[Vec<f64>], k: usize) -> Result<Vec<(usize, f64)>> {
87        let mut distances: Vec<(usize, f64)> = points
88            .iter()
89            .enumerate()
90            .filter_map(|(i, p)| self.distance(query, p).ok().map(|d| (i, d)))
91            .collect();
92
93        // Use sort_unstable_by for better performance
94        distances.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
95        distances.truncate(k);
96
97        Ok(distances)
98    }
99
100    /// Parallel k-nearest neighbors using rayon
101    #[cfg(feature = "parallel")]
102    fn knn_parallel(&self, query: &[f64], points: &[Vec<f64>], k: usize) -> Result<Vec<(usize, f64)>> {
103        let mut distances: Vec<(usize, f64)> = points
104            .par_iter()
105            .enumerate()
106            .filter_map(|(i, p)| self.distance(query, p).ok().map(|d| (i, d)))
107            .collect();
108
109        // Use sort_unstable_by for better performance
110        distances.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
111        distances.truncate(k);
112
113        Ok(distances)
114    }
115
116    /// Geodesic interpolation between two points
117    ///
118    /// Returns point at fraction t along geodesic from x to y
119    pub fn geodesic(&self, x: &[f64], y: &[f64], t: f64) -> Result<Vec<f64>> {
120        let t = t.clamp(0.0, 1.0);
121
122        // log_x(y) gives direction
123        let v = self.log_map(x, y)?;
124
125        // Scale by t
126        let tv: Vec<f64> = v.iter().map(|&vi| t * vi).collect();
127
128        // exp_x(t * v)
129        self.exp_map(x, &tv)
130    }
131
132    /// Sample points along geodesic
133    pub fn geodesic_path(
134        &self,
135        x: &[f64],
136        y: &[f64],
137        num_points: usize,
138    ) -> Result<Vec<Vec<f64>>> {
139        let mut path = Vec::with_capacity(num_points);
140
141        for i in 0..num_points {
142            let t = i as f64 / (num_points - 1).max(1) as f64;
143            path.push(self.geodesic(x, y, t)?);
144        }
145
146        Ok(path)
147    }
148
149    /// Parallel transport vector v from x to y
150    pub fn parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
151        if x.len() != self.dim() || y.len() != self.dim() || v.len() != self.dim() {
152            return Err(MathError::dimension_mismatch(self.dim(), x.len()));
153        }
154
155        let mut result = vec![0.0; self.dim()];
156        let (e_range, h_range, s_range) = self.config().component_ranges();
157
158        // Euclidean: parallel transport is identity
159        for i in e_range.clone() {
160            result[i] = v[i];
161        }
162
163        // Hyperbolic parallel transport
164        if !h_range.is_empty() {
165            let x_h = &x[h_range.clone()];
166            let y_h = &y[h_range.clone()];
167            let v_h = &v[h_range.clone()];
168            let pt_h = self.poincare_parallel_transport(x_h, y_h, v_h)?;
169            for (i, val) in h_range.clone().zip(pt_h.iter()) {
170                result[i] = *val;
171            }
172        }
173
174        // Spherical parallel transport
175        if !s_range.is_empty() {
176            let x_s = &x[s_range.clone()];
177            let y_s = &y[s_range.clone()];
178            let v_s = &v[s_range.clone()];
179            let pt_s = self.spherical_parallel_transport(x_s, y_s, v_s)?;
180            for (i, val) in s_range.clone().zip(pt_s.iter()) {
181                result[i] = *val;
182            }
183        }
184
185        Ok(result)
186    }
187
188    /// Poincaré ball parallel transport
189    fn poincare_parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
190        let c = -self.config().hyperbolic_curvature;
191
192        let x_norm_sq: f64 = x.iter().map(|&xi| xi * xi).sum();
193        let y_norm_sq: f64 = y.iter().map(|&yi| yi * yi).sum();
194
195        let lambda_x = 2.0 / (1.0 - c * x_norm_sq).max(EPS);
196        let lambda_y = 2.0 / (1.0 - c * y_norm_sq).max(EPS);
197
198        let scale = lambda_x / lambda_y;
199
200        // Gyration correction
201        let xy_dot: f64 = x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum();
202        let _gyration_factor = 1.0 + c * xy_dot;
203
204        // Simplified parallel transport (good approximation for small distances)
205        Ok(v.iter().map(|&vi| scale * vi).collect())
206    }
207
208    /// Spherical parallel transport
209    fn spherical_parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
210        use crate::utils::dot;
211
212        let cos_theta = dot(x, y).clamp(-1.0, 1.0);
213
214        if (cos_theta - 1.0).abs() < EPS {
215            return Ok(v.to_vec());
216        }
217
218        let theta = cos_theta.acos();
219
220        // Direction from x to y
221        let u: Vec<f64> = x
222            .iter()
223            .zip(y.iter())
224            .map(|(&xi, &yi)| yi - cos_theta * xi)
225            .collect();
226        let u_norm = norm(&u);
227
228        if u_norm < EPS {
229            return Ok(v.to_vec());
230        }
231
232        let u: Vec<f64> = u.iter().map(|&ui| ui / u_norm).collect();
233
234        // Components of v
235        let v_u = dot(v, &u);
236        let v_x = dot(v, x);
237
238        // Parallel transport formula
239        let result: Vec<f64> = (0..x.len())
240            .map(|i| {
241                let v_perp = v[i] - v_u * u[i] - v_x * x[i];
242                v_perp
243                    + v_u * (-theta.sin() * x[i] + theta.cos() * u[i])
244                    - v_x * (theta.cos() * x[i] + theta.sin() * u[i])
245            })
246            .collect();
247
248        Ok(result)
249    }
250
251    /// Compute variance of points on manifold
252    pub fn variance(&self, points: &[Vec<f64>], mean: Option<&[f64]>) -> Result<f64> {
253        if points.is_empty() {
254            return Ok(0.0);
255        }
256
257        let mean = match mean {
258            Some(m) => m.to_vec(),
259            None => self.frechet_mean(points, None)?,
260        };
261
262        let mut total_sq_dist = 0.0;
263        for p in points {
264            let d = self.distance(&mean, p)?;
265            total_sq_dist += d * d;
266        }
267
268        Ok(total_sq_dist / points.len() as f64)
269    }
270
271    /// Project gradient to tangent space at point
272    ///
273    /// For product manifolds, this projects each component appropriately
274    pub fn project_gradient(&self, point: &[f64], gradient: &[f64]) -> Result<Vec<f64>> {
275        if point.len() != self.dim() || gradient.len() != self.dim() {
276            return Err(MathError::dimension_mismatch(self.dim(), point.len()));
277        }
278
279        let mut result = gradient.to_vec();
280        let (_e_range, h_range, s_range) = self.config().component_ranges();
281
282        // Euclidean: gradient is already in tangent space (no modification needed)
283
284        // Hyperbolic: scale by (1 - ||x||²)² / 4
285        if !h_range.is_empty() {
286            let x_h = &point[h_range.clone()];
287            let x_norm_sq: f64 = x_h.iter().map(|&xi| xi * xi).sum();
288            let c = -self.config().hyperbolic_curvature;
289            let lambda = 2.0 / (1.0 - c * x_norm_sq).max(EPS);
290            let scale = 1.0 / (lambda * lambda);
291
292            for i in h_range.clone() {
293                result[i] *= scale;
294            }
295        }
296
297        // Spherical: project out normal component
298        if !s_range.is_empty() {
299            let x_s = &point[s_range.clone()];
300            let g_s = &gradient[s_range.clone()];
301
302            // Normal component: (g · x) x
303            let normal_component: f64 = g_s.iter().zip(x_s.iter()).map(|(&gi, &xi)| gi * xi).sum();
304
305            for (i, &xi) in s_range.clone().zip(x_s.iter()) {
306                result[i] -= normal_component * xi;
307            }
308        }
309
310        Ok(result)
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn test_pairwise_distances() {
320        let manifold = ProductManifold::new(2, 0, 0);
321
322        let points = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
323
324        let dists = manifold.pairwise_distances(&points).unwrap();
325
326        assert!(dists[0][0].abs() < 1e-10);
327        assert!((dists[0][1] - 1.0).abs() < 1e-10);
328        assert!((dists[0][2] - 1.0).abs() < 1e-10);
329    }
330
331    #[test]
332    fn test_knn() {
333        let manifold = ProductManifold::new(2, 0, 0);
334
335        let points = vec![
336            vec![0.0, 0.0],
337            vec![1.0, 0.0],
338            vec![2.0, 0.0],
339            vec![3.0, 0.0],
340        ];
341
342        let query = vec![0.5, 0.0];
343        let neighbors = manifold.knn(&query, &points, 2).unwrap();
344
345        assert_eq!(neighbors.len(), 2);
346        // Closest should be [0,0] or [1,0]
347        assert!(neighbors[0].0 == 0 || neighbors[0].0 == 1);
348    }
349
350    #[test]
351    fn test_geodesic_path() {
352        let manifold = ProductManifold::new(2, 0, 0);
353
354        let x = vec![0.0, 0.0];
355        let y = vec![2.0, 2.0];
356
357        let path = manifold.geodesic_path(&x, &y, 5).unwrap();
358
359        assert_eq!(path.len(), 5);
360
361        // Midpoint should be (1, 1)
362        assert!((path[2][0] - 1.0).abs() < 1e-6);
363        assert!((path[2][1] - 1.0).abs() < 1e-6);
364    }
365
366    #[test]
367    fn test_variance() {
368        let manifold = ProductManifold::new(2, 0, 0);
369
370        // Points at unit distance from origin
371        let points = vec![
372            vec![1.0, 0.0],
373            vec![-1.0, 0.0],
374            vec![0.0, 1.0],
375            vec![0.0, -1.0],
376        ];
377
378        let variance = manifold.variance(&points, Some(&vec![0.0, 0.0])).unwrap();
379        assert!((variance - 1.0).abs() < 1e-10);
380    }
381}