rust_optimal_transport/
utils.rs1use ndarray::prelude::*;
2use ndarray_linalg::cholesky::*;
3use ndarray_rand::rand::{thread_rng, Rng};
4use ndarray_rand::rand_distr::StandardNormal;
5
6use anyhow::anyhow;
7use thiserror::Error;
8
9#[derive(Error, Debug)]
11pub enum DistributionError {
12 #[error("Oops!")]
13 Oops(String),
14 #[error(transparent)]
15 Other(#[from] anyhow::Error),
16}
17
18pub fn get_1D_gauss_histogram(
23 n: i32,
24 mean: f64,
25 std: f64,
26) -> Result<Array1<f64>, DistributionError> {
27 let x = Array1::<f64>::range(0.0, n as f64, 1.0);
28 let var = std.powf(2.0);
29 let denom = 2.0 * var;
30 let diff = &x - mean;
31 let numerator = -&diff * &diff;
32 let mut result: Array1<f64> = numerator.iter().map(|val| (val / denom).exp()).collect();
33 let summed_val = result.sum();
34
35 result /= summed_val;
36
37 Ok(result)
40}
41
42pub fn sample_2D_gauss(
47 n: i32,
48 mean: &Array1<f64>,
49 cov: &Array2<f64>,
50) -> Result<Array2<f64>, DistributionError> {
51 let cov_shape = cov.shape();
52
53 if n <= 0 {
54 return Err(DistributionError::Oops(
55 "n is not greater than zero".to_string(),
56 ));
57 }
58
59 if mean.is_empty() || cov.is_empty() {
60 return Err(DistributionError::Oops(
61 "zero length mean or covariance".to_string(),
62 ));
63 }
64
65 if cov_shape[0] != mean.len() && cov_shape[1] != mean.len() {
66 return Err(DistributionError::Oops(
67 "covariance dimensions do not match mean dimensions".to_string(),
68 ));
69 }
70
71 let mut rng = thread_rng();
72 let mut samples = Array2::<f64>::zeros((n as usize, 2));
73 for mut row in samples.axis_iter_mut(Axis(0)) {
74 row[0] = rng.sample(StandardNormal);
75 row[1] = rng.sample(StandardNormal);
76 }
77
78 let epsilon = 0.0001;
80 let cov_perturbed = cov + Array2::<f64>::eye(cov_shape[0]) * epsilon;
81
82 let lower = match cov_perturbed.cholesky(UPLO::Lower) {
84 Ok(val) => val,
85 Err(_) => return Err(DistributionError::Other(anyhow!("oops!"))),
86 };
87
88 Ok(mean + samples.dot(&lower))
89}
90
91#[cfg(test)]
92mod tests {
93
94 use ndarray::array;
95
96 #[test]
97 fn test_get_1D_gauss_hist() {
98 let n = 50;
99 let mean = 20.0;
100 let std = 5.0;
101
102 let result = match super::get_1D_gauss_histogram(n, mean, std) {
103 Ok(val) => val,
104 Err(err) => panic!("{:?}", err),
105 };
106
107 }
110
111 #[test]
112 fn test_sample_2D_gauss() {
113 let n = 50;
114 let mean = array![0.0, 0.0];
115 let covariance = array![[1.0, 0.0], [0.0, 1.0]];
116
117 let result = match super::sample_2D_gauss(n, &mean, &covariance) {
118 Ok(val) => val,
119 Err(err) => panic!("{:?}", err),
120 };
121
122 }
125}