Skip to main content

scirs2_transform/ot/
unbalanced.rs

1//! Unbalanced Optimal Transport
2//!
3//! This module implements unbalanced optimal transport (UOT), which relaxes the
4//! hard marginal constraints of balanced OT so that distributions with different
5//! total mass can be compared.
6//!
7//! ## Theory
8//!
9//! Classical OT requires the source and target distributions to have equal mass.
10//! UOT replaces the hard marginal constraints with soft penalty terms:
11//!
12//! ```text
13//! UOT_{ε,τ}(a, b) = min_{T≥0}  ⟨C, T⟩
14//!                              + ε KL(T | a⊗b)
15//!                              + τ KL(T1 | a)
16//!                              + τ KL(1ᵀT | b)
17//! ```
18//!
19//! where KL(p|q) = Σ p_i log(p_i/q_i) − p_i + q_i is the generalised KL divergence.
20//!
21//! ### Unbalanced Sinkhorn Algorithm
22//!
23//! The solution is obtained via the scaling algorithm of Chizat et al. (2018):
24//!
25//! Initialise u = 1_n, v = 1_m, K_ij = exp(−C_ij / ε).
26//! Iterate until convergence:
27//! ```text
28//! u ← (a / (K v))^{τ/(τ+ε)}
29//! v ← (b / (Kᵀ u))^{τ/(τ+ε)}
30//! ```
31//! Optimal transport plan: T_ij = u_i K_ij v_j.
32//!
33//! ## References
34//!
35//! - Chizat, Peyré, Schmitzer, Vialard (2018):
36//!   "Scaling algorithms for unbalanced optimal transport problems."
37//!   Mathematics of Computation, 87(314), 2563-2609.
38//! - Séjourné, Feydy, Vialard, Trouvé, Peyré (2019):
39//!   "Sinkhorn Divergences for Unbalanced Optimal Transport."
40
41use scirs2_core::ndarray::{Array1, Array2};
42
43use crate::error::{Result, TransformError};
44
45// ---------------------------------------------------------------------------
46// Regularization type
47// ---------------------------------------------------------------------------
48
49/// Marginal relaxation type for unbalanced OT.
50#[non_exhaustive]
51#[derive(Debug, Clone, PartialEq)]
52pub enum UnbalancedRegularization {
53    /// KL-divergence marginal penalty: τ KL(T1 | a) + τ KL(1ᵀT | b).
54    /// This is the standard choice for UOT and leads to a closed-form
55    /// proximal step in the scaling algorithm.
56    KLDivergence,
57    /// L2-norm marginal penalty: (τ/2) ‖T1 − a‖² + (τ/2) ‖1ᵀT − b‖².
58    /// The proximal step is a soft-thresholding operator.
59    L2,
60}
61
62// ---------------------------------------------------------------------------
63// Configuration
64// ---------------------------------------------------------------------------
65
66/// Configuration for unbalanced Sinkhorn OT.
67#[derive(Debug, Clone)]
68pub struct UnbalancedOtConfig {
69    /// Entropic regularization strength ε > 0. Default: 0.1.
70    pub epsilon: f64,
71    /// Marginal relaxation strength τ > 0. Default: 1.0.
72    ///
73    /// As τ → ∞ the problem approaches balanced OT.
74    /// Small τ allows large deviations from the input marginals.
75    pub tau: f64,
76    /// Marginal penalty type.
77    pub regularization: UnbalancedRegularization,
78    /// Maximum number of Sinkhorn iterations. Default: 1000.
79    pub max_iter: usize,
80    /// Convergence tolerance (on the marginal error). Default: 1e-6.
81    pub tol: f64,
82    /// Whether to apply log-domain stabilization (recommended for small ε).
83    /// Default: `true`.
84    pub log_domain: bool,
85}
86
87impl Default for UnbalancedOtConfig {
88    fn default() -> Self {
89        Self {
90            epsilon: 0.1,
91            tau: 1.0,
92            regularization: UnbalancedRegularization::KLDivergence,
93            max_iter: 1000,
94            tol: 1e-6,
95            log_domain: true,
96        }
97    }
98}
99
100// ---------------------------------------------------------------------------
101// Result
102// ---------------------------------------------------------------------------
103
104/// Result of an unbalanced OT computation.
105#[derive(Debug, Clone)]
106pub struct UnbalancedOtResult {
107    /// Optimal transport plan T (n × m), with potentially unequal row/column sums.
108    pub transport_plan: Array2<f64>,
109    /// Total transport cost ⟨C, T⟩.
110    pub cost: f64,
111    /// Marginal violation on the source side: ‖T 1_m − a‖₁.
112    pub marginal_violation_source: f64,
113    /// Marginal violation on the target side: ‖1_n^ᵀ T − b‖₁.
114    pub marginal_violation_target: f64,
115    /// Number of iterations performed.
116    pub n_iter: usize,
117    /// Whether convergence was achieved.
118    pub converged: bool,
119}
120
121// ---------------------------------------------------------------------------
122// Main entry point
123// ---------------------------------------------------------------------------
124
125/// Solve an unbalanced optimal transport problem via the Sinkhorn scaling algorithm.
126///
127/// # Arguments
128/// * `a`    – Source histogram (n,), must be non-negative (will be normalised internally).
129/// * `b`    – Target histogram (m,), must be non-negative.
130/// * `cost` – Ground cost matrix C (n × m), must be non-negative.
131/// * `config` – Algorithm parameters.
132///
133/// # Returns
134/// [`UnbalancedOtResult`] containing the transport plan and diagnostics.
135///
136/// # Errors
137/// Returns an error if inputs have incompatible shapes, contain negative entries,
138/// or if all weights are zero.
139///
140/// # Example
141/// ```rust
142/// use scirs2_transform::ot::unbalanced::{unbalanced_sinkhorn, UnbalancedOtConfig};
143/// use scirs2_core::ndarray::array;
144///
145/// let a = vec![0.5, 0.5];
146/// let b = vec![0.5, 0.5];
147/// let cost = array![[0.0_f64, 1.0], [1.0, 0.0]];
148/// let config = UnbalancedOtConfig::default();
149/// let result = unbalanced_sinkhorn(&a, &b, &cost, &config).expect("UOT should succeed");
150/// assert!(result.cost >= 0.0);
151/// ```
152pub fn unbalanced_sinkhorn(
153    a: &[f64],
154    b: &[f64],
155    cost: &Array2<f64>,
156    config: &UnbalancedOtConfig,
157) -> Result<UnbalancedOtResult> {
158    // ----------------------------------------------------------------
159    // Validate inputs
160    // ----------------------------------------------------------------
161    let n = a.len();
162    let m = b.len();
163
164    if n == 0 {
165        return Err(TransformError::InvalidInput(
166            "Source histogram 'a' must be non-empty".to_string(),
167        ));
168    }
169    if m == 0 {
170        return Err(TransformError::InvalidInput(
171            "Target histogram 'b' must be non-empty".to_string(),
172        ));
173    }
174    if cost.dim() != (n, m) {
175        return Err(TransformError::InvalidInput(format!(
176            "Cost matrix shape ({},{}) does not match histogram lengths ({n},{m})",
177            cost.nrows(),
178            cost.ncols()
179        )));
180    }
181    if config.epsilon <= 0.0 {
182        return Err(TransformError::InvalidInput(
183            "epsilon must be positive".to_string(),
184        ));
185    }
186    if config.tau <= 0.0 {
187        return Err(TransformError::InvalidInput(
188            "tau must be positive".to_string(),
189        ));
190    }
191    for &ai in a {
192        if ai < 0.0 {
193            return Err(TransformError::InvalidInput(
194                "Source histogram contains negative entries".to_string(),
195            ));
196        }
197    }
198    for &bi in b {
199        if bi < 0.0 {
200            return Err(TransformError::InvalidInput(
201                "Target histogram contains negative entries".to_string(),
202            ));
203        }
204    }
205    let sum_a: f64 = a.iter().sum();
206    let sum_b: f64 = b.iter().sum();
207    if sum_a < f64::EPSILON {
208        return Err(TransformError::InvalidInput(
209            "Source histogram has zero total mass".to_string(),
210        ));
211    }
212    if sum_b < f64::EPSILON {
213        return Err(TransformError::InvalidInput(
214            "Target histogram has zero total mass".to_string(),
215        ));
216    }
217
218    // ----------------------------------------------------------------
219    // Check for negative cost entries
220    // ----------------------------------------------------------------
221    for ci in cost.iter() {
222        if *ci < 0.0 {
223            return Err(TransformError::InvalidInput(
224                "Cost matrix contains negative entries".to_string(),
225            ));
226        }
227    }
228
229    match config.regularization {
230        UnbalancedRegularization::KLDivergence => {
231            if config.log_domain {
232                sinkhorn_kl_log_domain(a, b, cost, config)
233            } else {
234                sinkhorn_kl(a, b, cost, config)
235            }
236        }
237        UnbalancedRegularization::L2 => sinkhorn_l2(a, b, cost, config),
238    }
239}
240
241// ---------------------------------------------------------------------------
242// KL-divergence scaling algorithm (standard domain)
243// ---------------------------------------------------------------------------
244
245/// Unbalanced Sinkhorn scaling with KL marginal penalties.
246///
247/// Scaling exponent: ρ = τ / (τ + ε)
248fn sinkhorn_kl(
249    a: &[f64],
250    b: &[f64],
251    cost: &Array2<f64>,
252    config: &UnbalancedOtConfig,
253) -> Result<UnbalancedOtResult> {
254    let n = a.len();
255    let m = b.len();
256    let rho = config.tau / (config.tau + config.epsilon);
257
258    // Gibbs kernel K_ij = exp(-C_ij / epsilon)
259    let k: Array2<f64> = cost.mapv(|c| (-c / config.epsilon).exp());
260
261    // Scaling vectors (dual variables in the exponential domain)
262    let mut u = Array1::from_elem(n, 1.0_f64);
263    let mut v = Array1::from_elem(m, 1.0_f64);
264
265    let a_arr = Array1::from_vec(a.to_vec());
266    let b_arr = Array1::from_vec(b.to_vec());
267
268    let mut converged = false;
269    let mut n_iter = 0usize;
270
271    for _iter in 0..config.max_iter {
272        n_iter += 1;
273
274        // Kv[i] = Σ_j K[i,j] * v[j]
275        let kv: Array1<f64> = k.dot(&v);
276        // u ← (a / Kv)^ρ
277        let u_new: Array1<f64> = a_arr
278            .iter()
279            .zip(kv.iter())
280            .map(|(&ai, &kvi)| {
281                if kvi < f64::EPSILON {
282                    0.0
283                } else {
284                    (ai / kvi).powf(rho)
285                }
286            })
287            .collect::<Vec<f64>>()
288            .into();
289
290        // Ktu[j] = Σ_i K[i,j] * u[i]
291        let ktu: Array1<f64> = k.t().dot(&u_new);
292        // v ← (b / Kᵀu)^ρ
293        let v_new: Array1<f64> = b_arr
294            .iter()
295            .zip(ktu.iter())
296            .map(|(&bi, &ktui)| {
297                if ktui < f64::EPSILON {
298                    0.0
299                } else {
300                    (bi / ktui).powf(rho)
301                }
302            })
303            .collect::<Vec<f64>>()
304            .into();
305
306        // Convergence check: change in scaling vectors
307        let du: f64 = u_new
308            .iter()
309            .zip(u.iter())
310            .map(|(&a, &b)| (a - b).abs())
311            .sum::<f64>()
312            / (n as f64);
313        let dv: f64 = v_new
314            .iter()
315            .zip(v.iter())
316            .map(|(&a, &b)| (a - b).abs())
317            .sum::<f64>()
318            / (m as f64);
319
320        u = u_new;
321        v = v_new;
322
323        if du + dv < config.tol {
324            converged = true;
325            break;
326        }
327    }
328
329    // Build transport plan: T_ij = u_i K_ij v_j
330    let transport_plan = build_transport_plan(&u, &k, &v);
331    let result = compute_result(transport_plan, cost, a, b, n_iter, converged);
332    Ok(result)
333}
334
335// ---------------------------------------------------------------------------
336// KL-divergence scaling algorithm (log domain — numerically stable)
337// ---------------------------------------------------------------------------
338
339/// Log-domain stabilized unbalanced Sinkhorn (Chizat 2018, Algorithm 2).
340///
341/// Works in log-space to avoid numerical overflow/underflow for small ε.
342fn sinkhorn_kl_log_domain(
343    a: &[f64],
344    b: &[f64],
345    cost: &Array2<f64>,
346    config: &UnbalancedOtConfig,
347) -> Result<UnbalancedOtResult> {
348    let n = a.len();
349    let m = b.len();
350    let rho = config.tau / (config.tau + config.epsilon);
351    let eps = config.epsilon;
352
353    // Log potentials: f (n,), g (m,)
354    let mut f: Array1<f64> = Array1::zeros(n);
355    let mut g: Array1<f64> = Array1::zeros(m);
356
357    let log_a: Vec<f64> = a
358        .iter()
359        .map(|&ai| if ai > 0.0 { ai.ln() } else { f64::NEG_INFINITY })
360        .collect();
361    let log_b: Vec<f64> = b
362        .iter()
363        .map(|&bi| if bi > 0.0 { bi.ln() } else { f64::NEG_INFINITY })
364        .collect();
365
366    let mut converged = false;
367    let mut n_iter = 0usize;
368
369    for _iter in 0..config.max_iter {
370        n_iter += 1;
371
372        // Softmin_ε over j: h_i = −ε lse_j (g_j − C_ij / ε)
373        // Then f ← ρ (log_a − h_i / ε) * ε  [from KL prox update]
374        // Equivalently: f_i ← ρ (ε log_a_i − softmin_ε_j(g_j − C_ij/ε))
375
376        let f_prev = f.clone();
377        let g_prev = g.clone();
378
379        // Update f: f_i = ρ * (ε ln a_i − softmin_j(g_j − C_{ij}/ε) * ε ... )
380        // The proximal update is: f ← ρ/(ρ+1) * (ε ln a − ε lse_j((g - C/ε) / 1))
381        // But with the standard KL UOT: f_i ← rho * ε * (ln a_i − lse_j( (g_j - C_ij)/ε ) )
382        // where lse is log-sum-exp.
383        for i in 0..n {
384            let lse_j = log_sum_exp_row(i, &g, cost, eps, m);
385            let new_fi = rho * (eps * log_a[i] - lse_j);
386            f[i] = new_fi;
387        }
388
389        // Update g: g_j = ρ * (ε ln b_j − lse_i( (f_i - C_ij)/ε ))
390        for j in 0..m {
391            let lse_i = log_sum_exp_col(j, &f, cost, eps, n);
392            let new_gj = rho * (eps * log_b[j] - lse_i);
393            g[j] = new_gj;
394        }
395
396        // Convergence check
397        let df: f64 = f
398            .iter()
399            .zip(f_prev.iter())
400            .map(|(&a, &b)| (a - b).abs())
401            .sum::<f64>()
402            / n as f64;
403        let dg: f64 = g
404            .iter()
405            .zip(g_prev.iter())
406            .map(|(&a, &b)| (a - b).abs())
407            .sum::<f64>()
408            / m as f64;
409
410        if df + dg < config.tol {
411            converged = true;
412            break;
413        }
414    }
415
416    // Build transport plan from potentials: T_ij = exp((f_i + g_j - C_ij) / ε)
417    let mut transport_plan = Array2::<f64>::zeros((n, m));
418    for i in 0..n {
419        for j in 0..m {
420            transport_plan[[i, j]] = ((f[i] + g[j] - cost[[i, j]]) / eps).exp();
421        }
422    }
423
424    let result = compute_result(transport_plan, cost, a, b, n_iter, converged);
425    Ok(result)
426}
427
428/// Log-sum-exp of (g_j − C_ij / ε) over j (row i of cost matrix).
429#[inline]
430fn log_sum_exp_row(i: usize, g: &Array1<f64>, cost: &Array2<f64>, eps: f64, m: usize) -> f64 {
431    let vals: Vec<f64> = (0..m).map(|j| g[j] - cost[[i, j]] / eps).collect();
432    log_sum_exp_vec(&vals)
433}
434
435/// Log-sum-exp of (f_i − C_ij / ε) over i (column j of cost matrix).
436#[inline]
437fn log_sum_exp_col(j: usize, f: &Array1<f64>, cost: &Array2<f64>, eps: f64, n: usize) -> f64 {
438    let vals: Vec<f64> = (0..n).map(|i| f[i] - cost[[i, j]] / eps).collect();
439    log_sum_exp_vec(&vals)
440}
441
442/// Numerically stable log-sum-exp: log Σ exp(x_i) = max + log Σ exp(x_i − max).
443fn log_sum_exp_vec(vals: &[f64]) -> f64 {
444    if vals.is_empty() {
445        return f64::NEG_INFINITY;
446    }
447    let max_val = vals
448        .iter()
449        .copied()
450        .filter(|v| v.is_finite())
451        .fold(f64::NEG_INFINITY, f64::max);
452    if !max_val.is_finite() {
453        return f64::NEG_INFINITY;
454    }
455    let sum_exp: f64 = vals
456        .iter()
457        .filter(|v| v.is_finite())
458        .map(|&v| (v - max_val).exp())
459        .sum();
460    max_val + sum_exp.ln()
461}
462
463// ---------------------------------------------------------------------------
464// L2 marginal penalty (proximal step = clip-to-positive)
465// ---------------------------------------------------------------------------
466
467/// Unbalanced Sinkhorn with L2 marginal penalties.
468///
469/// The proximal operator for the L2 penalty is:
470/// u ← max(0, 1 − (Kv − a) / (τ K1_m))  ... (simplified form)
471///
472/// In practice we use the scaling form:
473/// u ← a / (Kv + ε/τ)
474fn sinkhorn_l2(
475    a: &[f64],
476    b: &[f64],
477    cost: &Array2<f64>,
478    config: &UnbalancedOtConfig,
479) -> Result<UnbalancedOtResult> {
480    let n = a.len();
481    let m = b.len();
482
483    let k: Array2<f64> = cost.mapv(|c| (-c / config.epsilon).exp());
484    let mut u = Array1::from_elem(n, 1.0_f64);
485    let mut v = Array1::from_elem(m, 1.0_f64);
486
487    let a_arr = Array1::from_vec(a.to_vec());
488    let b_arr = Array1::from_vec(b.to_vec());
489
490    // L2 proximal scaling: effectively a soft update
491    // u ← a / (Kv + ε/τ)  — comes from RKHS proximal step for squared norm
492    let lambda = config.epsilon / config.tau;
493
494    let mut converged = false;
495    let mut n_iter = 0usize;
496
497    for _iter in 0..config.max_iter {
498        n_iter += 1;
499        let kv: Array1<f64> = k.dot(&v);
500        let u_new: Array1<f64> = a_arr
501            .iter()
502            .zip(kv.iter())
503            .map(|(&ai, &kvi)| ai / (kvi + lambda).max(f64::EPSILON))
504            .collect::<Vec<f64>>()
505            .into();
506
507        let ktu: Array1<f64> = k.t().dot(&u_new);
508        let v_new: Array1<f64> = b_arr
509            .iter()
510            .zip(ktu.iter())
511            .map(|(&bi, &ktui)| bi / (ktui + lambda).max(f64::EPSILON))
512            .collect::<Vec<f64>>()
513            .into();
514
515        let du: f64 = u_new
516            .iter()
517            .zip(u.iter())
518            .map(|(&a, &b)| (a - b).abs())
519            .sum::<f64>()
520            / n as f64;
521        let dv: f64 = v_new
522            .iter()
523            .zip(v.iter())
524            .map(|(&a, &b)| (a - b).abs())
525            .sum::<f64>()
526            / m as f64;
527
528        u = u_new;
529        v = v_new;
530
531        if du + dv < config.tol {
532            converged = true;
533            break;
534        }
535    }
536
537    let transport_plan = build_transport_plan(&u, &k, &v);
538    let result = compute_result(transport_plan, cost, a, b, n_iter, converged);
539    Ok(result)
540}
541
542// ---------------------------------------------------------------------------
543// Internal helpers
544// ---------------------------------------------------------------------------
545
546/// Build the transport plan: T_ij = u_i K_ij v_j
547fn build_transport_plan(u: &Array1<f64>, k: &Array2<f64>, v: &Array1<f64>) -> Array2<f64> {
548    let n = u.len();
549    let m = v.len();
550    let mut t = Array2::zeros((n, m));
551    for i in 0..n {
552        for j in 0..m {
553            t[[i, j]] = u[i] * k[[i, j]] * v[j];
554        }
555    }
556    t
557}
558
559/// Compute diagnostics from the transport plan.
560fn compute_result(
561    transport_plan: Array2<f64>,
562    cost: &Array2<f64>,
563    a: &[f64],
564    b: &[f64],
565    n_iter: usize,
566    converged: bool,
567) -> UnbalancedOtResult {
568    let n = a.len();
569    let m = b.len();
570
571    // Transport cost: ⟨C, T⟩
572    let ot_cost: f64 = cost
573        .iter()
574        .zip(transport_plan.iter())
575        .map(|(&c, &t)| c * t)
576        .sum();
577
578    // Source marginal: T 1_m
579    let source_marg: Vec<f64> = (0..n).map(|i| transport_plan.row(i).sum()).collect();
580
581    // Target marginal: 1_n^ᵀ T
582    let target_marg: Vec<f64> = (0..m).map(|j| transport_plan.column(j).sum()).collect();
583
584    // Marginal violations (L1 distance from input histograms)
585    let mv_src: f64 = source_marg
586        .iter()
587        .zip(a.iter())
588        .map(|(&sm, &ai)| (sm - ai).abs())
589        .sum();
590    let mv_tgt: f64 = target_marg
591        .iter()
592        .zip(b.iter())
593        .map(|(&tm, &bi)| (tm - bi).abs())
594        .sum();
595
596    UnbalancedOtResult {
597        transport_plan,
598        cost: ot_cost,
599        marginal_violation_source: mv_src,
600        marginal_violation_target: mv_tgt,
601        n_iter,
602        converged,
603    }
604}
605
606// ---------------------------------------------------------------------------
607// Tests
608// ---------------------------------------------------------------------------
609
610#[cfg(test)]
611mod tests {
612    use super::*;
613    use scirs2_core::ndarray::array;
614
615    // ------------------------------------------------------------------
616    // Basic correctness
617    // ------------------------------------------------------------------
618
619    #[test]
620    fn test_unbalanced_ot_equal_mass_kl() {
621        // Equal-mass uniform histograms; cost = |i − j| / n
622        let n = 4usize;
623        let a: Vec<f64> = vec![0.25; n];
624        let b: Vec<f64> = vec![0.25; n];
625        let mut cost_arr = Array2::<f64>::zeros((n, n));
626        for i in 0..n {
627            for j in 0..n {
628                cost_arr[[i, j]] = (i as f64 - j as f64).abs() / n as f64;
629            }
630        }
631
632        let config = UnbalancedOtConfig {
633            epsilon: 0.01,
634            tau: 100.0, // large tau → close to balanced OT
635            log_domain: true,
636            max_iter: 2000,
637            tol: 1e-8,
638            ..Default::default()
639        };
640
641        let result = unbalanced_sinkhorn(&a, &b, &cost_arr, &config).expect("UOT ok");
642        assert!(result.cost >= 0.0, "cost must be non-negative");
643        // For equal uniform mass, balanced W1 = 1/8; with large tau the UOT should be close
644        assert!(
645            result.marginal_violation_source < 0.1,
646            "source marginal violation should be small, got {}",
647            result.marginal_violation_source
648        );
649    }
650
651    #[test]
652    fn test_unbalanced_ot_equal_mass_l2() {
653        let a = vec![0.5, 0.5];
654        let b = vec![0.5, 0.5];
655        let cost = array![[0.0_f64, 1.0], [1.0, 0.0]];
656        let config = UnbalancedOtConfig {
657            regularization: UnbalancedRegularization::L2,
658            epsilon: 0.1,
659            tau: 10.0,
660            max_iter: 500,
661            tol: 1e-6,
662            log_domain: false,
663            ..Default::default()
664        };
665        let result = unbalanced_sinkhorn(&a, &b, &cost, &config).expect("UOT L2 ok");
666        assert!(result.cost >= 0.0);
667        // Transport plan should be non-negative
668        for &t in result.transport_plan.iter() {
669            assert!(t >= -1e-10, "transport plan entries must be non-negative");
670        }
671    }
672
673    #[test]
674    fn test_unbalanced_ot_unequal_mass() {
675        // Source has mass 1.0, target has mass 0.5
676        let a = vec![0.5, 0.5]; // total mass = 1.0
677        let b = vec![0.25, 0.25]; // total mass = 0.5
678        let cost = array![[0.0_f64, 1.0], [1.0, 0.0]];
679
680        let config = UnbalancedOtConfig {
681            epsilon: 0.05,
682            tau: 0.5, // allow significant marginal deviation
683            max_iter: 1000,
684            tol: 1e-6,
685            log_domain: true,
686            ..Default::default()
687        };
688        let result = unbalanced_sinkhorn(&a, &b, &cost, &config).expect("UOT unequal ok");
689        assert!(result.cost >= 0.0);
690        // With unequal mass, at least one marginal violation should be significant
691        let total_mv = result.marginal_violation_source + result.marginal_violation_target;
692        // It's expected that marginals don't match perfectly with unequal mass
693        assert!(
694            total_mv >= 0.0,
695            "marginal violations should be non-negative"
696        );
697    }
698
699    #[test]
700    fn test_unbalanced_ot_diagonal_cost() {
701        // Zero cost on diagonal: optimal plan should concentrate on diagonal
702        let n = 3usize;
703        let a = vec![1.0 / n as f64; n];
704        let b = vec![1.0 / n as f64; n];
705        let mut cost_arr = Array2::<f64>::ones((n, n)) * 10.0;
706        for i in 0..n {
707            cost_arr[[i, i]] = 0.0;
708        }
709
710        let config = UnbalancedOtConfig {
711            epsilon: 0.01,
712            tau: 100.0,
713            max_iter: 2000,
714            tol: 1e-9,
715            log_domain: true,
716            ..Default::default()
717        };
718        let result = unbalanced_sinkhorn(&a, &b, &cost_arr, &config).expect("UOT diagonal ok");
719        // Cost should be close to 0 (all mass on diagonal)
720        assert!(
721            result.cost < 0.5,
722            "diagonal-concentrated plan should have small cost, got {}",
723            result.cost
724        );
725    }
726
727    #[test]
728    fn test_unbalanced_ot_kl_standard_domain() {
729        let a = vec![0.5, 0.5];
730        let b = vec![0.5, 0.5];
731        let cost = array![[0.0_f64, 1.0], [1.0, 0.0]];
732        let config = UnbalancedOtConfig {
733            epsilon: 0.1,
734            tau: 1.0,
735            log_domain: false, // test non-log-domain path
736            max_iter: 500,
737            tol: 1e-6,
738            ..Default::default()
739        };
740        let result = unbalanced_sinkhorn(&a, &b, &cost, &config).expect("UOT KL std ok");
741        assert!(result.cost >= 0.0);
742    }
743
744    // ------------------------------------------------------------------
745    // Error cases
746    // ------------------------------------------------------------------
747
748    #[test]
749    fn test_empty_source_error() {
750        let a: Vec<f64> = vec![];
751        let b = vec![0.5, 0.5];
752        let cost = Array2::<f64>::zeros((0, 2));
753        let config = UnbalancedOtConfig::default();
754        assert!(unbalanced_sinkhorn(&a, &b, &cost, &config).is_err());
755    }
756
757    #[test]
758    fn test_shape_mismatch_error() {
759        let a = vec![0.5, 0.5];
760        let b = vec![0.5, 0.5];
761        let cost = Array2::<f64>::zeros((3, 2)); // wrong n
762        let config = UnbalancedOtConfig::default();
763        assert!(unbalanced_sinkhorn(&a, &b, &cost, &config).is_err());
764    }
765
766    #[test]
767    fn test_negative_epsilon_error() {
768        let a = vec![0.5, 0.5];
769        let b = vec![0.5, 0.5];
770        let cost = array![[0.0_f64, 1.0], [1.0, 0.0]];
771        let config = UnbalancedOtConfig {
772            epsilon: -0.1,
773            ..Default::default()
774        };
775        assert!(unbalanced_sinkhorn(&a, &b, &cost, &config).is_err());
776    }
777
778    #[test]
779    fn test_zero_mass_error() {
780        let a = vec![0.0, 0.0];
781        let b = vec![0.5, 0.5];
782        let cost = array![[0.0_f64, 1.0], [1.0, 0.0]];
783        let config = UnbalancedOtConfig::default();
784        assert!(unbalanced_sinkhorn(&a, &b, &cost, &config).is_err());
785    }
786
787    #[test]
788    fn test_transport_plan_non_negative() {
789        // All transport plan entries should be non-negative
790        let a = vec![0.3, 0.7];
791        let b = vec![0.6, 0.4];
792        let cost = array![[0.1_f64, 0.9], [0.8, 0.2]];
793        let config = UnbalancedOtConfig::default();
794        let result = unbalanced_sinkhorn(&a, &b, &cost, &config).expect("UOT ok");
795        for &t in result.transport_plan.iter() {
796            assert!(t >= -1e-12, "transport plan entry {t} is negative");
797        }
798    }
799
800    #[test]
801    fn test_1x1_trivial() {
802        // Single source, single target: with zero cost, transport plan entry should be
803        // close to 1 (balanced OT recovered with large tau)
804        let a = vec![1.0];
805        let b = vec![1.0];
806        // Zero cost: optimal T = 1 regardless of regularization
807        let cost = array![[0.0_f64]];
808        let config = UnbalancedOtConfig {
809            epsilon: 0.01,
810            tau: 100.0,
811            max_iter: 2000,
812            tol: 1e-8,
813            ..Default::default()
814        };
815        let result = unbalanced_sinkhorn(&a, &b, &cost, &config).expect("1x1 ok");
816        assert!(
817            (result.transport_plan[[0, 0]] - 1.0).abs() < 0.2,
818            "1x1 transport plan should be close to 1, got {}",
819            result.transport_plan[[0, 0]]
820        );
821        // Cost should be near 0 (zero cost matrix)
822        assert!(
823            result.cost < 0.5,
824            "1x1 cost with zero cost matrix should be small, got {}",
825            result.cost
826        );
827    }
828
829    // ------------------------------------------------------------------
830    // Convenience: log_sum_exp
831    // ------------------------------------------------------------------
832
833    #[test]
834    fn test_log_sum_exp_vec() {
835        let vals = vec![1.0_f64, 2.0, 3.0];
836        let lse = log_sum_exp_vec(&vals);
837        let expected = (1.0_f64.exp() + 2.0_f64.exp() + 3.0_f64.exp()).ln();
838        assert!((lse - expected).abs() < 1e-10, "lse mismatch");
839    }
840}