Skip to main content

ruvector_math/optimal_transport/
gromov_wasserstein.rs

1//! Gromov-Wasserstein Distance
2//!
3//! Gromov-Wasserstein (GW) distance compares the *structure* of two metric spaces,
4//! not requiring them to share a common embedding space.
5//!
6//! ## Definition
7//!
8//! GW(X, Y) = min_{γ ∈ Π(μ,ν)} Σᵢⱼₖₗ |d_X(xᵢ, xₖ) - d_Y(yⱼ, yₗ)|² γᵢⱼ γₖₗ
9//!
10//! This measures how well the pairwise distances in X match those in Y.
11//!
12//! ## Use Cases
13//!
14//! - Cross-lingual word embeddings (different embedding spaces)
15//! - Graph matching (comparing graph structures)
16//! - Shape matching (comparing point cloud structures)
17//! - Multi-modal alignment (different feature spaces)
18//!
19//! ## Algorithm
20//!
21//! Uses Frank-Wolfe (conditional gradient) with entropic regularization:
22//! 1. Initialize transport plan (identity or Sinkhorn)
23//! 2. Compute gradient of GW objective
24//! 3. Solve linearized problem via Sinkhorn
25//! 4. Line search and update
26//! 5. Repeat until convergence
27
28use super::SinkhornSolver;
29use crate::error::{MathError, Result};
30use crate::utils::EPS;
31
32/// Gromov-Wasserstein distance calculator
33#[derive(Debug, Clone)]
34pub struct GromovWasserstein {
35    /// Regularization for inner Sinkhorn
36    regularization: f64,
37    /// Maximum outer iterations
38    max_iterations: usize,
39    /// Convergence threshold
40    threshold: f64,
41    /// Inner Sinkhorn iterations
42    inner_iterations: usize,
43}
44
45impl GromovWasserstein {
46    /// Create a new Gromov-Wasserstein calculator
47    ///
48    /// # Arguments
49    /// * `regularization` - Entropy regularization (0.01-0.1 typical)
50    pub fn new(regularization: f64) -> Self {
51        Self {
52            regularization: regularization.max(1e-6),
53            max_iterations: 100,
54            threshold: 1e-5,
55            inner_iterations: 50,
56        }
57    }
58
59    /// Set maximum iterations
60    pub fn with_max_iterations(mut self, max_iter: usize) -> Self {
61        self.max_iterations = max_iter.max(1);
62        self
63    }
64
65    /// Set convergence threshold
66    pub fn with_threshold(mut self, threshold: f64) -> Self {
67        self.threshold = threshold.max(1e-12);
68        self
69    }
70
71    /// Compute pairwise distance matrix
72    fn distance_matrix(points: &[Vec<f64>]) -> Vec<Vec<f64>> {
73        let n = points.len();
74        let mut dist = vec![vec![0.0; n]; n];
75
76        for i in 0..n {
77            for j in (i + 1)..n {
78                let d: f64 = points[i]
79                    .iter()
80                    .zip(points[j].iter())
81                    .map(|(&a, &b)| (a - b).powi(2))
82                    .sum::<f64>()
83                    .sqrt();
84                dist[i][j] = d;
85                dist[j][i] = d;
86            }
87        }
88
89        dist
90    }
91
92    /// Compute squared distance loss tensor contraction
93    /// L(γ) = Σᵢⱼₖₗ (D_X[i,k] - D_Y[j,l])² γᵢⱼ γₖₗ
94    ///      = ⟨h₁(D_X) ⊗ h₂(D_Y), γ ⊗ γ⟩ - 2⟨D_X γ D_Y^T, γ⟩
95    ///
96    /// where h₁(a) = a², h₂(b) = b², for squared loss
97    fn compute_gw_loss(dist_x: &[Vec<f64>], dist_y: &[Vec<f64>], gamma: &[Vec<f64>]) -> f64 {
98        let n = dist_x.len();
99        let m = dist_y.len();
100
101        // Term 1: Σᵢₖ D_X[i,k]² (Σⱼ γᵢⱼ)(Σₗ γₖₗ) = Σᵢₖ D_X[i,k]² pᵢ pₖ
102        let p: Vec<f64> = gamma.iter().map(|row| row.iter().sum()).collect();
103        let term1: f64 = (0..n)
104            .map(|i| {
105                (0..n)
106                    .map(|k| dist_x[i][k].powi(2) * p[i] * p[k])
107                    .sum::<f64>()
108            })
109            .sum();
110
111        // Term 2: Σⱼₗ D_Y[j,l]² (Σᵢ γᵢⱼ)(Σₖ γₖₗ) = Σⱼₗ D_Y[j,l]² qⱼ qₗ
112        let q: Vec<f64> = (0..m)
113            .map(|j| gamma.iter().map(|row| row[j]).sum())
114            .collect();
115        let term2: f64 = (0..m)
116            .map(|j| {
117                (0..m)
118                    .map(|l| dist_y[j][l].powi(2) * q[j] * q[l])
119                    .sum::<f64>()
120            })
121            .sum();
122
123        // Term 3: 2 * Σᵢⱼₖₗ D_X[i,k] D_Y[j,l] γᵢⱼ γₖₗ = 2 * trace(D_X γ D_Y^T γ^T)
124        // = 2 * Σᵢⱼ (D_X γ)ᵢⱼ (γ D_Y^T)ᵢⱼ
125        let dx_gamma: Vec<Vec<f64>> = (0..n)
126            .map(|i| {
127                (0..m)
128                    .map(|j| (0..n).map(|k| dist_x[i][k] * gamma[k][j]).sum())
129                    .collect()
130            })
131            .collect();
132
133        let gamma_dy: Vec<Vec<f64>> = (0..n)
134            .map(|i| {
135                (0..m)
136                    .map(|j| (0..m).map(|l| gamma[i][l] * dist_y[l][j]).sum())
137                    .collect()
138            })
139            .collect();
140
141        let term3: f64 = 2.0
142            * (0..n)
143                .map(|i| (0..m).map(|j| dx_gamma[i][j] * gamma_dy[i][j]).sum::<f64>())
144                .sum::<f64>();
145
146        term1 + term2 - term3
147    }
148
149    /// Compute gradient of GW loss w.r.t. gamma
150    /// ∇_γ L = 2 * (h₁(D_X) p 1^T + 1 q^T h₂(D_Y) - 2 D_X γ D_Y^T)
151    fn compute_gradient(
152        dist_x: &[Vec<f64>],
153        dist_y: &[Vec<f64>],
154        gamma: &[Vec<f64>],
155    ) -> Vec<Vec<f64>> {
156        let n = dist_x.len();
157        let m = dist_y.len();
158
159        // Marginals
160        let p: Vec<f64> = gamma.iter().map(|row| row.iter().sum()).collect();
161        let q: Vec<f64> = (0..m)
162            .map(|j| gamma.iter().map(|row| row[j]).sum())
163            .collect();
164
165        // D_X² p 1^T term
166        let dx2_p: Vec<f64> = (0..n)
167            .map(|i| (0..n).map(|k| dist_x[i][k].powi(2) * p[k]).sum())
168            .collect();
169
170        // 1 q^T D_Y² term
171        let dy2_q: Vec<f64> = (0..m)
172            .map(|j| (0..m).map(|l| dist_y[j][l].powi(2) * q[l]).sum())
173            .collect();
174
175        // D_X γ D_Y^T
176        let dx_gamma_dy: Vec<Vec<f64>> = (0..n)
177            .map(|i| {
178                (0..m)
179                    .map(|j| {
180                        (0..n)
181                            .map(|k| {
182                                (0..m)
183                                    .map(|l| dist_x[i][k] * gamma[k][l] * dist_y[l][j])
184                                    .sum::<f64>()
185                            })
186                            .sum()
187                    })
188                    .collect()
189            })
190            .collect();
191
192        // Gradient = 2 * (dx2_p 1^T + 1 dy2_q^T - 2 * D_X γ D_Y^T)
193        (0..n)
194            .map(|i| {
195                (0..m)
196                    .map(|j| 2.0 * (dx2_p[i] + dy2_q[j] - 2.0 * dx_gamma_dy[i][j]))
197                    .collect()
198            })
199            .collect()
200    }
201
202    /// Solve Gromov-Wasserstein using Frank-Wolfe
203    pub fn solve(
204        &self,
205        source: &[Vec<f64>],
206        target: &[Vec<f64>],
207    ) -> Result<GromovWassersteinResult> {
208        if source.is_empty() || target.is_empty() {
209            return Err(MathError::empty_input("points"));
210        }
211
212        let n = source.len();
213        let m = target.len();
214
215        // Compute distance matrices
216        let dist_x = Self::distance_matrix(source);
217        let dist_y = Self::distance_matrix(target);
218
219        // Initialize with independent coupling
220        let mut gamma: Vec<Vec<f64>> = (0..n).map(|_| vec![1.0 / (n * m) as f64; m]).collect();
221
222        let sinkhorn = SinkhornSolver::new(self.regularization, self.inner_iterations);
223        let source_weights = vec![1.0 / n as f64; n];
224        let target_weights = vec![1.0 / m as f64; m];
225
226        let mut loss = Self::compute_gw_loss(&dist_x, &dist_y, &gamma);
227        let mut converged = false;
228
229        for _iter in 0..self.max_iterations {
230            // Compute gradient (cost matrix for linearized problem)
231            let gradient = Self::compute_gradient(&dist_x, &dist_y, &gamma);
232
233            // Solve linearized problem with Sinkhorn
234            let linear_result = sinkhorn.solve(&gradient, &source_weights, &target_weights)?;
235            let direction = linear_result.plan;
236
237            // Line search
238            let mut best_alpha = 0.0;
239            let mut best_loss = loss;
240
241            for k in 1..=10 {
242                let alpha = k as f64 / 10.0;
243
244                // gamma_new = (1 - alpha) * gamma + alpha * direction
245                let gamma_new: Vec<Vec<f64>> = (0..n)
246                    .map(|i| {
247                        (0..m)
248                            .map(|j| (1.0 - alpha) * gamma[i][j] + alpha * direction[i][j])
249                            .collect()
250                    })
251                    .collect();
252
253                let new_loss = Self::compute_gw_loss(&dist_x, &dist_y, &gamma_new);
254
255                if new_loss < best_loss {
256                    best_alpha = alpha;
257                    best_loss = new_loss;
258                }
259            }
260
261            // Update gamma
262            if best_alpha > 0.0 {
263                for i in 0..n {
264                    for j in 0..m {
265                        gamma[i][j] =
266                            (1.0 - best_alpha) * gamma[i][j] + best_alpha * direction[i][j];
267                    }
268                }
269            }
270
271            // Check convergence
272            let loss_change = (loss - best_loss).abs() / (loss.abs() + EPS);
273            loss = best_loss;
274
275            if loss_change < self.threshold {
276                converged = true;
277                break;
278            }
279        }
280
281        Ok(GromovWassersteinResult {
282            transport_plan: gamma,
283            loss,
284            converged,
285        })
286    }
287
288    /// Compute GW distance between two point clouds
289    pub fn distance(&self, source: &[Vec<f64>], target: &[Vec<f64>]) -> Result<f64> {
290        let result = self.solve(source, target)?;
291        Ok(result.loss.sqrt())
292    }
293}
294
295/// Result of Gromov-Wasserstein computation
296#[derive(Debug, Clone)]
297pub struct GromovWassersteinResult {
298    /// Optimal transport plan
299    pub transport_plan: Vec<Vec<f64>>,
300    /// GW loss value
301    pub loss: f64,
302    /// Whether algorithm converged
303    pub converged: bool,
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_gw_identical() {
312        let gw = GromovWasserstein::new(0.1);
313
314        let points = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
315
316        let dist = gw.distance(&points, &points).unwrap();
317        // GW with entropic regularization won't be exactly 0 for identical structures
318        assert!(
319            dist < 1.0,
320            "Identical structures should have low GW: {}",
321            dist
322        );
323    }
324
325    #[test]
326    fn test_gw_scaled() {
327        let gw = GromovWasserstein::new(0.1);
328
329        let source = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
330
331        // Scale by 2 - structure is preserved!
332        let target: Vec<Vec<f64>> = source
333            .iter()
334            .map(|p| vec![p[0] * 2.0, p[1] * 2.0])
335            .collect();
336
337        let dist = gw.distance(&source, &target).unwrap();
338
339        // GW is NOT invariant to scaling (distances change)
340        // But relative structure is preserved
341        assert!(dist > 0.0, "Scaled structure should have some GW distance");
342    }
343
344    #[test]
345    fn test_gw_different_structures() {
346        let gw = GromovWasserstein::new(0.1);
347
348        // Triangle
349        let triangle = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.5, 0.866]];
350
351        // Line
352        let line = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![2.0, 0.0]];
353
354        let dist = gw.distance(&triangle, &line).unwrap();
355
356        // Different structures should have larger GW distance
357        assert!(
358            dist > 0.1,
359            "Different structures should have high GW: {}",
360            dist
361        );
362    }
363
364    #[test]
365    fn test_distance_matrix() {
366        let points = vec![vec![0.0, 0.0], vec![3.0, 4.0]];
367        let dist = GromovWasserstein::distance_matrix(&points);
368
369        assert!((dist[0][1] - 5.0).abs() < 1e-10);
370        assert!((dist[1][0] - 5.0).abs() < 1e-10);
371        assert!(dist[0][0].abs() < 1e-10);
372    }
373}