ruvector_math/optimal_transport/
sinkhorn.rs1use crate::error::{MathError, Result};
22use crate::utils::{log_sum_exp, EPS, LOG_MIN};
23
24#[derive(Debug, Clone)]
26pub struct TransportPlan {
27 pub plan: Vec<Vec<f64>>,
29 pub cost: f64,
31 pub iterations: usize,
33 pub marginal_error: f64,
35 pub converged: bool,
37}
38
39#[derive(Debug, Clone)]
41pub struct SinkhornSolver {
42 regularization: f64,
44 max_iterations: usize,
46 threshold: f64,
48}
49
50impl SinkhornSolver {
51 pub fn new(regularization: f64, max_iterations: usize) -> Self {
57 Self {
58 regularization: regularization.max(1e-6),
59 max_iterations: max_iterations.max(1),
60 threshold: 1e-6,
61 }
62 }
63
64 pub fn with_threshold(mut self, threshold: f64) -> Self {
66 self.threshold = threshold.max(1e-12);
67 self
68 }
69
70 #[inline]
73 pub fn compute_cost_matrix(source: &[Vec<f64>], target: &[Vec<f64>]) -> Vec<Vec<f64>> {
74 source
75 .iter()
76 .map(|s| {
77 target
78 .iter()
79 .map(|t| Self::squared_euclidean(s, t))
80 .collect()
81 })
82 .collect()
83 }
84
85 #[inline(always)]
87 fn squared_euclidean(a: &[f64], b: &[f64]) -> f64 {
88 let len = a.len();
89 let chunks = len / 4;
90 let remainder = len % 4;
91
92 let mut sum0 = 0.0f64;
93 let mut sum1 = 0.0f64;
94 let mut sum2 = 0.0f64;
95 let mut sum3 = 0.0f64;
96
97 for i in 0..chunks {
98 let base = i * 4;
99 let d0 = a[base] - b[base];
100 let d1 = a[base + 1] - b[base + 1];
101 let d2 = a[base + 2] - b[base + 2];
102 let d3 = a[base + 3] - b[base + 3];
103 sum0 += d0 * d0;
104 sum1 += d1 * d1;
105 sum2 += d2 * d2;
106 sum3 += d3 * d3;
107 }
108
109 let base = chunks * 4;
110 for i in 0..remainder {
111 let d = a[base + i] - b[base + i];
112 sum0 += d * d;
113 }
114
115 sum0 + sum1 + sum2 + sum3
116 }
117
118 pub fn solve(
125 &self,
126 cost_matrix: &[Vec<f64>],
127 source_weights: &[f64],
128 target_weights: &[f64],
129 ) -> Result<TransportPlan> {
130 let n = source_weights.len();
131 let m = target_weights.len();
132
133 if n == 0 || m == 0 {
134 return Err(MathError::empty_input("weights"));
135 }
136
137 if cost_matrix.len() != n || cost_matrix.iter().any(|row| row.len() != m) {
138 return Err(MathError::dimension_mismatch(n, cost_matrix.len()));
139 }
140
141 let sum_a: f64 = source_weights.iter().sum();
143 let sum_b: f64 = target_weights.iter().sum();
144 let a: Vec<f64> = source_weights.iter().map(|&w| w / sum_a).collect();
145 let b: Vec<f64> = target_weights.iter().map(|&w| w / sum_b).collect();
146
147 let log_k: Vec<Vec<f64>> = cost_matrix
150 .iter()
151 .map(|row| row.iter().map(|&c| -c / self.regularization).collect())
152 .collect();
153
154 let mut log_u = vec![0.0; n];
156 let mut log_v = vec![0.0; m];
157
158 let log_a: Vec<f64> = a.iter().map(|&ai| ai.ln().max(LOG_MIN)).collect();
159 let log_b: Vec<f64> = b.iter().map(|&bi| bi.ln().max(LOG_MIN)).collect();
160
161 let mut converged = false;
162 let mut iterations = 0;
163 let mut marginal_error = f64::INFINITY;
164
165 let mut log_terms_row = vec![0.0; m];
167 let mut log_terms_col = vec![0.0; n];
168
169 for iter in 0..self.max_iterations {
171 iterations = iter + 1;
172
173 let mut max_u_change: f64 = 0.0;
175 for i in 0..n {
176 let old_log_u = log_u[i];
177 for j in 0..m {
179 log_terms_row[j] = log_v[j] + log_k[i][j];
180 }
181 let lse = log_sum_exp(&log_terms_row);
182 log_u[i] = log_a[i] - lse;
183 max_u_change = max_u_change.max((log_u[i] - old_log_u).abs());
184 }
185
186 let mut max_v_change: f64 = 0.0;
188 for j in 0..m {
189 let old_log_v = log_v[j];
190 for i in 0..n {
192 log_terms_col[i] = log_u[i] + log_k[i][j];
193 }
194 let lse = log_sum_exp(&log_terms_col);
195 log_v[j] = log_b[j] - lse;
196 max_v_change = max_v_change.max((log_v[j] - old_log_v).abs());
197 }
198
199 let max_change = max_u_change.max(max_v_change);
201
202 if iter % 10 == 0 || max_change < self.threshold {
204 marginal_error = self.compute_marginal_error(&log_u, &log_v, &log_k, &a, &b);
205
206 if max_change < self.threshold && marginal_error < self.threshold * 10.0 {
207 converged = true;
208 break;
209 }
210 }
211 }
212
213 let plan: Vec<Vec<f64>> = (0..n)
215 .map(|i| {
216 (0..m)
217 .map(|j| {
218 let log_gamma = log_u[i] + log_k[i][j] + log_v[j];
219 log_gamma.exp().max(0.0)
220 })
221 .collect()
222 })
223 .collect();
224
225 let cost = plan
227 .iter()
228 .zip(cost_matrix.iter())
229 .map(|(gamma_row, cost_row)| {
230 gamma_row
231 .iter()
232 .zip(cost_row.iter())
233 .map(|(&g, &c)| g * c)
234 .sum::<f64>()
235 })
236 .sum();
237
238 Ok(TransportPlan {
239 plan,
240 cost,
241 iterations,
242 marginal_error,
243 converged,
244 })
245 }
246
247 fn compute_marginal_error(
249 &self,
250 log_u: &[f64],
251 log_v: &[f64],
252 log_k: &[Vec<f64>],
253 a: &[f64],
254 b: &[f64],
255 ) -> f64 {
256 let n = log_u.len();
257 let m = log_v.len();
258
259 let mut row_error = 0.0;
261 for i in 0..n {
262 let log_row_sum = log_sum_exp(
263 &(0..m)
264 .map(|j| log_u[i] + log_k[i][j] + log_v[j])
265 .collect::<Vec<_>>(),
266 );
267 row_error += (log_row_sum.exp() - a[i]).abs();
268 }
269
270 let mut col_error = 0.0;
272 for j in 0..m {
273 let log_col_sum = log_sum_exp(
274 &(0..n)
275 .map(|i| log_u[i] + log_k[i][j] + log_v[j])
276 .collect::<Vec<_>>(),
277 );
278 col_error += (log_col_sum.exp() - b[j]).abs();
279 }
280
281 row_error + col_error
282 }
283
284 pub fn distance(&self, source: &[Vec<f64>], target: &[Vec<f64>]) -> Result<f64> {
286 let cost_matrix = Self::compute_cost_matrix(source, target);
287
288 let n = source.len();
290 let m = target.len();
291 let source_weights = vec![1.0 / n as f64; n];
292 let target_weights = vec![1.0 / m as f64; m];
293
294 let result = self.solve(&cost_matrix, &source_weights, &target_weights)?;
295 Ok(result.cost)
296 }
297
298 pub fn barycenter(
302 &self,
303 distributions: &[&[Vec<f64>]],
304 weights: Option<&[f64]>,
305 support_size: usize,
306 dim: usize,
307 ) -> Result<Vec<Vec<f64>>> {
308 if distributions.is_empty() {
309 return Err(MathError::empty_input("distributions"));
310 }
311
312 let k = distributions.len();
313 let barycenter_weights = match weights {
314 Some(w) => {
315 let sum: f64 = w.iter().sum();
316 w.iter().map(|&wi| wi / sum).collect()
317 }
318 None => vec![1.0 / k as f64; k],
319 };
320
321 let mut barycenter: Vec<Vec<f64>> = (0..support_size)
323 .map(|i| {
324 let t = i as f64 / (support_size - 1).max(1) as f64;
325 vec![t; dim]
326 })
327 .collect();
328
329 for _outer in 0..20 {
331 let mut displacements = vec![vec![0.0; dim]; support_size];
333
334 for (dist_idx, &distribution) in distributions.iter().enumerate() {
335 let cost_matrix = Self::compute_cost_matrix(distribution, &barycenter);
336
337 let n = distribution.len();
338 let source_w = vec![1.0 / n as f64; n];
339 let target_w = vec![1.0 / support_size as f64; support_size];
340
341 if let Ok(plan) = self.solve(&cost_matrix, &source_w, &target_w) {
342 for j in 0..support_size {
344 for i in 0..n {
345 let weight = plan.plan[i][j] * support_size as f64;
346 for d in 0..dim {
347 displacements[j][d] +=
348 barycenter_weights[dist_idx] * weight * (distribution[i][d] - barycenter[j][d]);
349 }
350 }
351 }
352 }
353 }
354
355 let mut max_update: f64 = 0.0;
357 for j in 0..support_size {
358 for d in 0..dim {
359 let delta = displacements[j][d] * 0.5; barycenter[j][d] += delta;
361 max_update = max_update.max(delta.abs());
362 }
363 }
364
365 if max_update < EPS {
366 break;
367 }
368 }
369
370 Ok(barycenter)
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_sinkhorn_identity() {
380 let solver = SinkhornSolver::new(0.1, 100);
381
382 let source = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
383 let target = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
384
385 let cost = solver.distance(&source, &target).unwrap();
386 assert!(cost < 0.1, "Identity should have near-zero cost: {}", cost);
387 }
388
389 #[test]
390 fn test_sinkhorn_translation() {
391 let solver = SinkhornSolver::new(0.05, 200);
392
393 let source = vec![
394 vec![0.0, 0.0],
395 vec![1.0, 0.0],
396 vec![0.0, 1.0],
397 vec![1.0, 1.0],
398 ];
399
400 let target: Vec<Vec<f64>> = source.iter().map(|p| vec![p[0] + 1.0, p[1]]).collect();
402
403 let cost = solver.distance(&source, &target).unwrap();
404
405 assert!(
408 cost > 0.5 && cost < 2.0,
409 "Translation cost should be ~1.0: {}",
410 cost
411 );
412 }
413
414 #[test]
415 fn test_sinkhorn_convergence() {
416 let solver = SinkhornSolver::new(0.1, 100).with_threshold(1e-6);
417
418 let cost_matrix = vec![
419 vec![0.0, 1.0, 2.0],
420 vec![1.0, 0.0, 1.0],
421 vec![2.0, 1.0, 0.0],
422 ];
423
424 let a = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
425 let b = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
426
427 let result = solver.solve(&cost_matrix, &a, &b).unwrap();
428
429 assert!(result.converged, "Should converge");
430 assert!(
431 result.marginal_error < 0.01,
432 "Marginal error too high: {}",
433 result.marginal_error
434 );
435 }
436
437 #[test]
438 fn test_transport_plan_marginals() {
439 let solver = SinkhornSolver::new(0.1, 100);
440
441 let cost_matrix = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
442
443 let a = vec![0.3, 0.7];
444 let b = vec![0.6, 0.4];
445
446 let result = solver.solve(&cost_matrix, &a, &b).unwrap();
447
448 for (i, &ai) in a.iter().enumerate() {
450 let row_sum: f64 = result.plan[i].iter().sum();
451 assert!(
452 (row_sum - ai).abs() < 0.05,
453 "Row {} sum {} != {}",
454 i,
455 row_sum,
456 ai
457 );
458 }
459
460 for (j, &bj) in b.iter().enumerate() {
462 let col_sum: f64 = result.plan.iter().map(|row| row[j]).sum();
463 assert!(
464 (col_sum - bj).abs() < 0.05,
465 "Col {} sum {} != {}",
466 j,
467 col_sum,
468 bj
469 );
470 }
471 }
472}