rust_optimal_transport/
metrics.rs1use ndarray::prelude::*;
2use ndarray_einsum_beta::*;
3
4pub enum MetricType {
5 SqEuclidean,
6 Euclidean,
7}
8
9pub fn dist(x1: &Array2<f64>, x2: &Array2<f64>, metric: MetricType) -> Array2<f64> {
14 match metric {
15 MetricType::SqEuclidean => euclidean_distances(x1, x2, true),
16 MetricType::Euclidean => euclidean_distances(x1, x2, false),
17 }
18}
19
20fn euclidean_distances(x: &Array2<f64>, y: &Array2<f64>, squared: bool) -> Array2<f64> {
26 let a2 = einsum("ij,ij->i", &[x, x]).unwrap();
32 let b2 = einsum("ij,ij->i", &[y, y]).unwrap();
34
35 let mut c = (x.dot(&y.t())) * -2f64;
36
37 for (mut row, a2val) in c.axis_iter_mut(Axis(0)).zip(&a2.t()) {
39 for ele in row.iter_mut() {
40 *ele += a2val;
41 }
42 }
43
44 for (mut col, b2val) in c.axis_iter_mut(Axis(1)).zip(&b2) {
46 for ele in col.iter_mut() {
47 *ele += b2val;
48 }
49 }
50
51 for val in c.iter_mut() {
53 if *val <= 0f64 {
54 *val = 0f64;
55 }
56 }
57
58 if !squared {
59 for val in c.iter_mut() {
61 *val = val.powf(0.5);
62 }
63 }
64
65 if x == y {
66 let mut anti_diag = Array2::<f64>::ones((a2.len(), b2.len()));
68 for ele in anti_diag.diag_mut().iter_mut() {
69 *ele = 0f64;
70 }
71
72 c = c * anti_diag;
73 }
74
75 c
76}
77
78#[cfg(test)]
79mod tests {
80
81 use ndarray::prelude::*;
82
83 #[test]
84 fn test_euclidean_distances() {
85 let x = Array2::<f64>::zeros((3, 5));
86 let y = Array2::from_elem((3, 5), 5.0);
87 let distance = super::euclidean_distances(&x, &y, false);
90
91 let truth = array![
101 [11.180339887498949, 11.180339887498949, 11.180339887498949],
102 [11.180339887498949, 11.180339887498949, 11.180339887498949],
103 [11.180339887498949, 11.180339887498949, 11.180339887498949]
104 ];
105
106 assert_eq!(distance, truth);
107 }
108
109 #[test]
110 #[allow(non_snake_case)]
111 fn test_dist() {
112 let x = Array2::<f64>::zeros((3, 5));
113 let y = Array2::from_elem((3, 5), 5.0);
114
115 let M = super::dist(&x, &y, super::MetricType::Euclidean);
116
117 let truth = array![
121 [11.180339887498949, 11.180339887498949, 11.180339887498949],
122 [11.180339887498949, 11.180339887498949, 11.180339887498949],
123 [11.180339887498949, 11.180339887498949, 11.180339887498949]
124 ];
125
126 assert_eq!(M, truth);
127 }
128}