ruvector_math/product_manifold/
operations.rs1use super::ProductManifold;
4use crate::error::{MathError, Result};
5use crate::utils::{norm, EPS};
6
7#[cfg(feature = "parallel")]
8use rayon::prelude::*;
9
10impl ProductManifold {
12 pub fn pairwise_distances(&self, points: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
15 let n = points.len();
16
17 #[cfg(feature = "parallel")]
18 {
19 self.pairwise_distances_parallel(points, n)
20 }
21
22 #[cfg(not(feature = "parallel"))]
23 {
24 self.pairwise_distances_sequential(points, n)
25 }
26 }
27
28 #[inline]
30 fn pairwise_distances_sequential(
31 &self,
32 points: &[Vec<f64>],
33 n: usize,
34 ) -> Result<Vec<Vec<f64>>> {
35 let mut distances = vec![vec![0.0; n]; n];
36
37 for i in 0..n {
38 for j in (i + 1)..n {
39 let d = self.distance(&points[i], &points[j])?;
40 distances[i][j] = d;
41 distances[j][i] = d;
42 }
43 }
44
45 Ok(distances)
46 }
47
48 #[cfg(feature = "parallel")]
50 fn pairwise_distances_parallel(&self, points: &[Vec<f64>], n: usize) -> Result<Vec<Vec<f64>>> {
51 let pairs: Vec<_> = (0..n)
53 .flat_map(|i| ((i + 1)..n).map(move |j| (i, j)))
54 .collect();
55
56 let results: Vec<(usize, usize, f64)> = pairs
57 .par_iter()
58 .filter_map(|&(i, j)| {
59 self.distance(&points[i], &points[j])
60 .ok()
61 .map(|d| (i, j, d))
62 })
63 .collect();
64
65 let mut distances = vec![vec![0.0; n]; n];
66 for (i, j, d) in results {
67 distances[i][j] = d;
68 distances[j][i] = d;
69 }
70
71 Ok(distances)
72 }
73
74 pub fn knn(&self, query: &[f64], points: &[Vec<f64>], k: usize) -> Result<Vec<(usize, f64)>> {
77 #[cfg(feature = "parallel")]
78 {
79 self.knn_parallel(query, points, k)
80 }
81
82 #[cfg(not(feature = "parallel"))]
83 {
84 self.knn_sequential(query, points, k)
85 }
86 }
87
88 #[inline]
90 fn knn_sequential(
91 &self,
92 query: &[f64],
93 points: &[Vec<f64>],
94 k: usize,
95 ) -> Result<Vec<(usize, f64)>> {
96 let mut distances: Vec<(usize, f64)> = points
97 .iter()
98 .enumerate()
99 .filter_map(|(i, p)| self.distance(query, p).ok().map(|d| (i, d)))
100 .collect();
101
102 distances
104 .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
105 distances.truncate(k);
106
107 Ok(distances)
108 }
109
110 #[cfg(feature = "parallel")]
112 fn knn_parallel(
113 &self,
114 query: &[f64],
115 points: &[Vec<f64>],
116 k: usize,
117 ) -> Result<Vec<(usize, f64)>> {
118 let mut distances: Vec<(usize, f64)> = points
119 .par_iter()
120 .enumerate()
121 .filter_map(|(i, p)| self.distance(query, p).ok().map(|d| (i, d)))
122 .collect();
123
124 distances
126 .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
127 distances.truncate(k);
128
129 Ok(distances)
130 }
131
132 pub fn geodesic(&self, x: &[f64], y: &[f64], t: f64) -> Result<Vec<f64>> {
136 let t = t.clamp(0.0, 1.0);
137
138 let v = self.log_map(x, y)?;
140
141 let tv: Vec<f64> = v.iter().map(|&vi| t * vi).collect();
143
144 self.exp_map(x, &tv)
146 }
147
148 pub fn geodesic_path(&self, x: &[f64], y: &[f64], num_points: usize) -> Result<Vec<Vec<f64>>> {
150 let mut path = Vec::with_capacity(num_points);
151
152 for i in 0..num_points {
153 let t = i as f64 / (num_points - 1).max(1) as f64;
154 path.push(self.geodesic(x, y, t)?);
155 }
156
157 Ok(path)
158 }
159
160 pub fn parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
162 if x.len() != self.dim() || y.len() != self.dim() || v.len() != self.dim() {
163 return Err(MathError::dimension_mismatch(self.dim(), x.len()));
164 }
165
166 let mut result = vec![0.0; self.dim()];
167 let (e_range, h_range, s_range) = self.config().component_ranges();
168
169 for i in e_range.clone() {
171 result[i] = v[i];
172 }
173
174 if !h_range.is_empty() {
176 let x_h = &x[h_range.clone()];
177 let y_h = &y[h_range.clone()];
178 let v_h = &v[h_range.clone()];
179 let pt_h = self.poincare_parallel_transport(x_h, y_h, v_h)?;
180 for (i, val) in h_range.clone().zip(pt_h.iter()) {
181 result[i] = *val;
182 }
183 }
184
185 if !s_range.is_empty() {
187 let x_s = &x[s_range.clone()];
188 let y_s = &y[s_range.clone()];
189 let v_s = &v[s_range.clone()];
190 let pt_s = self.spherical_parallel_transport(x_s, y_s, v_s)?;
191 for (i, val) in s_range.clone().zip(pt_s.iter()) {
192 result[i] = *val;
193 }
194 }
195
196 Ok(result)
197 }
198
199 fn poincare_parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
201 let c = -self.config().hyperbolic_curvature;
202
203 let x_norm_sq: f64 = x.iter().map(|&xi| xi * xi).sum();
204 let y_norm_sq: f64 = y.iter().map(|&yi| yi * yi).sum();
205
206 let lambda_x = 2.0 / (1.0 - c * x_norm_sq).max(EPS);
207 let lambda_y = 2.0 / (1.0 - c * y_norm_sq).max(EPS);
208
209 let scale = lambda_x / lambda_y;
210
211 let xy_dot: f64 = x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum();
213 let _gyration_factor = 1.0 + c * xy_dot;
214
215 Ok(v.iter().map(|&vi| scale * vi).collect())
217 }
218
219 fn spherical_parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
221 use crate::utils::dot;
222
223 let cos_theta = dot(x, y).clamp(-1.0, 1.0);
224
225 if (cos_theta - 1.0).abs() < EPS {
226 return Ok(v.to_vec());
227 }
228
229 let theta = cos_theta.acos();
230
231 let u: Vec<f64> = x
233 .iter()
234 .zip(y.iter())
235 .map(|(&xi, &yi)| yi - cos_theta * xi)
236 .collect();
237 let u_norm = norm(&u);
238
239 if u_norm < EPS {
240 return Ok(v.to_vec());
241 }
242
243 let u: Vec<f64> = u.iter().map(|&ui| ui / u_norm).collect();
244
245 let v_u = dot(v, &u);
247 let v_x = dot(v, x);
248
249 let result: Vec<f64> = (0..x.len())
251 .map(|i| {
252 let v_perp = v[i] - v_u * u[i] - v_x * x[i];
253 v_perp + v_u * (-theta.sin() * x[i] + theta.cos() * u[i])
254 - v_x * (theta.cos() * x[i] + theta.sin() * u[i])
255 })
256 .collect();
257
258 Ok(result)
259 }
260
261 pub fn variance(&self, points: &[Vec<f64>], mean: Option<&[f64]>) -> Result<f64> {
263 if points.is_empty() {
264 return Ok(0.0);
265 }
266
267 let mean = match mean {
268 Some(m) => m.to_vec(),
269 None => self.frechet_mean(points, None)?,
270 };
271
272 let mut total_sq_dist = 0.0;
273 for p in points {
274 let d = self.distance(&mean, p)?;
275 total_sq_dist += d * d;
276 }
277
278 Ok(total_sq_dist / points.len() as f64)
279 }
280
281 pub fn project_gradient(&self, point: &[f64], gradient: &[f64]) -> Result<Vec<f64>> {
285 if point.len() != self.dim() || gradient.len() != self.dim() {
286 return Err(MathError::dimension_mismatch(self.dim(), point.len()));
287 }
288
289 let mut result = gradient.to_vec();
290 let (_e_range, h_range, s_range) = self.config().component_ranges();
291
292 if !h_range.is_empty() {
296 let x_h = &point[h_range.clone()];
297 let x_norm_sq: f64 = x_h.iter().map(|&xi| xi * xi).sum();
298 let c = -self.config().hyperbolic_curvature;
299 let lambda = 2.0 / (1.0 - c * x_norm_sq).max(EPS);
300 let scale = 1.0 / (lambda * lambda);
301
302 for i in h_range.clone() {
303 result[i] *= scale;
304 }
305 }
306
307 if !s_range.is_empty() {
309 let x_s = &point[s_range.clone()];
310 let g_s = &gradient[s_range.clone()];
311
312 let normal_component: f64 = g_s.iter().zip(x_s.iter()).map(|(&gi, &xi)| gi * xi).sum();
314
315 for (i, &xi) in s_range.clone().zip(x_s.iter()) {
316 result[i] -= normal_component * xi;
317 }
318 }
319
320 Ok(result)
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn test_pairwise_distances() {
330 let manifold = ProductManifold::new(2, 0, 0);
331
332 let points = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
333
334 let dists = manifold.pairwise_distances(&points).unwrap();
335
336 assert!(dists[0][0].abs() < 1e-10);
337 assert!((dists[0][1] - 1.0).abs() < 1e-10);
338 assert!((dists[0][2] - 1.0).abs() < 1e-10);
339 }
340
341 #[test]
342 fn test_knn() {
343 let manifold = ProductManifold::new(2, 0, 0);
344
345 let points = vec![
346 vec![0.0, 0.0],
347 vec![1.0, 0.0],
348 vec![2.0, 0.0],
349 vec![3.0, 0.0],
350 ];
351
352 let query = vec![0.5, 0.0];
353 let neighbors = manifold.knn(&query, &points, 2).unwrap();
354
355 assert_eq!(neighbors.len(), 2);
356 assert!(neighbors[0].0 == 0 || neighbors[0].0 == 1);
358 }
359
360 #[test]
361 fn test_geodesic_path() {
362 let manifold = ProductManifold::new(2, 0, 0);
363
364 let x = vec![0.0, 0.0];
365 let y = vec![2.0, 2.0];
366
367 let path = manifold.geodesic_path(&x, &y, 5).unwrap();
368
369 assert_eq!(path.len(), 5);
370
371 assert!((path[2][0] - 1.0).abs() < 1e-6);
373 assert!((path[2][1] - 1.0).abs() < 1e-6);
374 }
375
376 #[test]
377 fn test_variance() {
378 let manifold = ProductManifold::new(2, 0, 0);
379
380 let points = vec![
382 vec![1.0, 0.0],
383 vec![-1.0, 0.0],
384 vec![0.0, 1.0],
385 vec![0.0, -1.0],
386 ];
387
388 let variance = manifold.variance(&points, Some(&vec![0.0, 0.0])).unwrap();
389 assert!((variance - 1.0).abs() < 1e-10);
390 }
391}