rust_optimal_transport/exact/
mod.rs1mod ffi;
2mod utils;
3
4use ndarray::prelude::*;
5use std::error::Error;
6use std::fmt;
7
8use super::error::OTError;
9use super::OTSolver;
10use ffi::emd_c;
11use utils::*;
12
13#[derive(Debug)]
16pub enum FastTransportErrorCode {
17 IsInfeasible,
19 IsOptimal,
22 IsUnbounded,
26 IsMaxIterReached,
28}
29
30impl Error for FastTransportErrorCode {}
31
32impl fmt::Display for FastTransportErrorCode {
33 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34 match self {
35 FastTransportErrorCode::IsInfeasible => write!(f, "Network simplex infeasible!"),
36 FastTransportErrorCode::IsOptimal => write!(f, "Optimal solution found!"),
37 FastTransportErrorCode::IsUnbounded => write!(f, "Network simplex unbounded!"),
38 FastTransportErrorCode::IsMaxIterReached => write!(f, "Max iteration reached!"),
39 }
40 }
41}
42
43impl From<i32> for FastTransportErrorCode {
44 fn from(e: i32) -> Self {
45 match e {
46 0 => FastTransportErrorCode::IsInfeasible,
47 1 => FastTransportErrorCode::IsOptimal,
48 2 => FastTransportErrorCode::IsUnbounded,
49 3 => FastTransportErrorCode::IsMaxIterReached,
50 _ => FastTransportErrorCode::IsMaxIterReached,
51 }
52 }
53}
54
55pub struct EarthMovers<'a> {
62 source_weights: &'a mut Array1<f64>,
63 target_weights: &'a mut Array1<f64>,
64 cost: &'a mut Array2<f64>,
65 max_iter: i32,
66}
67
68impl<'a> EarthMovers<'a> {
69 pub fn new(
70 source_weights: &'a mut Array1<f64>,
71 target_weights: &'a mut Array1<f64>,
72 cost: &'a mut Array2<f64>,
73 ) -> Self {
74 Self {
75 source_weights,
76 target_weights,
77 cost,
78 max_iter: 100000,
79 }
80 }
81
82 pub fn iterations<'b>(&'b mut self, max_iter: i32) -> &'b mut Self {
83 self.max_iter = max_iter;
84 self
85 }
86}
87
88impl<'a> OTSolver for EarthMovers<'a> {
89 fn check_shape(&self) -> Result<(), OTError> {
90 let mshape = self.cost.shape();
91 let m0 = mshape[0];
92 let m1 = mshape[1];
93 let dim_a = self.source_weights.len();
94 let dim_b = self.target_weights.len();
95
96 if dim_a != m0 || dim_b != m1 {
97 return Err(OTError::WeightDimensionError {
98 dim_a,
99 dim_b,
100 dim_m_0: m0,
101 dim_m_1: m1,
102 });
103 }
104
105 Ok(())
106 }
107
108 fn solve(&mut self) -> Result<Array2<f64>, OTError> {
109 self.check_shape()?;
110
111 *self.target_weights *= self.source_weights.sum() / self.target_weights.sum();
112
113 emd(
114 self.source_weights,
115 self.target_weights,
116 self.cost,
117 Some(self.max_iter),
118 )
119 }
120}
121
122#[allow(non_snake_case)]
128pub(crate) fn emd(
129 a: &mut Array1<f64>,
130 b: &mut Array1<f64>,
131 M: &mut Array2<f64>,
132 num_iter_max: Option<i32>,
133) -> Result<Array2<f64>, OTError> {
134 let iterations = match num_iter_max {
136 Some(val) => val,
137 None => 100000,
138 };
139
140 let (G, _cost, mut _u, mut _v, result_code) = emd_c(a, b, M, iterations);
142
143 check_result(FastTransportErrorCode::from(result_code))?;
145
146 Ok(G)
147}
148
149#[cfg(test)]
150mod tests {
151
152 use crate::OTSolver;
153 use ndarray::array;
154
155 #[allow(non_snake_case)]
156 #[test]
157 fn test_emd() {
158 let mut a = array![0.5, 0.5];
159 let mut b = array![0.5, 0.5];
160 let mut M = array![[0.0, 1.0], [1.0, 0.0]];
161
162 let gamma = match super::emd(&mut a, &mut b, &mut M, None) {
163 Ok(result) => result,
164 Err(error) => panic!("{:?}", error),
165 };
166
167 let truth = array![[0.5, 0.0], [0.0, 0.5]];
168
169 assert_eq!(gamma, truth);
172 }
173
174 #[test]
175 fn test_earthmovers_builder() {
176 let mut a = array![0.5, 0.5];
177 let mut b = array![0.5, 0.5];
178 let mut m = array![[0.0, 1.0], [1.0, 0.0]];
179
180 let test = match super::EarthMovers::new(&mut a, &mut b, &mut m).solve() {
181 Ok(result) => result,
182 Err(error) => panic!("{:?}", error),
183 };
184
185 let truth = array![[0.5, 0.0], [0.0, 0.5]];
186
187 assert_eq!(test, truth);
188 }
189}