ruvector_math/optimal_transport/
sinkhorn.rs

1//! Log-Stabilized Sinkhorn Algorithm
2//!
3//! The Sinkhorn algorithm computes the entropic-regularized optimal transport:
4//!
5//! min_{γ ∈ Π(a,b)} ⟨γ, C⟩ - ε H(γ)
6//!
7//! where H(γ) = -Σ γ_ij log(γ_ij) is the entropy and ε is the regularization.
8//!
9//! ## Log-Stabilization
10//!
11//! We work in log-domain to prevent numerical overflow/underflow:
12//! - Store log(u) and log(v) instead of u, v
13//! - Use log-sum-exp for stable normalization
14//!
15//! ## Complexity
16//!
17//! - O(n² × iterations) for dense cost matrix
18//! - Typically converges in 50-200 iterations
19//! - ~1000x faster than linear programming for exact OT
20
21use crate::error::{MathError, Result};
22use crate::utils::{log_sum_exp, EPS, LOG_MIN};
23
24/// Result of Sinkhorn algorithm
25#[derive(Debug, Clone)]
26pub struct TransportPlan {
27    /// Transport plan matrix γ[i,j] (n × m)
28    pub plan: Vec<Vec<f64>>,
29    /// Total transport cost
30    pub cost: f64,
31    /// Number of iterations to convergence
32    pub iterations: usize,
33    /// Final marginal error (||Pγ - a||₁ + ||γᵀ1 - b||₁)
34    pub marginal_error: f64,
35    /// Whether the algorithm converged
36    pub converged: bool,
37}
38
39/// Log-stabilized Sinkhorn solver for entropic optimal transport
40#[derive(Debug, Clone)]
41pub struct SinkhornSolver {
42    /// Regularization parameter ε
43    regularization: f64,
44    /// Maximum iterations
45    max_iterations: usize,
46    /// Convergence threshold
47    threshold: f64,
48}
49
50impl SinkhornSolver {
51    /// Create a new Sinkhorn solver
52    ///
53    /// # Arguments
54    /// * `regularization` - Entropy regularization ε (0.01-0.1 typical)
55    /// * `max_iterations` - Maximum Sinkhorn iterations (100-1000 typical)
56    pub fn new(regularization: f64, max_iterations: usize) -> Self {
57        Self {
58            regularization: regularization.max(1e-6),
59            max_iterations: max_iterations.max(1),
60            threshold: 1e-6,
61        }
62    }
63
64    /// Set convergence threshold
65    pub fn with_threshold(mut self, threshold: f64) -> Self {
66        self.threshold = threshold.max(1e-12);
67        self
68    }
69
70    /// Compute the cost matrix for squared Euclidean distance
71    /// Uses SIMD-friendly 4-way unrolled accumulator for better performance
72    #[inline]
73    pub fn compute_cost_matrix(source: &[Vec<f64>], target: &[Vec<f64>]) -> Vec<Vec<f64>> {
74        source
75            .iter()
76            .map(|s| {
77                target
78                    .iter()
79                    .map(|t| Self::squared_euclidean(s, t))
80                    .collect()
81            })
82            .collect()
83    }
84
85    /// SIMD-friendly squared Euclidean distance
86    #[inline(always)]
87    fn squared_euclidean(a: &[f64], b: &[f64]) -> f64 {
88        let len = a.len();
89        let chunks = len / 4;
90        let remainder = len % 4;
91
92        let mut sum0 = 0.0f64;
93        let mut sum1 = 0.0f64;
94        let mut sum2 = 0.0f64;
95        let mut sum3 = 0.0f64;
96
97        for i in 0..chunks {
98            let base = i * 4;
99            let d0 = a[base] - b[base];
100            let d1 = a[base + 1] - b[base + 1];
101            let d2 = a[base + 2] - b[base + 2];
102            let d3 = a[base + 3] - b[base + 3];
103            sum0 += d0 * d0;
104            sum1 += d1 * d1;
105            sum2 += d2 * d2;
106            sum3 += d3 * d3;
107        }
108
109        let base = chunks * 4;
110        for i in 0..remainder {
111            let d = a[base + i] - b[base + i];
112            sum0 += d * d;
113        }
114
115        sum0 + sum1 + sum2 + sum3
116    }
117
118    /// Solve optimal transport using log-stabilized Sinkhorn
119    ///
120    /// # Arguments
121    /// * `cost_matrix` - C[i,j] = cost to move from source[i] to target[j]
122    /// * `source_weights` - Marginal distribution a (sum to 1)
123    /// * `target_weights` - Marginal distribution b (sum to 1)
124    pub fn solve(
125        &self,
126        cost_matrix: &[Vec<f64>],
127        source_weights: &[f64],
128        target_weights: &[f64],
129    ) -> Result<TransportPlan> {
130        let n = source_weights.len();
131        let m = target_weights.len();
132
133        if n == 0 || m == 0 {
134            return Err(MathError::empty_input("weights"));
135        }
136
137        if cost_matrix.len() != n || cost_matrix.iter().any(|row| row.len() != m) {
138            return Err(MathError::dimension_mismatch(n, cost_matrix.len()));
139        }
140
141        // Normalize weights
142        let sum_a: f64 = source_weights.iter().sum();
143        let sum_b: f64 = target_weights.iter().sum();
144        let a: Vec<f64> = source_weights.iter().map(|&w| w / sum_a).collect();
145        let b: Vec<f64> = target_weights.iter().map(|&w| w / sum_b).collect();
146
147        // Initialize log-domain Gibbs kernel: K = exp(-C/ε)
148        // Store log(K) = -C/ε
149        let log_k: Vec<Vec<f64>> = cost_matrix
150            .iter()
151            .map(|row| row.iter().map(|&c| -c / self.regularization).collect())
152            .collect();
153
154        // Initialize log scaling vectors
155        let mut log_u = vec![0.0; n];
156        let mut log_v = vec![0.0; m];
157
158        let log_a: Vec<f64> = a.iter().map(|&ai| ai.ln().max(LOG_MIN)).collect();
159        let log_b: Vec<f64> = b.iter().map(|&bi| bi.ln().max(LOG_MIN)).collect();
160
161        let mut converged = false;
162        let mut iterations = 0;
163        let mut marginal_error = f64::INFINITY;
164
165        // Pre-allocate buffers for log-sum-exp computation (reduces allocations per iteration)
166        let mut log_terms_row = vec![0.0; m];
167        let mut log_terms_col = vec![0.0; n];
168
169        // Sinkhorn iterations in log domain
170        for iter in 0..self.max_iterations {
171            iterations = iter + 1;
172
173            // Update log_u: log_u = log_a - log_sum_exp_j(log_v[j] + log_K[i,j])
174            let mut max_u_change: f64 = 0.0;
175            for i in 0..n {
176                let old_log_u = log_u[i];
177                // Compute into pre-allocated buffer
178                for j in 0..m {
179                    log_terms_row[j] = log_v[j] + log_k[i][j];
180                }
181                let lse = log_sum_exp(&log_terms_row);
182                log_u[i] = log_a[i] - lse;
183                max_u_change = max_u_change.max((log_u[i] - old_log_u).abs());
184            }
185
186            // Update log_v: log_v = log_b - log_sum_exp_i(log_u[i] + log_K[i,j])
187            let mut max_v_change: f64 = 0.0;
188            for j in 0..m {
189                let old_log_v = log_v[j];
190                // Compute into pre-allocated buffer
191                for i in 0..n {
192                    log_terms_col[i] = log_u[i] + log_k[i][j];
193                }
194                let lse = log_sum_exp(&log_terms_col);
195                log_v[j] = log_b[j] - lse;
196                max_v_change = max_v_change.max((log_v[j] - old_log_v).abs());
197            }
198
199            // Check convergence
200            let max_change = max_u_change.max(max_v_change);
201
202            // Compute marginal error every 10 iterations
203            if iter % 10 == 0 || max_change < self.threshold {
204                marginal_error = self.compute_marginal_error(&log_u, &log_v, &log_k, &a, &b);
205
206                if max_change < self.threshold && marginal_error < self.threshold * 10.0 {
207                    converged = true;
208                    break;
209                }
210            }
211        }
212
213        // Compute transport plan: γ[i,j] = exp(log_u[i] + log_K[i,j] + log_v[j])
214        let plan: Vec<Vec<f64>> = (0..n)
215            .map(|i| {
216                (0..m)
217                    .map(|j| {
218                        let log_gamma = log_u[i] + log_k[i][j] + log_v[j];
219                        log_gamma.exp().max(0.0)
220                    })
221                    .collect()
222            })
223            .collect();
224
225        // Compute transport cost: ⟨γ, C⟩
226        let cost = plan
227            .iter()
228            .zip(cost_matrix.iter())
229            .map(|(gamma_row, cost_row)| {
230                gamma_row
231                    .iter()
232                    .zip(cost_row.iter())
233                    .map(|(&g, &c)| g * c)
234                    .sum::<f64>()
235            })
236            .sum();
237
238        Ok(TransportPlan {
239            plan,
240            cost,
241            iterations,
242            marginal_error,
243            converged,
244        })
245    }
246
247    /// Compute marginal constraint error
248    fn compute_marginal_error(
249        &self,
250        log_u: &[f64],
251        log_v: &[f64],
252        log_k: &[Vec<f64>],
253        a: &[f64],
254        b: &[f64],
255    ) -> f64 {
256        let n = log_u.len();
257        let m = log_v.len();
258
259        // Compute row sums (γ1 should equal a)
260        let mut row_error = 0.0;
261        for i in 0..n {
262            let log_row_sum = log_sum_exp(
263                &(0..m)
264                    .map(|j| log_u[i] + log_k[i][j] + log_v[j])
265                    .collect::<Vec<_>>(),
266            );
267            row_error += (log_row_sum.exp() - a[i]).abs();
268        }
269
270        // Compute column sums (γᵀ1 should equal b)
271        let mut col_error = 0.0;
272        for j in 0..m {
273            let log_col_sum = log_sum_exp(
274                &(0..n)
275                    .map(|i| log_u[i] + log_k[i][j] + log_v[j])
276                    .collect::<Vec<_>>(),
277            );
278            col_error += (log_col_sum.exp() - b[j]).abs();
279        }
280
281        row_error + col_error
282    }
283
284    /// Compute Sinkhorn distance (optimal transport cost) between point clouds
285    pub fn distance(&self, source: &[Vec<f64>], target: &[Vec<f64>]) -> Result<f64> {
286        let cost_matrix = Self::compute_cost_matrix(source, target);
287
288        // Uniform weights
289        let n = source.len();
290        let m = target.len();
291        let source_weights = vec![1.0 / n as f64; n];
292        let target_weights = vec![1.0 / m as f64; m];
293
294        let result = self.solve(&cost_matrix, &source_weights, &target_weights)?;
295        Ok(result.cost)
296    }
297
298    /// Compute Wasserstein barycenter of multiple distributions
299    ///
300    /// Returns the barycenter (mean distribution) in transport space
301    pub fn barycenter(
302        &self,
303        distributions: &[&[Vec<f64>]],
304        weights: Option<&[f64]>,
305        support_size: usize,
306        dim: usize,
307    ) -> Result<Vec<Vec<f64>>> {
308        if distributions.is_empty() {
309            return Err(MathError::empty_input("distributions"));
310        }
311
312        let k = distributions.len();
313        let barycenter_weights = match weights {
314            Some(w) => {
315                let sum: f64 = w.iter().sum();
316                w.iter().map(|&wi| wi / sum).collect()
317            }
318            None => vec![1.0 / k as f64; k],
319        };
320
321        // Initialize barycenter as mean of first distribution
322        let mut barycenter: Vec<Vec<f64>> = (0..support_size)
323            .map(|i| {
324                let t = i as f64 / (support_size - 1).max(1) as f64;
325                vec![t; dim]
326            })
327            .collect();
328
329        // Fixed-point iteration to find barycenter
330        for _outer in 0..20 {
331            // For each input distribution, compute transport to barycenter
332            let mut displacements = vec![vec![0.0; dim]; support_size];
333
334            for (dist_idx, &distribution) in distributions.iter().enumerate() {
335                let cost_matrix = Self::compute_cost_matrix(distribution, &barycenter);
336
337                let n = distribution.len();
338                let source_w = vec![1.0 / n as f64; n];
339                let target_w = vec![1.0 / support_size as f64; support_size];
340
341                if let Ok(plan) = self.solve(&cost_matrix, &source_w, &target_w) {
342                    // Compute displacement from plan
343                    for j in 0..support_size {
344                        for i in 0..n {
345                            let weight = plan.plan[i][j] * support_size as f64;
346                            for d in 0..dim {
347                                displacements[j][d] +=
348                                    barycenter_weights[dist_idx] * weight * (distribution[i][d] - barycenter[j][d]);
349                            }
350                        }
351                    }
352                }
353            }
354
355            // Update barycenter
356            let mut max_update: f64 = 0.0;
357            for j in 0..support_size {
358                for d in 0..dim {
359                    let delta = displacements[j][d] * 0.5; // Step size
360                    barycenter[j][d] += delta;
361                    max_update = max_update.max(delta.abs());
362                }
363            }
364
365            if max_update < EPS {
366                break;
367            }
368        }
369
370        Ok(barycenter)
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_sinkhorn_identity() {
380        let solver = SinkhornSolver::new(0.1, 100);
381
382        let source = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
383        let target = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
384
385        let cost = solver.distance(&source, &target).unwrap();
386        assert!(cost < 0.1, "Identity should have near-zero cost: {}", cost);
387    }
388
389    #[test]
390    fn test_sinkhorn_translation() {
391        let solver = SinkhornSolver::new(0.05, 200);
392
393        let source = vec![
394            vec![0.0, 0.0],
395            vec![1.0, 0.0],
396            vec![0.0, 1.0],
397            vec![1.0, 1.0],
398        ];
399
400        // Translate by (1, 0)
401        let target: Vec<Vec<f64>> = source.iter().map(|p| vec![p[0] + 1.0, p[1]]).collect();
402
403        let cost = solver.distance(&source, &target).unwrap();
404
405        // Expected cost for unit translation: each point moves distance 1
406        // With squared Euclidean: cost ≈ 1.0
407        assert!(
408            cost > 0.5 && cost < 2.0,
409            "Translation cost should be ~1.0: {}",
410            cost
411        );
412    }
413
414    #[test]
415    fn test_sinkhorn_convergence() {
416        let solver = SinkhornSolver::new(0.1, 100).with_threshold(1e-6);
417
418        let cost_matrix = vec![
419            vec![0.0, 1.0, 2.0],
420            vec![1.0, 0.0, 1.0],
421            vec![2.0, 1.0, 0.0],
422        ];
423
424        let a = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
425        let b = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
426
427        let result = solver.solve(&cost_matrix, &a, &b).unwrap();
428
429        assert!(result.converged, "Should converge");
430        assert!(
431            result.marginal_error < 0.01,
432            "Marginal error too high: {}",
433            result.marginal_error
434        );
435    }
436
437    #[test]
438    fn test_transport_plan_marginals() {
439        let solver = SinkhornSolver::new(0.1, 100);
440
441        let cost_matrix = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
442
443        let a = vec![0.3, 0.7];
444        let b = vec![0.6, 0.4];
445
446        let result = solver.solve(&cost_matrix, &a, &b).unwrap();
447
448        // Check row marginals
449        for (i, &ai) in a.iter().enumerate() {
450            let row_sum: f64 = result.plan[i].iter().sum();
451            assert!(
452                (row_sum - ai).abs() < 0.05,
453                "Row {} sum {} != {}",
454                i,
455                row_sum,
456                ai
457            );
458        }
459
460        // Check column marginals
461        for (j, &bj) in b.iter().enumerate() {
462            let col_sum: f64 = result.plan.iter().map(|row| row[j]).sum();
463            assert!(
464                (col_sum - bj).abs() < 0.05,
465                "Col {} sum {} != {}",
466                j,
467                col_sum,
468                bj
469            );
470        }
471    }
472}