rust_optimal_transport/
utils.rs

1use ndarray::prelude::*;
2use ndarray_linalg::cholesky::*;
3use ndarray_rand::rand::{thread_rng, Rng};
4use ndarray_rand::rand_distr::StandardNormal;
5
6use anyhow::anyhow;
7use thiserror::Error;
8
9// TODO: Add additional error cases for DistributionError enum
10#[derive(Error, Debug)]
11pub enum DistributionError {
12    #[error("Oops!")]
13    Oops(String),
14    #[error(transparent)]
15    Other(#[from] anyhow::Error),
16}
17
18/// Returns a 1D histogram for a gaussian distribution
19/// n: number of bins in histogram
20/// mean: mean value of distribution
21/// std: standard distribution of distribution
22pub fn get_1D_gauss_histogram(
23    n: i32,
24    mean: f64,
25    std: f64,
26) -> Result<Array1<f64>, DistributionError> {
27    let x = Array1::<f64>::range(0.0, n as f64, 1.0);
28    let var = std.powf(2.0);
29    let denom = 2.0 * var;
30    let diff = &x - mean;
31    let numerator = -&diff * &diff;
32    let mut result: Array1<f64> = numerator.iter().map(|val| (val / denom).exp()).collect();
33    let summed_val = result.sum();
34
35    result /= summed_val;
36
37    // TODO: add error handling
38
39    Ok(result)
40}
41
42/// Returns n samples drawn from a 2D gaussian distribution
43/// n: number of samples to take
44/// mean: mean values (x,y) of distribution
45/// cov: covariance matrix of the distribution
46pub fn sample_2D_gauss(
47    n: i32,
48    mean: &Array1<f64>,
49    cov: &Array2<f64>,
50) -> Result<Array2<f64>, DistributionError> {
51    let cov_shape = cov.shape();
52
53    if n <= 0 {
54        return Err(DistributionError::Oops(
55            "n is not greater than zero".to_string(),
56        ));
57    }
58
59    if mean.is_empty() || cov.is_empty() {
60        return Err(DistributionError::Oops(
61            "zero length mean or covariance".to_string(),
62        ));
63    }
64
65    if cov_shape[0] != mean.len() && cov_shape[1] != mean.len() {
66        return Err(DistributionError::Oops(
67            "covariance dimensions do not match mean dimensions".to_string(),
68        ));
69    }
70
71    let mut rng = thread_rng();
72    let mut samples = Array2::<f64>::zeros((n as usize, 2));
73    for mut row in samples.axis_iter_mut(Axis(0)) {
74        row[0] = rng.sample(StandardNormal);
75        row[1] = rng.sample(StandardNormal);
76    }
77
78    // add small perturbation to covariance matrix for numerical stability
79    let epsilon = 0.0001;
80    let cov_perturbed = cov + Array2::<f64>::eye(cov_shape[0]) * epsilon;
81
82    // Compute cholesky decomposition
83    let lower = match cov_perturbed.cholesky(UPLO::Lower) {
84        Ok(val) => val,
85        Err(_) => return Err(DistributionError::Other(anyhow!("oops!"))),
86    };
87
88    Ok(mean + samples.dot(&lower))
89}
90
91#[cfg(test)]
92mod tests {
93
94    use ndarray::array;
95
96    #[test]
97    fn test_get_1D_gauss_hist() {
98        let n = 50;
99        let mean = 20.0;
100        let std = 5.0;
101
102        let result = match super::get_1D_gauss_histogram(n, mean, std) {
103            Ok(val) => val,
104            Err(err) => panic!("{:?}", err),
105        };
106
107        // TODO: assert correctness get_1D_gauss_histogram()
108        // println!("{:?}", result);
109    }
110
111    #[test]
112    fn test_sample_2D_gauss() {
113        let n = 50;
114        let mean = array![0.0, 0.0];
115        let covariance = array![[1.0, 0.0], [0.0, 1.0]];
116
117        let result = match super::sample_2D_gauss(n, &mean, &covariance) {
118            Ok(val) => val,
119            Err(err) => panic!("{:?}", err),
120        };
121
122        // TODO: assert correctness of sample_2D_gauss()
123        // println!("{:?}", result);
124    }
125}