rust_optimal_transport/exact/
mod.rs

1mod ffi;
2mod utils;
3
4use ndarray::prelude::*;
5use std::error::Error;
6use std::fmt;
7
8use super::error::OTError;
9use super::OTSolver;
10use ffi::emd_c;
11use utils::*;
12
13/// Return codes from the FastTransport network simplex solver
14/// FastTransport returns 1 on success
15#[derive(Debug)]
16pub enum FastTransportErrorCode {
17    /// No feasible flow exists for the problem
18    IsInfeasible,
19    /// The problem is feasible and bounded.
20    /// Optimal flow and node potentials (primal and dual solutions) found
21    IsOptimal,
22    /// Objective function of the problem is unbounded
23    /// ie. there is a directed cycle having negative total cost and infinite
24    /// upper bound
25    IsUnbounded,
26    /// Maximum iterations reached by the solver
27    IsMaxIterReached,
28}
29
30impl Error for FastTransportErrorCode {}
31
32impl fmt::Display for FastTransportErrorCode {
33    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34        match self {
35            FastTransportErrorCode::IsInfeasible => write!(f, "Network simplex infeasible!"),
36            FastTransportErrorCode::IsOptimal => write!(f, "Optimal solution found!"),
37            FastTransportErrorCode::IsUnbounded => write!(f, "Network simplex unbounded!"),
38            FastTransportErrorCode::IsMaxIterReached => write!(f, "Max iteration reached!"),
39        }
40    }
41}
42
43impl From<i32> for FastTransportErrorCode {
44    fn from(e: i32) -> Self {
45        match e {
46            0 => FastTransportErrorCode::IsInfeasible,
47            1 => FastTransportErrorCode::IsOptimal,
48            2 => FastTransportErrorCode::IsUnbounded,
49            3 => FastTransportErrorCode::IsMaxIterReached,
50            _ => FastTransportErrorCode::IsMaxIterReached,
51        }
52    }
53}
54
55/// Solves the exact OT Earth Movers Distance using the FastTransport LP solver
56/// source_weights: Weights on samples from the source distribution
57/// target_weights: Weights on samples from the target distribution
58/// cost: Distance between samples in the source and target distributions
59/// num_iter_max: maximum number of iterations before stopping the optimization algorithm if it has
60/// not converged (default = 100000)
61pub struct EarthMovers<'a> {
62    source_weights: &'a mut Array1<f64>,
63    target_weights: &'a mut Array1<f64>,
64    cost: &'a mut Array2<f64>,
65    max_iter: i32,
66}
67
68impl<'a> EarthMovers<'a> {
69    pub fn new(
70        source_weights: &'a mut Array1<f64>,
71        target_weights: &'a mut Array1<f64>,
72        cost: &'a mut Array2<f64>,
73    ) -> Self {
74        Self {
75            source_weights,
76            target_weights,
77            cost,
78            max_iter: 100000,
79        }
80    }
81
82    pub fn iterations<'b>(&'b mut self, max_iter: i32) -> &'b mut Self {
83        self.max_iter = max_iter;
84        self
85    }
86}
87
88impl<'a> OTSolver for EarthMovers<'a> {
89    fn check_shape(&self) -> Result<(), OTError> {
90        let mshape = self.cost.shape();
91        let m0 = mshape[0];
92        let m1 = mshape[1];
93        let dim_a = self.source_weights.len();
94        let dim_b = self.target_weights.len();
95
96        if dim_a != m0 || dim_b != m1 {
97            return Err(OTError::WeightDimensionError {
98                dim_a,
99                dim_b,
100                dim_m_0: m0,
101                dim_m_1: m1,
102            });
103        }
104
105        Ok(())
106    }
107
108    fn solve(&mut self) -> Result<Array2<f64>, OTError> {
109        self.check_shape()?;
110
111        *self.target_weights *= self.source_weights.sum() / self.target_weights.sum();
112
113        emd(
114            self.source_weights,
115            self.target_weights,
116            self.cost,
117            Some(self.max_iter),
118        )
119    }
120}
121
122/// a: Source sample weights (defaults to uniform weight if empty)
123/// b: Target sample weights (defaults to uniform weight if empty)
124/// M: Loss matrix (row-major)
125/// num_iter_max: maximum number of iterations before stopping the optimization algorithm if it has
126/// not converged (default = 100000)
127#[allow(non_snake_case)]
128pub(crate) fn emd(
129    a: &mut Array1<f64>,
130    b: &mut Array1<f64>,
131    M: &mut Array2<f64>,
132    num_iter_max: Option<i32>,
133) -> Result<Array2<f64>, OTError> {
134    // Defaults
135    let iterations = match num_iter_max {
136        Some(val) => val,
137        None => 100000,
138    };
139
140    // Call FastTransport via wrapper
141    let (G, _cost, mut _u, mut _v, result_code) = emd_c(a, b, M, iterations);
142
143    // Propogate errors if there are any
144    check_result(FastTransportErrorCode::from(result_code))?;
145
146    Ok(G)
147}
148
149#[cfg(test)]
150mod tests {
151
152    use crate::OTSolver;
153    use ndarray::array;
154
155    #[allow(non_snake_case)]
156    #[test]
157    fn test_emd() {
158        let mut a = array![0.5, 0.5];
159        let mut b = array![0.5, 0.5];
160        let mut M = array![[0.0, 1.0], [1.0, 0.0]];
161
162        let gamma = match super::emd(&mut a, &mut b, &mut M, None) {
163            Ok(result) => result,
164            Err(error) => panic!("{:?}", error),
165        };
166
167        let truth = array![[0.5, 0.0], [0.0, 0.5]];
168
169        // println!("{:?}", gamma);
170
171        assert_eq!(gamma, truth);
172    }
173
174    #[test]
175    fn test_earthmovers_builder() {
176        let mut a = array![0.5, 0.5];
177        let mut b = array![0.5, 0.5];
178        let mut m = array![[0.0, 1.0], [1.0, 0.0]];
179
180        let test = match super::EarthMovers::new(&mut a, &mut b, &mut m).solve() {
181            Ok(result) => result,
182            Err(error) => panic!("{:?}", error),
183        };
184
185        let truth = array![[0.5, 0.0], [0.0, 0.5]];
186
187        assert_eq!(test, truth);
188    }
189}