#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use once_cell::sync::OnceCell;
use rand::Rng;
use rand_distr::Normal;
use special::Error as _;
use std::f64::consts::SQRT_2;
use std::fmt;
use crate::consts::*;
use crate::data::GaussianSuffStat;
use crate::traits::*;
use crate::{clone_cache_f64, impl_display};
#[derive(Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Gaussian {
mu: f64,
sigma: f64,
#[cfg_attr(feature = "serde1", serde(skip))]
ln_sigma: OnceCell<f64>,
}
impl Clone for Gaussian {
fn clone(&self) -> Self {
Self {
mu: self.mu,
sigma: self.sigma,
ln_sigma: clone_cache_f64!(self, ln_sigma),
}
}
}
impl PartialEq for Gaussian {
fn eq(&self, other: &Gaussian) -> bool {
self.mu == other.mu && self.sigma == other.sigma
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum GaussianError {
MuNotFinite { mu: f64 },
SigmaTooLow { sigma: f64 },
SigmaNotFinite { sigma: f64 },
}
impl Gaussian {
pub fn new(mu: f64, sigma: f64) -> Result<Self, GaussianError> {
if !mu.is_finite() {
Err(GaussianError::MuNotFinite { mu })
} else if sigma <= 0.0 {
Err(GaussianError::SigmaTooLow { sigma })
} else if !sigma.is_finite() {
Err(GaussianError::SigmaNotFinite { sigma })
} else {
Ok(Gaussian {
mu,
sigma,
ln_sigma: OnceCell::new(),
})
}
}
#[inline]
pub fn new_unchecked(mu: f64, sigma: f64) -> Self {
Gaussian {
mu,
sigma,
ln_sigma: OnceCell::new(),
}
}
#[inline]
pub fn standard() -> Self {
Gaussian {
mu: 0.0,
sigma: 1.0,
ln_sigma: OnceCell::from(0.0),
}
}
#[inline]
pub fn mu(&self) -> f64 {
self.mu
}
#[inline]
pub fn set_mu(&mut self, mu: f64) -> Result<(), GaussianError> {
if !mu.is_finite() {
Err(GaussianError::MuNotFinite { mu })
} else {
self.set_mu_unchecked(mu);
Ok(())
}
}
#[inline]
pub fn set_mu_unchecked(&mut self, mu: f64) {
self.mu = mu;
}
#[inline]
pub fn sigma(&self) -> f64 {
self.sigma
}
#[inline]
pub fn set_sigma(&mut self, sigma: f64) -> Result<(), GaussianError> {
if sigma <= 0.0 {
Err(GaussianError::SigmaTooLow { sigma })
} else if !sigma.is_finite() {
Err(GaussianError::SigmaNotFinite { sigma })
} else {
self.set_sigma_unchecked(sigma);
Ok(())
}
}
#[inline]
pub fn set_sigma_unchecked(&mut self, sigma: f64) {
self.sigma = sigma;
self.ln_sigma = OnceCell::new();
}
#[inline]
fn ln_sigma(&self) -> f64 {
*self.ln_sigma.get_or_init(|| self.sigma.ln())
}
}
impl Default for Gaussian {
fn default() -> Self {
Gaussian::standard()
}
}
impl From<&Gaussian> for String {
fn from(gauss: &Gaussian) -> String {
format!("N(μ: {}, σ: {})", gauss.mu, gauss.sigma)
}
}
impl_display!(Gaussian);
macro_rules! impl_traits {
($kind:ty) => {
impl Rv<$kind> for Gaussian {
fn ln_f(&self, x: &$kind) -> f64 {
let k = (f64::from(*x) - self.mu) / self.sigma;
-self.ln_sigma() - 0.5 * k * k - HALF_LN_2PI
}
fn draw<R: Rng>(&self, rng: &mut R) -> $kind {
let g = Normal::new(self.mu, self.sigma).unwrap();
rng.sample(g) as $kind
}
fn sample<R: Rng>(&self, n: usize, rng: &mut R) -> Vec<$kind> {
let g = Normal::new(self.mu, self.sigma).unwrap();
(0..n).map(|_| rng.sample(g) as $kind).collect()
}
}
impl ContinuousDistr<$kind> for Gaussian {}
impl Support<$kind> for Gaussian {
fn supports(&self, x: &$kind) -> bool {
x.is_finite()
}
}
impl Cdf<$kind> for Gaussian {
fn cdf(&self, x: &$kind) -> f64 {
let errf =
((f64::from(*x) - self.mu) / (self.sigma * SQRT_2)).error();
0.5 * (1.0 + errf)
}
}
impl InverseCdf<$kind> for Gaussian {
fn invcdf(&self, p: f64) -> $kind {
if (p <= 0.0) || (1.0 <= p) {
panic!("P out of range");
}
let x =
self.mu + self.sigma * SQRT_2 * (2.0 * p - 1.0).inv_error();
x as $kind
}
}
impl Mean<$kind> for Gaussian {
fn mean(&self) -> Option<$kind> {
Some(self.mu as $kind)
}
}
impl Median<$kind> for Gaussian {
fn median(&self) -> Option<$kind> {
Some(self.mu as $kind)
}
}
impl Mode<$kind> for Gaussian {
fn mode(&self) -> Option<$kind> {
Some(self.mu as $kind)
}
}
impl HasSuffStat<$kind> for Gaussian {
type Stat = GaussianSuffStat;
fn empty_suffstat(&self) -> Self::Stat {
GaussianSuffStat::new()
}
}
};
}
impl Variance<f64> for Gaussian {
fn variance(&self) -> Option<f64> {
Some(self.sigma * self.sigma)
}
}
impl Entropy for Gaussian {
fn entropy(&self) -> f64 {
HALF_LN_2PI_E + self.ln_sigma()
}
}
impl Skewness for Gaussian {
fn skewness(&self) -> Option<f64> {
Some(0.0)
}
}
impl Kurtosis for Gaussian {
fn kurtosis(&self) -> Option<f64> {
Some(0.0)
}
}
impl KlDivergence for Gaussian {
#[allow(clippy::clippy::suspicious_operation_groupings)]
fn kl(&self, other: &Self) -> f64 {
let m1 = self.mu;
let m2 = other.mu;
let s1 = self.sigma;
let s2 = other.sigma;
let term1 = s2.ln() - s1.ln();
let term2 = (s1 * s1 + (m1 - m2) * (m1 - m2)) / (2.0 * s2 * s2);
term1 + term2 - 0.5
}
}
impl QuadBounds for Gaussian {
fn quad_bounds(&self) -> (f64, f64) {
self.interval(0.99999999999)
}
}
impl_traits!(f32);
impl_traits!(f64);
impl std::error::Error for GaussianError {}
impl fmt::Display for GaussianError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::MuNotFinite { mu } => write!(f, "non-finite mu: {}", mu),
Self::SigmaTooLow { sigma } => {
write!(f, "sigma ({}) must be greater than zero", sigma)
}
Self::SigmaNotFinite { sigma } => {
write!(f, "non-finite sigma: {}", sigma)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_basic_impls;
use std::f64;
const TOL: f64 = 1E-12;
test_basic_impls!([continuous] Gaussian::standard());
#[test]
fn new() {
let gauss = Gaussian::new(1.2, 3.0).unwrap();
assert::close(gauss.mu, 1.2, TOL);
assert::close(gauss.sigma, 3.0, TOL);
}
#[test]
fn standard() {
let gauss = Gaussian::standard();
assert::close(gauss.mu, 0.0, TOL);
assert::close(gauss.sigma, 1.0, TOL);
}
#[test]
fn standard_gaussian_mean_should_be_zero() {
let mean: f64 = Gaussian::standard().mean().unwrap();
assert::close(mean, 0.0, TOL);
}
#[test]
fn standard_gaussian_variance_should_be_one() {
assert::close(Gaussian::standard().variance().unwrap(), 1.0, TOL);
}
#[test]
fn mean_should_be_mu() {
let mu = 3.4;
let mean: f64 = Gaussian::new(mu, 0.5).unwrap().mean().unwrap();
assert::close(mean, mu, TOL);
}
#[test]
fn median_should_be_mu() {
let mu = 3.4;
let median: f64 = Gaussian::new(mu, 0.5).unwrap().median().unwrap();
assert::close(median, mu, TOL);
}
#[test]
fn mode_should_be_mu() {
let mu = 3.4;
let mode: f64 = Gaussian::new(mu, 0.5).unwrap().mode().unwrap();
assert::close(mode, mu, TOL);
}
#[test]
fn variance_should_be_sigma_squared() {
let sigma = 0.5;
let gauss = Gaussian::new(3.4, sigma).unwrap();
assert::close(gauss.variance().unwrap(), sigma * sigma, TOL);
}
#[test]
fn draws_should_be_finite() {
let mut rng = rand::thread_rng();
let gauss = Gaussian::standard();
for _ in 0..100 {
let x: f64 = gauss.draw(&mut rng);
assert!(x.is_finite())
}
}
#[test]
fn sample_length() {
let mut rng = rand::thread_rng();
let gauss = Gaussian::standard();
let xs: Vec<f64> = gauss.sample(10, &mut rng);
assert_eq!(xs.len(), 10);
}
#[test]
fn standard_ln_pdf_at_zero() {
let gauss = Gaussian::standard();
assert::close(gauss.ln_pdf(&0.0_f64), -0.918_938_533_204_672_7, TOL);
}
#[test]
fn standard_ln_pdf_off_zero() {
let gauss = Gaussian::standard();
assert::close(gauss.ln_pdf(&2.1_f64), -3.1239385332046727, TOL);
}
#[test]
fn nonstandard_ln_pdf_on_mean() {
let gauss = Gaussian::new(-1.2, 0.33).unwrap();
assert::close(gauss.ln_pdf(&-1.2_f64), 0.18972409131693846, TOL);
}
#[test]
fn nonstandard_ln_pdf_off_mean() {
let gauss = Gaussian::new(-1.2, 0.33).unwrap();
assert::close(gauss.ln_pdf(&0.0_f32), -6.421_846_156_616_945, TOL);
}
#[test]
fn should_contain_finite_values() {
let gauss = Gaussian::standard();
assert!(gauss.supports(&0.0_f32));
assert!(gauss.supports(&10E8_f64));
assert!(gauss.supports(&-10E8_f64));
}
#[test]
fn should_not_contain_nan() {
let gauss = Gaussian::standard();
assert!(!gauss.supports(&f64::NAN));
}
#[test]
fn should_not_contain_positive_or_negative_infinity() {
let gauss = Gaussian::standard();
assert!(!gauss.supports(&f64::INFINITY));
assert!(!gauss.supports(&f64::NEG_INFINITY));
}
#[test]
fn skewness_should_be_zero() {
let gauss = Gaussian::new(-12.3, 45.6).unwrap();
assert::close(gauss.skewness().unwrap(), 0.0, TOL);
}
#[test]
fn kurtosis_should_be_zero() {
let gauss = Gaussian::new(-12.3, 45.6).unwrap();
assert::close(gauss.skewness().unwrap(), 0.0, TOL);
}
#[test]
fn cdf_at_mean_should_be_one_half() {
let mu1: f64 = 2.3;
let gauss1 = Gaussian::new(mu1, 0.2).unwrap();
assert::close(gauss1.cdf(&mu1), 0.5, TOL);
let mu2: f32 = -8.0;
let gauss2 = Gaussian::new(mu2.into(), 100.0).unwrap();
assert::close(gauss2.cdf(&mu2), 0.5, TOL);
}
#[test]
fn cdf_value_at_one() {
let gauss = Gaussian::standard();
assert::close(gauss.cdf(&1.0_f64), 0.841_344_746_068_542_9, TOL);
}
#[test]
fn cdf_value_at_neg_two() {
let gauss = Gaussian::standard();
assert::close(gauss.cdf(&-2.0_f64), 0.022750131948179195, TOL);
}
#[test]
fn quantile_at_one_half_should_be_mu() {
let mu = 1.2315;
let gauss = Gaussian::new(mu, 1.0).unwrap();
let x: f64 = gauss.quantile(0.5);
assert::close(x, mu, TOL);
}
#[test]
fn quantile_agree_with_cdf() {
let mut rng = rand::thread_rng();
let gauss = Gaussian::standard();
let xs: Vec<f64> = gauss.sample(100, &mut rng);
xs.iter().for_each(|x| {
let p = gauss.cdf(x);
let y: f64 = gauss.quantile(p);
assert::close(y, *x, TOL);
})
}
#[test]
fn quad_on_pdf_agrees_with_cdf_x() {
use crate::misc::quad_eps;
let ig = Gaussian::new(-2.3, 0.5).unwrap();
let pdf = |x: f64| ig.f(&x);
let mut rng = rand::thread_rng();
for _ in 0..100 {
let x: f64 = ig.draw(&mut rng);
let res = quad_eps(pdf, -10.0, x, Some(1e-13));
let cdf = ig.cdf(&x);
assert::close(res, cdf, 1e-7);
}
}
#[test]
fn standard_gaussian_entropy() {
let gauss = Gaussian::standard();
assert::close(gauss.entropy(), 1.4189385332046727, TOL);
}
#[test]
fn entropy() {
let gauss = Gaussian::new(3.0, 12.3).unwrap();
assert::close(gauss.entropy(), 3.9285377955830447, TOL);
}
#[test]
fn kl_of_idential_dsitrbutions_should_be_zero() {
let gauss = Gaussian::new(1.2, 3.4).unwrap();
assert::close(gauss.kl(&gauss), 0.0, TOL);
}
#[test]
fn kl() {
let g1 = Gaussian::new(1.0, 2.0).unwrap();
let g2 = Gaussian::new(2.0, 1.0).unwrap();
let kl = 0.5_f64.ln() + 5.0 / 2.0 - 0.5;
assert::close(g1.kl(&g2), kl, TOL);
}
#[test]
fn ln_f_after_set_mu_works() {
let mut gauss = Gaussian::standard();
assert::close(gauss.ln_pdf(&0.0_f64), -0.918_938_533_204_672_7, TOL);
gauss.set_mu(1.0).unwrap();
assert::close(gauss.ln_pdf(&1.0_f64), -0.918_938_533_204_672_7, TOL);
}
#[test]
fn ln_f_after_set_sigm_works() {
let mut gauss = Gaussian::new(-1.2, 5.0).unwrap();
gauss.set_sigma(0.33).unwrap();
assert::close(gauss.ln_pdf(&-1.2_f64), 0.18972409131693846, TOL);
assert::close(gauss.ln_pdf(&0.0_f32), -6.421_846_156_616_945, TOL);
}
}