1use scirs2_core::ndarray::{Array1, Array2};
42
43use crate::error::{Result, TransformError};
44
45#[non_exhaustive]
51#[derive(Debug, Clone, PartialEq)]
52pub enum UnbalancedRegularization {
53 KLDivergence,
57 L2,
60}
61
62#[derive(Debug, Clone)]
68pub struct UnbalancedOtConfig {
69 pub epsilon: f64,
71 pub tau: f64,
76 pub regularization: UnbalancedRegularization,
78 pub max_iter: usize,
80 pub tol: f64,
82 pub log_domain: bool,
85}
86
87impl Default for UnbalancedOtConfig {
88 fn default() -> Self {
89 Self {
90 epsilon: 0.1,
91 tau: 1.0,
92 regularization: UnbalancedRegularization::KLDivergence,
93 max_iter: 1000,
94 tol: 1e-6,
95 log_domain: true,
96 }
97 }
98}
99
100#[derive(Debug, Clone)]
106pub struct UnbalancedOtResult {
107 pub transport_plan: Array2<f64>,
109 pub cost: f64,
111 pub marginal_violation_source: f64,
113 pub marginal_violation_target: f64,
115 pub n_iter: usize,
117 pub converged: bool,
119}
120
121pub fn unbalanced_sinkhorn(
153 a: &[f64],
154 b: &[f64],
155 cost: &Array2<f64>,
156 config: &UnbalancedOtConfig,
157) -> Result<UnbalancedOtResult> {
158 let n = a.len();
162 let m = b.len();
163
164 if n == 0 {
165 return Err(TransformError::InvalidInput(
166 "Source histogram 'a' must be non-empty".to_string(),
167 ));
168 }
169 if m == 0 {
170 return Err(TransformError::InvalidInput(
171 "Target histogram 'b' must be non-empty".to_string(),
172 ));
173 }
174 if cost.dim() != (n, m) {
175 return Err(TransformError::InvalidInput(format!(
176 "Cost matrix shape ({},{}) does not match histogram lengths ({n},{m})",
177 cost.nrows(),
178 cost.ncols()
179 )));
180 }
181 if config.epsilon <= 0.0 {
182 return Err(TransformError::InvalidInput(
183 "epsilon must be positive".to_string(),
184 ));
185 }
186 if config.tau <= 0.0 {
187 return Err(TransformError::InvalidInput(
188 "tau must be positive".to_string(),
189 ));
190 }
191 for &ai in a {
192 if ai < 0.0 {
193 return Err(TransformError::InvalidInput(
194 "Source histogram contains negative entries".to_string(),
195 ));
196 }
197 }
198 for &bi in b {
199 if bi < 0.0 {
200 return Err(TransformError::InvalidInput(
201 "Target histogram contains negative entries".to_string(),
202 ));
203 }
204 }
205 let sum_a: f64 = a.iter().sum();
206 let sum_b: f64 = b.iter().sum();
207 if sum_a < f64::EPSILON {
208 return Err(TransformError::InvalidInput(
209 "Source histogram has zero total mass".to_string(),
210 ));
211 }
212 if sum_b < f64::EPSILON {
213 return Err(TransformError::InvalidInput(
214 "Target histogram has zero total mass".to_string(),
215 ));
216 }
217
218 for ci in cost.iter() {
222 if *ci < 0.0 {
223 return Err(TransformError::InvalidInput(
224 "Cost matrix contains negative entries".to_string(),
225 ));
226 }
227 }
228
229 match config.regularization {
230 UnbalancedRegularization::KLDivergence => {
231 if config.log_domain {
232 sinkhorn_kl_log_domain(a, b, cost, config)
233 } else {
234 sinkhorn_kl(a, b, cost, config)
235 }
236 }
237 UnbalancedRegularization::L2 => sinkhorn_l2(a, b, cost, config),
238 }
239}
240
241fn sinkhorn_kl(
249 a: &[f64],
250 b: &[f64],
251 cost: &Array2<f64>,
252 config: &UnbalancedOtConfig,
253) -> Result<UnbalancedOtResult> {
254 let n = a.len();
255 let m = b.len();
256 let rho = config.tau / (config.tau + config.epsilon);
257
258 let k: Array2<f64> = cost.mapv(|c| (-c / config.epsilon).exp());
260
261 let mut u = Array1::from_elem(n, 1.0_f64);
263 let mut v = Array1::from_elem(m, 1.0_f64);
264
265 let a_arr = Array1::from_vec(a.to_vec());
266 let b_arr = Array1::from_vec(b.to_vec());
267
268 let mut converged = false;
269 let mut n_iter = 0usize;
270
271 for _iter in 0..config.max_iter {
272 n_iter += 1;
273
274 let kv: Array1<f64> = k.dot(&v);
276 let u_new: Array1<f64> = a_arr
278 .iter()
279 .zip(kv.iter())
280 .map(|(&ai, &kvi)| {
281 if kvi < f64::EPSILON {
282 0.0
283 } else {
284 (ai / kvi).powf(rho)
285 }
286 })
287 .collect::<Vec<f64>>()
288 .into();
289
290 let ktu: Array1<f64> = k.t().dot(&u_new);
292 let v_new: Array1<f64> = b_arr
294 .iter()
295 .zip(ktu.iter())
296 .map(|(&bi, &ktui)| {
297 if ktui < f64::EPSILON {
298 0.0
299 } else {
300 (bi / ktui).powf(rho)
301 }
302 })
303 .collect::<Vec<f64>>()
304 .into();
305
306 let du: f64 = u_new
308 .iter()
309 .zip(u.iter())
310 .map(|(&a, &b)| (a - b).abs())
311 .sum::<f64>()
312 / (n as f64);
313 let dv: f64 = v_new
314 .iter()
315 .zip(v.iter())
316 .map(|(&a, &b)| (a - b).abs())
317 .sum::<f64>()
318 / (m as f64);
319
320 u = u_new;
321 v = v_new;
322
323 if du + dv < config.tol {
324 converged = true;
325 break;
326 }
327 }
328
329 let transport_plan = build_transport_plan(&u, &k, &v);
331 let result = compute_result(transport_plan, cost, a, b, n_iter, converged);
332 Ok(result)
333}
334
335fn sinkhorn_kl_log_domain(
343 a: &[f64],
344 b: &[f64],
345 cost: &Array2<f64>,
346 config: &UnbalancedOtConfig,
347) -> Result<UnbalancedOtResult> {
348 let n = a.len();
349 let m = b.len();
350 let rho = config.tau / (config.tau + config.epsilon);
351 let eps = config.epsilon;
352
353 let mut f: Array1<f64> = Array1::zeros(n);
355 let mut g: Array1<f64> = Array1::zeros(m);
356
357 let log_a: Vec<f64> = a
358 .iter()
359 .map(|&ai| if ai > 0.0 { ai.ln() } else { f64::NEG_INFINITY })
360 .collect();
361 let log_b: Vec<f64> = b
362 .iter()
363 .map(|&bi| if bi > 0.0 { bi.ln() } else { f64::NEG_INFINITY })
364 .collect();
365
366 let mut converged = false;
367 let mut n_iter = 0usize;
368
369 for _iter in 0..config.max_iter {
370 n_iter += 1;
371
372 let f_prev = f.clone();
377 let g_prev = g.clone();
378
379 for i in 0..n {
384 let lse_j = log_sum_exp_row(i, &g, cost, eps, m);
385 let new_fi = rho * (eps * log_a[i] - lse_j);
386 f[i] = new_fi;
387 }
388
389 for j in 0..m {
391 let lse_i = log_sum_exp_col(j, &f, cost, eps, n);
392 let new_gj = rho * (eps * log_b[j] - lse_i);
393 g[j] = new_gj;
394 }
395
396 let df: f64 = f
398 .iter()
399 .zip(f_prev.iter())
400 .map(|(&a, &b)| (a - b).abs())
401 .sum::<f64>()
402 / n as f64;
403 let dg: f64 = g
404 .iter()
405 .zip(g_prev.iter())
406 .map(|(&a, &b)| (a - b).abs())
407 .sum::<f64>()
408 / m as f64;
409
410 if df + dg < config.tol {
411 converged = true;
412 break;
413 }
414 }
415
416 let mut transport_plan = Array2::<f64>::zeros((n, m));
418 for i in 0..n {
419 for j in 0..m {
420 transport_plan[[i, j]] = ((f[i] + g[j] - cost[[i, j]]) / eps).exp();
421 }
422 }
423
424 let result = compute_result(transport_plan, cost, a, b, n_iter, converged);
425 Ok(result)
426}
427
428#[inline]
430fn log_sum_exp_row(i: usize, g: &Array1<f64>, cost: &Array2<f64>, eps: f64, m: usize) -> f64 {
431 let vals: Vec<f64> = (0..m).map(|j| g[j] - cost[[i, j]] / eps).collect();
432 log_sum_exp_vec(&vals)
433}
434
435#[inline]
437fn log_sum_exp_col(j: usize, f: &Array1<f64>, cost: &Array2<f64>, eps: f64, n: usize) -> f64 {
438 let vals: Vec<f64> = (0..n).map(|i| f[i] - cost[[i, j]] / eps).collect();
439 log_sum_exp_vec(&vals)
440}
441
442fn log_sum_exp_vec(vals: &[f64]) -> f64 {
444 if vals.is_empty() {
445 return f64::NEG_INFINITY;
446 }
447 let max_val = vals
448 .iter()
449 .copied()
450 .filter(|v| v.is_finite())
451 .fold(f64::NEG_INFINITY, f64::max);
452 if !max_val.is_finite() {
453 return f64::NEG_INFINITY;
454 }
455 let sum_exp: f64 = vals
456 .iter()
457 .filter(|v| v.is_finite())
458 .map(|&v| (v - max_val).exp())
459 .sum();
460 max_val + sum_exp.ln()
461}
462
463fn sinkhorn_l2(
475 a: &[f64],
476 b: &[f64],
477 cost: &Array2<f64>,
478 config: &UnbalancedOtConfig,
479) -> Result<UnbalancedOtResult> {
480 let n = a.len();
481 let m = b.len();
482
483 let k: Array2<f64> = cost.mapv(|c| (-c / config.epsilon).exp());
484 let mut u = Array1::from_elem(n, 1.0_f64);
485 let mut v = Array1::from_elem(m, 1.0_f64);
486
487 let a_arr = Array1::from_vec(a.to_vec());
488 let b_arr = Array1::from_vec(b.to_vec());
489
490 let lambda = config.epsilon / config.tau;
493
494 let mut converged = false;
495 let mut n_iter = 0usize;
496
497 for _iter in 0..config.max_iter {
498 n_iter += 1;
499 let kv: Array1<f64> = k.dot(&v);
500 let u_new: Array1<f64> = a_arr
501 .iter()
502 .zip(kv.iter())
503 .map(|(&ai, &kvi)| ai / (kvi + lambda).max(f64::EPSILON))
504 .collect::<Vec<f64>>()
505 .into();
506
507 let ktu: Array1<f64> = k.t().dot(&u_new);
508 let v_new: Array1<f64> = b_arr
509 .iter()
510 .zip(ktu.iter())
511 .map(|(&bi, &ktui)| bi / (ktui + lambda).max(f64::EPSILON))
512 .collect::<Vec<f64>>()
513 .into();
514
515 let du: f64 = u_new
516 .iter()
517 .zip(u.iter())
518 .map(|(&a, &b)| (a - b).abs())
519 .sum::<f64>()
520 / n as f64;
521 let dv: f64 = v_new
522 .iter()
523 .zip(v.iter())
524 .map(|(&a, &b)| (a - b).abs())
525 .sum::<f64>()
526 / m as f64;
527
528 u = u_new;
529 v = v_new;
530
531 if du + dv < config.tol {
532 converged = true;
533 break;
534 }
535 }
536
537 let transport_plan = build_transport_plan(&u, &k, &v);
538 let result = compute_result(transport_plan, cost, a, b, n_iter, converged);
539 Ok(result)
540}
541
542fn build_transport_plan(u: &Array1<f64>, k: &Array2<f64>, v: &Array1<f64>) -> Array2<f64> {
548 let n = u.len();
549 let m = v.len();
550 let mut t = Array2::zeros((n, m));
551 for i in 0..n {
552 for j in 0..m {
553 t[[i, j]] = u[i] * k[[i, j]] * v[j];
554 }
555 }
556 t
557}
558
559fn compute_result(
561 transport_plan: Array2<f64>,
562 cost: &Array2<f64>,
563 a: &[f64],
564 b: &[f64],
565 n_iter: usize,
566 converged: bool,
567) -> UnbalancedOtResult {
568 let n = a.len();
569 let m = b.len();
570
571 let ot_cost: f64 = cost
573 .iter()
574 .zip(transport_plan.iter())
575 .map(|(&c, &t)| c * t)
576 .sum();
577
578 let source_marg: Vec<f64> = (0..n).map(|i| transport_plan.row(i).sum()).collect();
580
581 let target_marg: Vec<f64> = (0..m).map(|j| transport_plan.column(j).sum()).collect();
583
584 let mv_src: f64 = source_marg
586 .iter()
587 .zip(a.iter())
588 .map(|(&sm, &ai)| (sm - ai).abs())
589 .sum();
590 let mv_tgt: f64 = target_marg
591 .iter()
592 .zip(b.iter())
593 .map(|(&tm, &bi)| (tm - bi).abs())
594 .sum();
595
596 UnbalancedOtResult {
597 transport_plan,
598 cost: ot_cost,
599 marginal_violation_source: mv_src,
600 marginal_violation_target: mv_tgt,
601 n_iter,
602 converged,
603 }
604}
605
606#[cfg(test)]
611mod tests {
612 use super::*;
613 use scirs2_core::ndarray::array;
614
615 #[test]
620 fn test_unbalanced_ot_equal_mass_kl() {
621 let n = 4usize;
623 let a: Vec<f64> = vec![0.25; n];
624 let b: Vec<f64> = vec![0.25; n];
625 let mut cost_arr = Array2::<f64>::zeros((n, n));
626 for i in 0..n {
627 for j in 0..n {
628 cost_arr[[i, j]] = (i as f64 - j as f64).abs() / n as f64;
629 }
630 }
631
632 let config = UnbalancedOtConfig {
633 epsilon: 0.01,
634 tau: 100.0, log_domain: true,
636 max_iter: 2000,
637 tol: 1e-8,
638 ..Default::default()
639 };
640
641 let result = unbalanced_sinkhorn(&a, &b, &cost_arr, &config).expect("UOT ok");
642 assert!(result.cost >= 0.0, "cost must be non-negative");
643 assert!(
645 result.marginal_violation_source < 0.1,
646 "source marginal violation should be small, got {}",
647 result.marginal_violation_source
648 );
649 }
650
651 #[test]
652 fn test_unbalanced_ot_equal_mass_l2() {
653 let a = vec![0.5, 0.5];
654 let b = vec![0.5, 0.5];
655 let cost = array![[0.0_f64, 1.0], [1.0, 0.0]];
656 let config = UnbalancedOtConfig {
657 regularization: UnbalancedRegularization::L2,
658 epsilon: 0.1,
659 tau: 10.0,
660 max_iter: 500,
661 tol: 1e-6,
662 log_domain: false,
663 ..Default::default()
664 };
665 let result = unbalanced_sinkhorn(&a, &b, &cost, &config).expect("UOT L2 ok");
666 assert!(result.cost >= 0.0);
667 for &t in result.transport_plan.iter() {
669 assert!(t >= -1e-10, "transport plan entries must be non-negative");
670 }
671 }
672
673 #[test]
674 fn test_unbalanced_ot_unequal_mass() {
675 let a = vec![0.5, 0.5]; let b = vec![0.25, 0.25]; let cost = array![[0.0_f64, 1.0], [1.0, 0.0]];
679
680 let config = UnbalancedOtConfig {
681 epsilon: 0.05,
682 tau: 0.5, max_iter: 1000,
684 tol: 1e-6,
685 log_domain: true,
686 ..Default::default()
687 };
688 let result = unbalanced_sinkhorn(&a, &b, &cost, &config).expect("UOT unequal ok");
689 assert!(result.cost >= 0.0);
690 let total_mv = result.marginal_violation_source + result.marginal_violation_target;
692 assert!(
694 total_mv >= 0.0,
695 "marginal violations should be non-negative"
696 );
697 }
698
699 #[test]
700 fn test_unbalanced_ot_diagonal_cost() {
701 let n = 3usize;
703 let a = vec![1.0 / n as f64; n];
704 let b = vec![1.0 / n as f64; n];
705 let mut cost_arr = Array2::<f64>::ones((n, n)) * 10.0;
706 for i in 0..n {
707 cost_arr[[i, i]] = 0.0;
708 }
709
710 let config = UnbalancedOtConfig {
711 epsilon: 0.01,
712 tau: 100.0,
713 max_iter: 2000,
714 tol: 1e-9,
715 log_domain: true,
716 ..Default::default()
717 };
718 let result = unbalanced_sinkhorn(&a, &b, &cost_arr, &config).expect("UOT diagonal ok");
719 assert!(
721 result.cost < 0.5,
722 "diagonal-concentrated plan should have small cost, got {}",
723 result.cost
724 );
725 }
726
727 #[test]
728 fn test_unbalanced_ot_kl_standard_domain() {
729 let a = vec![0.5, 0.5];
730 let b = vec![0.5, 0.5];
731 let cost = array![[0.0_f64, 1.0], [1.0, 0.0]];
732 let config = UnbalancedOtConfig {
733 epsilon: 0.1,
734 tau: 1.0,
735 log_domain: false, max_iter: 500,
737 tol: 1e-6,
738 ..Default::default()
739 };
740 let result = unbalanced_sinkhorn(&a, &b, &cost, &config).expect("UOT KL std ok");
741 assert!(result.cost >= 0.0);
742 }
743
744 #[test]
749 fn test_empty_source_error() {
750 let a: Vec<f64> = vec![];
751 let b = vec![0.5, 0.5];
752 let cost = Array2::<f64>::zeros((0, 2));
753 let config = UnbalancedOtConfig::default();
754 assert!(unbalanced_sinkhorn(&a, &b, &cost, &config).is_err());
755 }
756
757 #[test]
758 fn test_shape_mismatch_error() {
759 let a = vec![0.5, 0.5];
760 let b = vec![0.5, 0.5];
761 let cost = Array2::<f64>::zeros((3, 2)); let config = UnbalancedOtConfig::default();
763 assert!(unbalanced_sinkhorn(&a, &b, &cost, &config).is_err());
764 }
765
766 #[test]
767 fn test_negative_epsilon_error() {
768 let a = vec![0.5, 0.5];
769 let b = vec![0.5, 0.5];
770 let cost = array![[0.0_f64, 1.0], [1.0, 0.0]];
771 let config = UnbalancedOtConfig {
772 epsilon: -0.1,
773 ..Default::default()
774 };
775 assert!(unbalanced_sinkhorn(&a, &b, &cost, &config).is_err());
776 }
777
778 #[test]
779 fn test_zero_mass_error() {
780 let a = vec![0.0, 0.0];
781 let b = vec![0.5, 0.5];
782 let cost = array![[0.0_f64, 1.0], [1.0, 0.0]];
783 let config = UnbalancedOtConfig::default();
784 assert!(unbalanced_sinkhorn(&a, &b, &cost, &config).is_err());
785 }
786
787 #[test]
788 fn test_transport_plan_non_negative() {
789 let a = vec![0.3, 0.7];
791 let b = vec![0.6, 0.4];
792 let cost = array![[0.1_f64, 0.9], [0.8, 0.2]];
793 let config = UnbalancedOtConfig::default();
794 let result = unbalanced_sinkhorn(&a, &b, &cost, &config).expect("UOT ok");
795 for &t in result.transport_plan.iter() {
796 assert!(t >= -1e-12, "transport plan entry {t} is negative");
797 }
798 }
799
800 #[test]
801 fn test_1x1_trivial() {
802 let a = vec![1.0];
805 let b = vec![1.0];
806 let cost = array![[0.0_f64]];
808 let config = UnbalancedOtConfig {
809 epsilon: 0.01,
810 tau: 100.0,
811 max_iter: 2000,
812 tol: 1e-8,
813 ..Default::default()
814 };
815 let result = unbalanced_sinkhorn(&a, &b, &cost, &config).expect("1x1 ok");
816 assert!(
817 (result.transport_plan[[0, 0]] - 1.0).abs() < 0.2,
818 "1x1 transport plan should be close to 1, got {}",
819 result.transport_plan[[0, 0]]
820 );
821 assert!(
823 result.cost < 0.5,
824 "1x1 cost with zero cost matrix should be small, got {}",
825 result.cost
826 );
827 }
828
829 #[test]
834 fn test_log_sum_exp_vec() {
835 let vals = vec![1.0_f64, 2.0, 3.0];
836 let lse = log_sum_exp_vec(&vals);
837 let expected = (1.0_f64.exp() + 2.0_f64.exp() + 3.0_f64.exp()).ln();
838 assert!((lse - expected).abs() < 1e-10, "lse mismatch");
839 }
840}