ruvector_math/product_manifold/
operations.rs1use crate::error::{MathError, Result};
4use crate::utils::{norm, EPS};
5use super::ProductManifold;
6
7#[cfg(feature = "parallel")]
8use rayon::prelude::*;
9
10impl ProductManifold {
12 pub fn pairwise_distances(&self, points: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
15 let n = points.len();
16
17 #[cfg(feature = "parallel")]
18 {
19 self.pairwise_distances_parallel(points, n)
20 }
21
22 #[cfg(not(feature = "parallel"))]
23 {
24 self.pairwise_distances_sequential(points, n)
25 }
26 }
27
28 #[inline]
30 fn pairwise_distances_sequential(&self, points: &[Vec<f64>], n: usize) -> Result<Vec<Vec<f64>>> {
31 let mut distances = vec![vec![0.0; n]; n];
32
33 for i in 0..n {
34 for j in (i + 1)..n {
35 let d = self.distance(&points[i], &points[j])?;
36 distances[i][j] = d;
37 distances[j][i] = d;
38 }
39 }
40
41 Ok(distances)
42 }
43
44 #[cfg(feature = "parallel")]
46 fn pairwise_distances_parallel(&self, points: &[Vec<f64>], n: usize) -> Result<Vec<Vec<f64>>> {
47 let pairs: Vec<_> = (0..n)
49 .flat_map(|i| ((i + 1)..n).map(move |j| (i, j)))
50 .collect();
51
52 let results: Vec<(usize, usize, f64)> = pairs
53 .par_iter()
54 .filter_map(|&(i, j)| {
55 self.distance(&points[i], &points[j])
56 .ok()
57 .map(|d| (i, j, d))
58 })
59 .collect();
60
61 let mut distances = vec![vec![0.0; n]; n];
62 for (i, j, d) in results {
63 distances[i][j] = d;
64 distances[j][i] = d;
65 }
66
67 Ok(distances)
68 }
69
70 pub fn knn(&self, query: &[f64], points: &[Vec<f64>], k: usize) -> Result<Vec<(usize, f64)>> {
73 #[cfg(feature = "parallel")]
74 {
75 self.knn_parallel(query, points, k)
76 }
77
78 #[cfg(not(feature = "parallel"))]
79 {
80 self.knn_sequential(query, points, k)
81 }
82 }
83
84 #[inline]
86 fn knn_sequential(&self, query: &[f64], points: &[Vec<f64>], k: usize) -> Result<Vec<(usize, f64)>> {
87 let mut distances: Vec<(usize, f64)> = points
88 .iter()
89 .enumerate()
90 .filter_map(|(i, p)| self.distance(query, p).ok().map(|d| (i, d)))
91 .collect();
92
93 distances.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
95 distances.truncate(k);
96
97 Ok(distances)
98 }
99
100 #[cfg(feature = "parallel")]
102 fn knn_parallel(&self, query: &[f64], points: &[Vec<f64>], k: usize) -> Result<Vec<(usize, f64)>> {
103 let mut distances: Vec<(usize, f64)> = points
104 .par_iter()
105 .enumerate()
106 .filter_map(|(i, p)| self.distance(query, p).ok().map(|d| (i, d)))
107 .collect();
108
109 distances.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
111 distances.truncate(k);
112
113 Ok(distances)
114 }
115
116 pub fn geodesic(&self, x: &[f64], y: &[f64], t: f64) -> Result<Vec<f64>> {
120 let t = t.clamp(0.0, 1.0);
121
122 let v = self.log_map(x, y)?;
124
125 let tv: Vec<f64> = v.iter().map(|&vi| t * vi).collect();
127
128 self.exp_map(x, &tv)
130 }
131
132 pub fn geodesic_path(
134 &self,
135 x: &[f64],
136 y: &[f64],
137 num_points: usize,
138 ) -> Result<Vec<Vec<f64>>> {
139 let mut path = Vec::with_capacity(num_points);
140
141 for i in 0..num_points {
142 let t = i as f64 / (num_points - 1).max(1) as f64;
143 path.push(self.geodesic(x, y, t)?);
144 }
145
146 Ok(path)
147 }
148
149 pub fn parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
151 if x.len() != self.dim() || y.len() != self.dim() || v.len() != self.dim() {
152 return Err(MathError::dimension_mismatch(self.dim(), x.len()));
153 }
154
155 let mut result = vec![0.0; self.dim()];
156 let (e_range, h_range, s_range) = self.config().component_ranges();
157
158 for i in e_range.clone() {
160 result[i] = v[i];
161 }
162
163 if !h_range.is_empty() {
165 let x_h = &x[h_range.clone()];
166 let y_h = &y[h_range.clone()];
167 let v_h = &v[h_range.clone()];
168 let pt_h = self.poincare_parallel_transport(x_h, y_h, v_h)?;
169 for (i, val) in h_range.clone().zip(pt_h.iter()) {
170 result[i] = *val;
171 }
172 }
173
174 if !s_range.is_empty() {
176 let x_s = &x[s_range.clone()];
177 let y_s = &y[s_range.clone()];
178 let v_s = &v[s_range.clone()];
179 let pt_s = self.spherical_parallel_transport(x_s, y_s, v_s)?;
180 for (i, val) in s_range.clone().zip(pt_s.iter()) {
181 result[i] = *val;
182 }
183 }
184
185 Ok(result)
186 }
187
188 fn poincare_parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
190 let c = -self.config().hyperbolic_curvature;
191
192 let x_norm_sq: f64 = x.iter().map(|&xi| xi * xi).sum();
193 let y_norm_sq: f64 = y.iter().map(|&yi| yi * yi).sum();
194
195 let lambda_x = 2.0 / (1.0 - c * x_norm_sq).max(EPS);
196 let lambda_y = 2.0 / (1.0 - c * y_norm_sq).max(EPS);
197
198 let scale = lambda_x / lambda_y;
199
200 let xy_dot: f64 = x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum();
202 let _gyration_factor = 1.0 + c * xy_dot;
203
204 Ok(v.iter().map(|&vi| scale * vi).collect())
206 }
207
208 fn spherical_parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
210 use crate::utils::dot;
211
212 let cos_theta = dot(x, y).clamp(-1.0, 1.0);
213
214 if (cos_theta - 1.0).abs() < EPS {
215 return Ok(v.to_vec());
216 }
217
218 let theta = cos_theta.acos();
219
220 let u: Vec<f64> = x
222 .iter()
223 .zip(y.iter())
224 .map(|(&xi, &yi)| yi - cos_theta * xi)
225 .collect();
226 let u_norm = norm(&u);
227
228 if u_norm < EPS {
229 return Ok(v.to_vec());
230 }
231
232 let u: Vec<f64> = u.iter().map(|&ui| ui / u_norm).collect();
233
234 let v_u = dot(v, &u);
236 let v_x = dot(v, x);
237
238 let result: Vec<f64> = (0..x.len())
240 .map(|i| {
241 let v_perp = v[i] - v_u * u[i] - v_x * x[i];
242 v_perp
243 + v_u * (-theta.sin() * x[i] + theta.cos() * u[i])
244 - v_x * (theta.cos() * x[i] + theta.sin() * u[i])
245 })
246 .collect();
247
248 Ok(result)
249 }
250
251 pub fn variance(&self, points: &[Vec<f64>], mean: Option<&[f64]>) -> Result<f64> {
253 if points.is_empty() {
254 return Ok(0.0);
255 }
256
257 let mean = match mean {
258 Some(m) => m.to_vec(),
259 None => self.frechet_mean(points, None)?,
260 };
261
262 let mut total_sq_dist = 0.0;
263 for p in points {
264 let d = self.distance(&mean, p)?;
265 total_sq_dist += d * d;
266 }
267
268 Ok(total_sq_dist / points.len() as f64)
269 }
270
271 pub fn project_gradient(&self, point: &[f64], gradient: &[f64]) -> Result<Vec<f64>> {
275 if point.len() != self.dim() || gradient.len() != self.dim() {
276 return Err(MathError::dimension_mismatch(self.dim(), point.len()));
277 }
278
279 let mut result = gradient.to_vec();
280 let (_e_range, h_range, s_range) = self.config().component_ranges();
281
282 if !h_range.is_empty() {
286 let x_h = &point[h_range.clone()];
287 let x_norm_sq: f64 = x_h.iter().map(|&xi| xi * xi).sum();
288 let c = -self.config().hyperbolic_curvature;
289 let lambda = 2.0 / (1.0 - c * x_norm_sq).max(EPS);
290 let scale = 1.0 / (lambda * lambda);
291
292 for i in h_range.clone() {
293 result[i] *= scale;
294 }
295 }
296
297 if !s_range.is_empty() {
299 let x_s = &point[s_range.clone()];
300 let g_s = &gradient[s_range.clone()];
301
302 let normal_component: f64 = g_s.iter().zip(x_s.iter()).map(|(&gi, &xi)| gi * xi).sum();
304
305 for (i, &xi) in s_range.clone().zip(x_s.iter()) {
306 result[i] -= normal_component * xi;
307 }
308 }
309
310 Ok(result)
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn test_pairwise_distances() {
320 let manifold = ProductManifold::new(2, 0, 0);
321
322 let points = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
323
324 let dists = manifold.pairwise_distances(&points).unwrap();
325
326 assert!(dists[0][0].abs() < 1e-10);
327 assert!((dists[0][1] - 1.0).abs() < 1e-10);
328 assert!((dists[0][2] - 1.0).abs() < 1e-10);
329 }
330
331 #[test]
332 fn test_knn() {
333 let manifold = ProductManifold::new(2, 0, 0);
334
335 let points = vec![
336 vec![0.0, 0.0],
337 vec![1.0, 0.0],
338 vec![2.0, 0.0],
339 vec![3.0, 0.0],
340 ];
341
342 let query = vec![0.5, 0.0];
343 let neighbors = manifold.knn(&query, &points, 2).unwrap();
344
345 assert_eq!(neighbors.len(), 2);
346 assert!(neighbors[0].0 == 0 || neighbors[0].0 == 1);
348 }
349
350 #[test]
351 fn test_geodesic_path() {
352 let manifold = ProductManifold::new(2, 0, 0);
353
354 let x = vec![0.0, 0.0];
355 let y = vec![2.0, 2.0];
356
357 let path = manifold.geodesic_path(&x, &y, 5).unwrap();
358
359 assert_eq!(path.len(), 5);
360
361 assert!((path[2][0] - 1.0).abs() < 1e-6);
363 assert!((path[2][1] - 1.0).abs() < 1e-6);
364 }
365
366 #[test]
367 fn test_variance() {
368 let manifold = ProductManifold::new(2, 0, 0);
369
370 let points = vec![
372 vec![1.0, 0.0],
373 vec![-1.0, 0.0],
374 vec![0.0, 1.0],
375 vec![0.0, -1.0],
376 ];
377
378 let variance = manifold.variance(&points, Some(&vec![0.0, 0.0])).unwrap();
379 assert!((variance - 1.0).abs() < 1e-10);
380 }
381}