#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use crate::dist::{InvWishart, MvGaussian};
use crate::impl_display;
use crate::traits::*;
use nalgebra::{DMatrix, DVector};
use rand::Rng;
use std::fmt;
mod mvg_prior;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct NormalInvWishart {
mu: DVector<f64>,
k: f64,
df: usize,
scale: DMatrix<f64>,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum NormalInvWishartError {
KTooLow { k: f64 },
DfLessThanDimensions { df: usize, ndims: usize },
ScaleMatrixNotSquare {
nrows: usize,
ncols: usize,
},
MuScaleDimensionMismatch {
n_mu: usize,
n_scale: usize,
},
}
fn validate_params(
mu: &DVector<f64>,
k: f64,
df: usize,
scale: &DMatrix<f64>,
) -> Result<(), NormalInvWishartError> {
let ndims = mu.len();
if k <= 0.0 {
Err(NormalInvWishartError::KTooLow { k })
} else if df < ndims {
Err(NormalInvWishartError::DfLessThanDimensions { df, ndims })
} else if !scale.is_square() {
Err(NormalInvWishartError::ScaleMatrixNotSquare {
nrows: scale.nrows(),
ncols: scale.ncols(),
})
} else if ndims != scale.nrows() {
Err(NormalInvWishartError::MuScaleDimensionMismatch {
n_mu: ndims,
n_scale: scale.nrows(),
})
} else {
Ok(())
}
}
impl NormalInvWishart {
#[inline]
pub fn new(
mu: DVector<f64>,
k: f64,
df: usize,
scale: DMatrix<f64>,
) -> Result<Self, NormalInvWishartError> {
validate_params(&mu, k, df, &scale)?;
Ok(NormalInvWishart { mu, k, df, scale })
}
#[inline]
pub fn new_unchecked(
mu: DVector<f64>,
k: f64,
df: usize,
scale: DMatrix<f64>,
) -> Self {
NormalInvWishart { mu, k, df, scale }
}
#[inline]
pub fn ndims(&self) -> usize {
self.mu.len()
}
#[inline]
pub fn mu(&self) -> &DVector<f64> {
&self.mu
}
#[inline]
pub fn k(&self) -> f64 {
self.k
}
#[inline]
pub fn set_k(&mut self, k: f64) -> Result<(), NormalInvWishartError> {
if k <= 0.0 {
Err(NormalInvWishartError::KTooLow { k })
} else {
self.k = k;
Ok(())
}
}
#[inline]
pub fn set_k_unchecked(&mut self, k: f64) {
self.k = k;
}
#[inline]
pub fn df(&self) -> usize {
self.df
}
#[inline]
pub fn set_df(&mut self, df: usize) -> Result<(), NormalInvWishartError> {
let ndims = self.ndims();
if df < ndims {
Err(NormalInvWishartError::DfLessThanDimensions { df, ndims })
} else {
self.set_df_unchecked(df);
Ok(())
}
}
#[inline]
pub fn set_df_unchecked(&mut self, df: usize) {
self.df = df;
}
#[inline]
pub fn scale(&self) -> &DMatrix<f64> {
&self.scale
}
#[inline]
pub fn set_scale(
&mut self,
scale: DMatrix<f64>,
) -> Result<(), NormalInvWishartError> {
validate_params(&self.mu, self.k, self.df, &scale)?;
self.scale = scale;
Ok(())
}
#[inline]
pub fn set_scale_unnchecked(&mut self, scale: DMatrix<f64>) {
self.scale = scale;
}
#[inline]
pub fn set_mu(
&mut self,
mu: DVector<f64>,
) -> Result<(), NormalInvWishartError> {
validate_params(&mu, self.k, self.df, &self.scale)?;
self.mu = mu;
Ok(())
}
#[inline]
pub fn set_mu_unchecked(&mut self, mu: DVector<f64>) {
self.mu = mu;
}
}
impl From<&NormalInvWishart> for String {
fn from(niw: &NormalInvWishart) -> String {
format!(
"NIW (\n μ: {}\n κ: {}\n ν: {}\n Ψ: {}",
niw.mu, niw.k, niw.df, niw.scale
)
}
}
impl_display!(NormalInvWishart);
impl Rv<MvGaussian> for NormalInvWishart {
fn ln_f(&self, x: &MvGaussian) -> f64 {
let m = self.mu.clone();
let sigma = x.cov().to_owned() / self.k;
let mvg = MvGaussian::new_unchecked(m, sigma);
let iw = InvWishart::new_unchecked(self.scale.clone(), self.df);
mvg.ln_f(x.mu()) + iw.ln_f(x.cov())
}
fn draw<R: Rng>(&self, mut rng: &mut R) -> MvGaussian {
let iw = InvWishart::new_unchecked(self.scale.clone(), self.df);
let sigma = iw.draw(&mut rng);
let mvg =
MvGaussian::new_unchecked(self.mu.clone(), sigma.clone() / self.k);
let mu = mvg.draw(&mut rng);
MvGaussian::new(mu, sigma).unwrap()
}
}
impl Support<MvGaussian> for NormalInvWishart {
fn supports(&self, x: &MvGaussian) -> bool {
let p = self.mu.len();
x.mu().len() == p && x.cov().to_owned().cholesky().is_some()
}
}
impl ContinuousDistr<MvGaussian> for NormalInvWishart {}
impl std::error::Error for NormalInvWishartError {}
impl fmt::Display for NormalInvWishartError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::KTooLow { k } => {
write!(f, "k ({}) must be greater than zero", k)
}
Self::DfLessThanDimensions { df, ndims } => write!(
f,
"df, the degrees of freedom must be greater than or \
equal to the number of dimensions, but {} < {}",
df, ndims
),
Self::ScaleMatrixNotSquare { nrows, ncols } => write!(
f,
"The scale matrix is not square: {} x {}",
nrows, ncols
),
Self::MuScaleDimensionMismatch { n_mu, n_scale } => write!(
f,
"The mu vector (nrows = {}) must have the same \
number of entries as the scale matrix has columns/rows \
(ndims = {}). ",
n_mu, n_scale
),
}
}
}