ruvector_math/optimal_transport/
gromov_wasserstein.rs1use super::SinkhornSolver;
29use crate::error::{MathError, Result};
30use crate::utils::EPS;
31
32#[derive(Debug, Clone)]
34pub struct GromovWasserstein {
35 regularization: f64,
37 max_iterations: usize,
39 threshold: f64,
41 inner_iterations: usize,
43}
44
45impl GromovWasserstein {
46 pub fn new(regularization: f64) -> Self {
51 Self {
52 regularization: regularization.max(1e-6),
53 max_iterations: 100,
54 threshold: 1e-5,
55 inner_iterations: 50,
56 }
57 }
58
59 pub fn with_max_iterations(mut self, max_iter: usize) -> Self {
61 self.max_iterations = max_iter.max(1);
62 self
63 }
64
65 pub fn with_threshold(mut self, threshold: f64) -> Self {
67 self.threshold = threshold.max(1e-12);
68 self
69 }
70
71 fn distance_matrix(points: &[Vec<f64>]) -> Vec<Vec<f64>> {
73 let n = points.len();
74 let mut dist = vec![vec![0.0; n]; n];
75
76 for i in 0..n {
77 for j in (i + 1)..n {
78 let d: f64 = points[i]
79 .iter()
80 .zip(points[j].iter())
81 .map(|(&a, &b)| (a - b).powi(2))
82 .sum::<f64>()
83 .sqrt();
84 dist[i][j] = d;
85 dist[j][i] = d;
86 }
87 }
88
89 dist
90 }
91
92 fn compute_gw_loss(dist_x: &[Vec<f64>], dist_y: &[Vec<f64>], gamma: &[Vec<f64>]) -> f64 {
98 let n = dist_x.len();
99 let m = dist_y.len();
100
101 let p: Vec<f64> = gamma.iter().map(|row| row.iter().sum()).collect();
103 let term1: f64 = (0..n)
104 .map(|i| {
105 (0..n)
106 .map(|k| dist_x[i][k].powi(2) * p[i] * p[k])
107 .sum::<f64>()
108 })
109 .sum();
110
111 let q: Vec<f64> = (0..m)
113 .map(|j| gamma.iter().map(|row| row[j]).sum())
114 .collect();
115 let term2: f64 = (0..m)
116 .map(|j| {
117 (0..m)
118 .map(|l| dist_y[j][l].powi(2) * q[j] * q[l])
119 .sum::<f64>()
120 })
121 .sum();
122
123 let dx_gamma: Vec<Vec<f64>> = (0..n)
126 .map(|i| {
127 (0..m)
128 .map(|j| (0..n).map(|k| dist_x[i][k] * gamma[k][j]).sum())
129 .collect()
130 })
131 .collect();
132
133 let gamma_dy: Vec<Vec<f64>> = (0..n)
134 .map(|i| {
135 (0..m)
136 .map(|j| (0..m).map(|l| gamma[i][l] * dist_y[l][j]).sum())
137 .collect()
138 })
139 .collect();
140
141 let term3: f64 = 2.0
142 * (0..n)
143 .map(|i| (0..m).map(|j| dx_gamma[i][j] * gamma_dy[i][j]).sum::<f64>())
144 .sum::<f64>();
145
146 term1 + term2 - term3
147 }
148
149 fn compute_gradient(
152 dist_x: &[Vec<f64>],
153 dist_y: &[Vec<f64>],
154 gamma: &[Vec<f64>],
155 ) -> Vec<Vec<f64>> {
156 let n = dist_x.len();
157 let m = dist_y.len();
158
159 let p: Vec<f64> = gamma.iter().map(|row| row.iter().sum()).collect();
161 let q: Vec<f64> = (0..m)
162 .map(|j| gamma.iter().map(|row| row[j]).sum())
163 .collect();
164
165 let dx2_p: Vec<f64> = (0..n)
167 .map(|i| (0..n).map(|k| dist_x[i][k].powi(2) * p[k]).sum())
168 .collect();
169
170 let dy2_q: Vec<f64> = (0..m)
172 .map(|j| (0..m).map(|l| dist_y[j][l].powi(2) * q[l]).sum())
173 .collect();
174
175 let dx_gamma_dy: Vec<Vec<f64>> = (0..n)
177 .map(|i| {
178 (0..m)
179 .map(|j| {
180 (0..n)
181 .map(|k| {
182 (0..m)
183 .map(|l| dist_x[i][k] * gamma[k][l] * dist_y[l][j])
184 .sum::<f64>()
185 })
186 .sum()
187 })
188 .collect()
189 })
190 .collect();
191
192 (0..n)
194 .map(|i| {
195 (0..m)
196 .map(|j| 2.0 * (dx2_p[i] + dy2_q[j] - 2.0 * dx_gamma_dy[i][j]))
197 .collect()
198 })
199 .collect()
200 }
201
202 pub fn solve(
204 &self,
205 source: &[Vec<f64>],
206 target: &[Vec<f64>],
207 ) -> Result<GromovWassersteinResult> {
208 if source.is_empty() || target.is_empty() {
209 return Err(MathError::empty_input("points"));
210 }
211
212 let n = source.len();
213 let m = target.len();
214
215 let dist_x = Self::distance_matrix(source);
217 let dist_y = Self::distance_matrix(target);
218
219 let mut gamma: Vec<Vec<f64>> = (0..n).map(|_| vec![1.0 / (n * m) as f64; m]).collect();
221
222 let sinkhorn = SinkhornSolver::new(self.regularization, self.inner_iterations);
223 let source_weights = vec![1.0 / n as f64; n];
224 let target_weights = vec![1.0 / m as f64; m];
225
226 let mut loss = Self::compute_gw_loss(&dist_x, &dist_y, &gamma);
227 let mut converged = false;
228
229 for _iter in 0..self.max_iterations {
230 let gradient = Self::compute_gradient(&dist_x, &dist_y, &gamma);
232
233 let linear_result = sinkhorn.solve(&gradient, &source_weights, &target_weights)?;
235 let direction = linear_result.plan;
236
237 let mut best_alpha = 0.0;
239 let mut best_loss = loss;
240
241 for k in 1..=10 {
242 let alpha = k as f64 / 10.0;
243
244 let gamma_new: Vec<Vec<f64>> = (0..n)
246 .map(|i| {
247 (0..m)
248 .map(|j| (1.0 - alpha) * gamma[i][j] + alpha * direction[i][j])
249 .collect()
250 })
251 .collect();
252
253 let new_loss = Self::compute_gw_loss(&dist_x, &dist_y, &gamma_new);
254
255 if new_loss < best_loss {
256 best_alpha = alpha;
257 best_loss = new_loss;
258 }
259 }
260
261 if best_alpha > 0.0 {
263 for i in 0..n {
264 for j in 0..m {
265 gamma[i][j] =
266 (1.0 - best_alpha) * gamma[i][j] + best_alpha * direction[i][j];
267 }
268 }
269 }
270
271 let loss_change = (loss - best_loss).abs() / (loss.abs() + EPS);
273 loss = best_loss;
274
275 if loss_change < self.threshold {
276 converged = true;
277 break;
278 }
279 }
280
281 Ok(GromovWassersteinResult {
282 transport_plan: gamma,
283 loss,
284 converged,
285 })
286 }
287
288 pub fn distance(&self, source: &[Vec<f64>], target: &[Vec<f64>]) -> Result<f64> {
290 let result = self.solve(source, target)?;
291 Ok(result.loss.sqrt())
292 }
293}
294
295#[derive(Debug, Clone)]
297pub struct GromovWassersteinResult {
298 pub transport_plan: Vec<Vec<f64>>,
300 pub loss: f64,
302 pub converged: bool,
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_gw_identical() {
312 let gw = GromovWasserstein::new(0.1);
313
314 let points = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
315
316 let dist = gw.distance(&points, &points).unwrap();
317 assert!(
319 dist < 1.0,
320 "Identical structures should have low GW: {}",
321 dist
322 );
323 }
324
325 #[test]
326 fn test_gw_scaled() {
327 let gw = GromovWasserstein::new(0.1);
328
329 let source = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
330
331 let target: Vec<Vec<f64>> = source
333 .iter()
334 .map(|p| vec![p[0] * 2.0, p[1] * 2.0])
335 .collect();
336
337 let dist = gw.distance(&source, &target).unwrap();
338
339 assert!(dist > 0.0, "Scaled structure should have some GW distance");
342 }
343
344 #[test]
345 fn test_gw_different_structures() {
346 let gw = GromovWasserstein::new(0.1);
347
348 let triangle = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.5, 0.866]];
350
351 let line = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![2.0, 0.0]];
353
354 let dist = gw.distance(&triangle, &line).unwrap();
355
356 assert!(
358 dist > 0.1,
359 "Different structures should have high GW: {}",
360 dist
361 );
362 }
363
364 #[test]
365 fn test_distance_matrix() {
366 let points = vec![vec![0.0, 0.0], vec![3.0, 4.0]];
367 let dist = GromovWasserstein::distance_matrix(&points);
368
369 assert!((dist[0][1] - 5.0).abs() < 1e-10);
370 assert!((dist[1][0] - 5.0).abs() < 1e-10);
371 assert!(dist[0][0].abs() < 1e-10);
372 }
373}