rust_optimal_transport/
metrics.rs

1use ndarray::prelude::*;
2use ndarray_einsum_beta::*;
3
4pub enum MetricType {
5    SqEuclidean,
6    Euclidean,
7}
8
9/// Compute distance between samples in x1 and x2
10/// x1: matrix with n1 samples of size d
11/// x2: matrix with n2 samples of size d
12/// metric: choice of distance metric
13pub fn dist(x1: &Array2<f64>, x2: &Array2<f64>, metric: MetricType) -> Array2<f64> {
14    match metric {
15        MetricType::SqEuclidean => euclidean_distances(x1, x2, true),
16        MetricType::Euclidean => euclidean_distances(x1, x2, false),
17    }
18}
19
20/// Considering the rows of X (and Y=X) as vectors, compute the distance matrix between each pair
21/// of vectors
22/// X: matrix of nsamples x nfeatures
23/// Y: matrix of nsamples x nfeatures
24/// squared: Return squared Euclidean distances
25fn euclidean_distances(x: &Array2<f64>, y: &Array2<f64>, squared: bool) -> Array2<f64> {
26    // einsum('ij,ij->i', X, X)
27    // repeated i and j for both x and y inout matrices : multiply those components
28    // - element-wise multiplication
29    // ommitted j in output : sum along j axis
30    // - summation in j
31    let a2 = einsum("ij,ij->i", &[x, x]).unwrap();
32    // einsum('ij,ij->i', Y, Y)
33    let b2 = einsum("ij,ij->i", &[y, y]).unwrap();
34
35    let mut c = (x.dot(&y.t())) * -2f64;
36
37    // c += a2[:, None]
38    for (mut row, a2val) in c.axis_iter_mut(Axis(0)).zip(&a2.t()) {
39        for ele in row.iter_mut() {
40            *ele += a2val;
41        }
42    }
43
44    // c += b2[None, :]
45    for (mut col, b2val) in c.axis_iter_mut(Axis(1)).zip(&b2) {
46        for ele in col.iter_mut() {
47            *ele += b2val;
48        }
49    }
50
51    // c = nx.maximum(c, 0)
52    for val in c.iter_mut() {
53        if *val <= 0f64 {
54            *val = 0f64;
55        }
56    }
57
58    if !squared {
59        // np.sqrt(c)
60        for val in c.iter_mut() {
61            *val = val.powf(0.5);
62        }
63    }
64
65    if x == y {
66        // ones matrix with diagonals set to zero
67        let mut anti_diag = Array2::<f64>::ones((a2.len(), b2.len()));
68        for ele in anti_diag.diag_mut().iter_mut() {
69            *ele = 0f64;
70        }
71
72        c = c * anti_diag;
73    }
74
75    c
76}
77
78#[cfg(test)]
79mod tests {
80
81    use ndarray::prelude::*;
82
83    #[test]
84    fn test_euclidean_distances() {
85        let x = Array2::<f64>::zeros((3, 5));
86        let y = Array2::from_elem((3, 5), 5.0);
87        // let y = DMatrix::from_element(3, 5, 5.0);
88
89        let distance = super::euclidean_distances(&x, &y, false);
90
91        // println!("euclidean_distances: {:?}", distance);
92
93        // squared = true
94        // let truth = array![
95        //             [125.0, 125.0, 125.0],
96        //             [125.0, 125.0, 125.0],
97        //             [125.0, 125.0, 125.0]];
98
99        // squared = false
100        let truth = array![
101            [11.180339887498949, 11.180339887498949, 11.180339887498949],
102            [11.180339887498949, 11.180339887498949, 11.180339887498949],
103            [11.180339887498949, 11.180339887498949, 11.180339887498949]
104        ];
105
106        assert_eq!(distance, truth);
107    }
108
109    #[test]
110    #[allow(non_snake_case)]
111    fn test_dist() {
112        let x = Array2::<f64>::zeros((3, 5));
113        let y = Array2::from_elem((3, 5), 5.0);
114
115        let M = super::dist(&x, &y, super::MetricType::Euclidean);
116
117        // println!("dist: {:?}", M);
118
119        // squared = false
120        let truth = array![
121            [11.180339887498949, 11.180339887498949, 11.180339887498949],
122            [11.180339887498949, 11.180339887498949, 11.180339887498949],
123            [11.180339887498949, 11.180339887498949, 11.180339887498949]
124        ];
125
126        assert_eq!(M, truth);
127    }
128}