#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use crate::consts::HALF_LN_2PI_E;
use crate::consts::LN_2PI;
use crate::data::MvGaussianSuffStat;
use crate::impl_display;
use crate::traits::*;
use nalgebra::linalg::Cholesky;
use nalgebra::{DMatrix, DVector, Dynamic};
use once_cell::sync::OnceCell;
use rand::Rng;
use std::fmt;
#[derive(Clone, Debug)]
struct MvgCache {
pub cov_chol: Cholesky<f64, Dynamic>,
pub cov_inv: DMatrix<f64>,
}
impl MvgCache {
pub fn from_cov(cov: &DMatrix<f64>) -> Result<Self, MvGaussianError> {
match cov.clone().cholesky() {
None => Err(MvGaussianError::CovNotPositiveSemiDefinite),
Some(cov_chol) => {
let cov_inv = cov_chol.inverse();
Ok(MvgCache { cov_chol, cov_inv })
}
}
}
#[inline]
pub fn from_chol(cov_chol: Cholesky<f64, Dynamic>) -> Self {
let cov_inv = cov_chol.inverse();
MvgCache { cov_chol, cov_inv }
}
#[inline]
pub fn cov(&self) -> DMatrix<f64> {
let l = self.cov_chol.l();
&l * &l.transpose()
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct MvGaussian {
mu: DVector<f64>,
cov: DMatrix<f64>,
#[cfg_attr(
feature = "serde1",
serde(skip, default = "default_cache_none")
)]
cache: OnceCell<MvgCache>,
}
#[allow(dead_code)]
#[cfg(feature = "serde1")]
fn default_cache_none() -> OnceCell<MvgCache> {
OnceCell::new()
}
impl PartialEq for MvGaussian {
fn eq(&self, other: &MvGaussian) -> bool {
self.mu == other.mu && self.cov == other.cov
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum MvGaussianError {
MuCovDimensionMismatch {
n_mu: usize,
n_cov: usize,
},
CovNotSquare {
nrows: usize,
ncols: usize,
},
CovNotPositiveSemiDefinite,
ZeroDimension,
}
impl MvGaussian {
pub fn new(
mu: DVector<f64>,
cov: DMatrix<f64>,
) -> Result<Self, MvGaussianError> {
let cov_rows = cov.nrows();
let cov_cols = cov.ncols();
if cov_rows != cov_cols {
Err(MvGaussianError::CovNotSquare {
nrows: cov_rows,
ncols: cov_cols,
})
} else if mu.len() != cov_rows {
Err(MvGaussianError::MuCovDimensionMismatch {
n_mu: mu.len(),
n_cov: cov_rows,
})
} else {
let cache = OnceCell::from(MvgCache::from_cov(&cov)?);
Ok(MvGaussian { mu, cov, cache })
}
}
pub fn new_cholesky(
mu: DVector<f64>,
cov_chol: Cholesky<f64, Dynamic>,
) -> Result<Self, MvGaussianError> {
let l = cov_chol.l();
let cov = &l * &l.transpose();
if mu.len() != cov.nrows() {
Err(MvGaussianError::MuCovDimensionMismatch {
n_mu: mu.len(),
n_cov: cov.nrows(),
})
} else {
let cache = OnceCell::from(MvgCache::from_chol(cov_chol));
Ok(MvGaussian { mu, cov, cache })
}
}
#[inline]
pub fn new_unchecked(mu: DVector<f64>, cov: DMatrix<f64>) -> Self {
let cache = OnceCell::from(MvgCache::from_cov(&cov).unwrap());
MvGaussian { mu, cov, cache }
}
#[inline]
pub fn new_cholesky_unchecked(
mu: DVector<f64>,
cov_chol: Cholesky<f64, Dynamic>,
) -> Self {
let cache = OnceCell::from(MvgCache::from_chol(cov_chol));
let cov = cache.get().unwrap().cov();
MvGaussian { mu, cov, cache }
}
#[inline]
pub fn standard(dims: usize) -> Result<Self, MvGaussianError> {
if dims == 0 {
Err(MvGaussianError::ZeroDimension)
} else {
let mu = DVector::zeros(dims);
let cov = DMatrix::identity(dims, dims);
let cov_chol = cov.clone().cholesky().unwrap();
let cache = OnceCell::from(MvgCache::from_chol(cov_chol));
Ok(MvGaussian { mu, cov, cache })
}
}
#[inline]
pub fn ndims(&self) -> usize {
self.mu.len()
}
#[inline]
pub fn mu(&self) -> &DVector<f64> {
&self.mu
}
#[inline]
pub fn cov(&self) -> &DMatrix<f64> {
&self.cov
}
#[inline]
pub fn set_mu(&mut self, mu: DVector<f64>) -> Result<(), MvGaussianError> {
if mu.len() != self.cov.nrows() {
Err(MvGaussianError::MuCovDimensionMismatch {
n_mu: mu.len(),
n_cov: self.cov.nrows(),
})
} else {
self.mu = mu;
Ok(())
}
}
#[inline]
pub fn set_mu_unchecked(&mut self, mu: DVector<f64>) {
self.mu = mu;
}
pub fn set_cov(
&mut self,
cov: DMatrix<f64>,
) -> Result<(), MvGaussianError> {
let cov_rows = cov.nrows();
if self.mu.len() != cov_rows {
Err(MvGaussianError::MuCovDimensionMismatch {
n_mu: self.mu.len(),
n_cov: cov.nrows(),
})
} else if cov_rows != cov.ncols() {
Err(MvGaussianError::CovNotSquare {
nrows: cov_rows,
ncols: cov.ncols(),
})
} else {
let cache = MvgCache::from_cov(&cov)?;
self.cov = cov;
self.cache = OnceCell::new();
self.cache.set(cache).unwrap();
Ok(())
}
}
#[inline]
pub fn set_cov_unchecked(&mut self, cov: DMatrix<f64>) {
let cache = MvgCache::from_cov(&cov).unwrap();
self.cov = cov;
self.cache = OnceCell::from(cache);
}
#[inline]
fn cache(&self) -> &MvgCache {
self.cache
.get_or_try_init(|| MvgCache::from_cov(&self.cov))
.unwrap()
}
}
impl From<&MvGaussian> for String {
fn from(mvg: &MvGaussian) -> String {
format!("Nₖ({})\n μ: {}\n σ: {})", mvg.ndims(), mvg.mu, mvg.cov)
}
}
impl_display!(MvGaussian);
impl Rv<DVector<f64>> for MvGaussian {
fn ln_f(&self, x: &DVector<f64>) -> f64 {
let diff = x - &self.mu;
let det: f64 = self
.cache()
.cov_chol
.l_dirty()
.diagonal()
.row_iter()
.fold(1.0, |acc, y| acc * y[0])
.powi(2);
let inv = &(self.cache().cov_inv);
let term: f64 = (diff.transpose() * inv * &diff)[0];
-0.5 * (det.ln() + term + (diff.nrows() as f64) * LN_2PI)
}
fn draw<R: Rng>(&self, rng: &mut R) -> DVector<f64> {
let dims = self.mu.len();
let norm = rand_distr::StandardNormal;
let vals: Vec<f64> = (0..dims).map(|_| rng.sample(norm)).collect();
let a = self.cache().cov_chol.l_dirty();
let z: DVector<f64> = DVector::from_column_slice(&vals);
DVector::from_fn(dims, |i, _| {
let mut out: f64 = self.mu[i];
for j in 0..=i {
out += a[(i, j)] * z[j];
}
out
})
}
}
impl Support<DVector<f64>> for MvGaussian {
fn supports(&self, x: &DVector<f64>) -> bool {
x.len() == self.mu.len()
}
}
impl ContinuousDistr<DVector<f64>> for MvGaussian {}
impl Mean<DVector<f64>> for MvGaussian {
fn mean(&self) -> Option<DVector<f64>> {
Some(self.mu.clone())
}
}
impl Mode<DVector<f64>> for MvGaussian {
fn mode(&self) -> Option<DVector<f64>> {
Some(self.mu.clone())
}
}
impl Variance<DMatrix<f64>> for MvGaussian {
fn variance(&self) -> Option<DMatrix<f64>> {
Some(self.cov.clone())
}
}
impl Entropy for MvGaussian {
fn entropy(&self) -> f64 {
let det: f64 = self
.cache()
.cov_chol
.l_dirty()
.diagonal()
.row_iter()
.fold(1.0, |acc, x| acc * x[0])
.powi(2);
0.5 * det.ln() + HALF_LN_2PI_E * (self.cov.nrows() as f64)
}
}
impl HasSuffStat<DVector<f64>> for MvGaussian {
type Stat = MvGaussianSuffStat;
fn empty_suffstat(&self) -> Self::Stat {
MvGaussianSuffStat::new(self.mu.len())
}
}
impl std::error::Error for MvGaussianError {}
impl fmt::Display for MvGaussianError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ZeroDimension => write!(f, "requested dimension is too low"),
Self::CovNotPositiveSemiDefinite => {
write!(f, "covariance is not positive semi-definite")
}
Self::MuCovDimensionMismatch { n_mu, n_cov } => write!(
f,
"mean vector and covariance matrix do not align. mu is {} \
dimensions but cov is {} dimensions",
n_mu, n_cov
),
Self::CovNotSquare { nrows, ncols } => write!(
f,
"covariance matrix is not square ({} x {})",
nrows, ncols
),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dist::Gaussian;
use crate::misc::{ks_test, mardia};
use crate::test_basic_impls;
const TOL: f64 = 1E-12;
const NTRIES: usize = 5;
const KS_PVAL: f64 = 0.2;
const MARDIA_PVAL: f64 = 0.2;
test_basic_impls!(MvGaussian::standard(3).unwrap(), DVector::zeros(3));
#[test]
fn new() {
let mu = DVector::zeros(3);
let cov = DMatrix::identity(3, 3);
assert!(MvGaussian::new(mu, cov).is_ok());
}
#[test]
fn new_should_reject_cov_too_big() {
let mu = DVector::zeros(3);
let cov = DMatrix::identity(4, 4);
let mvg = MvGaussian::new(mu, cov);
assert_eq!(
mvg,
Err(MvGaussianError::MuCovDimensionMismatch { n_mu: 3, n_cov: 4 })
)
}
#[test]
fn new_should_reject_cov_too_small() {
let mu = DVector::zeros(3);
let cov = DMatrix::identity(2, 2);
let mvg = MvGaussian::new(mu, cov);
assert_eq!(
mvg,
Err(MvGaussianError::MuCovDimensionMismatch { n_mu: 3, n_cov: 2 })
)
}
#[test]
fn new_should_reject_cov_not_square() {
let mu = DVector::zeros(3);
let cov = DMatrix::identity(3, 2);
let mvg = MvGaussian::new(mu, cov);
assert_eq!(
mvg,
Err(MvGaussianError::CovNotSquare { nrows: 3, ncols: 2 })
);
}
#[test]
fn ln_f_standard_x_zeros() {
let mvg = MvGaussian::standard(3).unwrap();
let x = DVector::<f64>::zeros(3);
assert::close(mvg.ln_f(&x), -2.756815599614018, TOL);
}
#[test]
fn ln_f_standard_x_nonzeros() {
let mvg = MvGaussian::standard(3).unwrap();
let x = DVector::<f64>::from_column_slice(&[0.5, 3.1, -6.2]);
assert::close(mvg.ln_f(&x), -26.906_815_599_614_02, TOL);
}
#[test]
fn ln_f_nonstandard_zeros() {
let cov_vals = vec![
1.01742788,
0.36586652,
-0.65620486,
0.36586652,
1.00564553,
-0.42597261,
-0.65620486,
-0.42597261,
1.27247972,
];
let cov: DMatrix<f64> = DMatrix::from_row_slice(3, 3, &cov_vals);
let mu = DVector::<f64>::from_column_slice(&[0.5, 3.1, -6.2]);
let mvg = MvGaussian::new(mu, cov).unwrap();
let x = DVector::<f64>::zeros(3);
assert::close(mvg.ln_f(&x), -24.602_370_253_215_66, TOL);
}
#[test]
fn ln_f_nonstandard_nonzeros() {
let cov_vals = vec![
1.01742788,
0.36586652,
-0.65620486,
0.36586652,
1.00564553,
-0.42597261,
-0.65620486,
-0.42597261,
1.27247972,
];
let cov: DMatrix<f64> = DMatrix::from_row_slice(3, 3, &cov_vals);
let mu = DVector::<f64>::from_column_slice(&[0.5, 3.1, -6.2]);
let mvg = MvGaussian::new(mu, cov).unwrap();
let x = DVector::<f64>::from_column_slice(&[0.5, 3.1, -6.2]);
assert::close(mvg.ln_f(&x), -2.5915350538112296, TOL);
}
#[test]
fn sample_returns_proper_number_of_draws() {
let cov_vals = vec![
1.01742788,
0.36586652,
-0.65620486,
0.36586652,
1.00564553,
-0.42597261,
-0.65620486,
-0.42597261,
1.27247972,
];
let cov: DMatrix<f64> = DMatrix::from_row_slice(3, 3, &cov_vals);
let mu = DVector::<f64>::from_column_slice(&[0.5, 3.1, -6.2]);
let mvg = MvGaussian::new(mu, cov).unwrap();
let mut rng = rand::thread_rng();
let xs = mvg.sample(103, &mut rng);
assert_eq!(xs.len(), 103);
}
#[test]
fn standard_entropy() {
let mvg = MvGaussian::standard(3).unwrap();
assert::close(mvg.entropy(), 4.2568155996140185, TOL);
}
#[test]
fn nonstandard_entropy() {
let cov_vals = vec![
1.01742788,
0.36586652,
-0.65620486,
0.36586652,
1.00564553,
-0.42597261,
-0.65620486,
-0.42597261,
1.27247972,
];
let cov: DMatrix<f64> = DMatrix::from_row_slice(3, 3, &cov_vals);
let mu = DVector::<f64>::from_column_slice(&[0.5, 3.1, -6.2]);
let mvg = MvGaussian::new(mu, cov).unwrap();
assert::close(mvg.entropy(), 4.0915350538112305, TOL);
}
#[test]
fn standard_draw_marginals() {
let mut rng = rand::thread_rng();
let mvg = MvGaussian::standard(2).unwrap();
let g = Gaussian::standard();
let cdf = |x: f64| g.cdf(&x);
let passed = (0..NTRIES).fold(false, |acc, _| {
if acc {
acc
} else {
let xys = mvg.sample(500, &mut rng);
let xs: Vec<f64> = xys.iter().map(|xy| xy[0]).collect();
let ys: Vec<f64> = xys.iter().map(|xy| xy[1]).collect();
let (_, px) = ks_test(&xs, cdf);
let (_, py) = ks_test(&ys, cdf);
px > KS_PVAL && py > KS_PVAL
}
});
assert!(passed);
}
#[test]
fn standard_draw_mardia() {
let mut rng = rand::thread_rng();
let mvg = MvGaussian::standard(4).unwrap();
let passed = (0..NTRIES).fold(false, |acc, _| {
if acc {
acc
} else {
let xys = mvg.sample(500, &mut rng);
let (pa, pb) = mardia(&xys);
pa > MARDIA_PVAL && pb > MARDIA_PVAL
}
});
assert!(passed);
}
#[test]
fn nonstandard_draw_mardia() {
let mut rng = rand::thread_rng();
let cov_vals = vec![
1.01742788,
0.36586652,
-0.65620486,
0.36586652,
1.00564553,
-0.42597261,
-0.65620486,
-0.42597261,
1.27247972,
];
let cov: DMatrix<f64> = DMatrix::from_row_slice(3, 3, &cov_vals);
let mu = DVector::<f64>::from_column_slice(&[0.5, 3.1, -6.2]);
let mvg = MvGaussian::new(mu, cov).unwrap();
let passed = (0..NTRIES).fold(false, |acc, _| {
if acc {
acc
} else {
let xys = mvg.sample(500, &mut rng);
let (pa, pb) = mardia(&xys);
pa > MARDIA_PVAL && pb > MARDIA_PVAL
}
});
assert!(passed);
}
}