#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use crate::consts::LN_2PI_E;
use crate::data::PoissonSuffStat;
use crate::misc::ln_fact;
use crate::traits::*;
use crate::{clone_cache_f64, impl_display};
use once_cell::sync::OnceCell;
use rand::Rng;
use rand_distr::Poisson as RPossion;
use special::Gamma as _;
use std::fmt;
#[derive(Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Poisson {
rate: f64,
#[cfg_attr(feature = "serde1", serde(skip))]
ln_rate: OnceCell<f64>,
}
impl Clone for Poisson {
fn clone(&self) -> Self {
Poisson {
rate: self.rate,
ln_rate: clone_cache_f64!(self, ln_rate),
}
}
}
impl PartialEq for Poisson {
fn eq(&self, other: &Poisson) -> bool {
self.rate == other.rate
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum PoissonError {
RateTooLow { rate: f64 },
RateNotFinite { rate: f64 },
}
impl Poisson {
#[inline]
pub fn new(rate: f64) -> Result<Self, PoissonError> {
if rate <= 0.0 {
Err(PoissonError::RateTooLow { rate })
} else if !rate.is_finite() {
Err(PoissonError::RateNotFinite { rate })
} else {
Ok(Self::new_unchecked(rate))
}
}
#[inline]
pub fn new_unchecked(rate: f64) -> Self {
Poisson {
rate,
ln_rate: OnceCell::new(),
}
}
#[inline]
fn ln_rate(&self) -> f64 {
*self.ln_rate.get_or_init(|| self.rate.ln())
}
#[inline]
pub fn rate(&self) -> f64 {
self.rate
}
#[inline]
pub fn set_rate(&mut self, rate: f64) -> Result<(), PoissonError> {
if rate <= 0.0 {
Err(PoissonError::RateTooLow { rate })
} else if !rate.is_finite() {
Err(PoissonError::RateNotFinite { rate })
} else {
self.set_rate_unchecked(rate);
Ok(())
}
}
#[inline]
pub fn set_rate_unchecked(&mut self, rate: f64) {
self.rate = rate;
self.ln_rate = OnceCell::new();
}
}
impl From<&Poisson> for String {
fn from(pois: &Poisson) -> String {
format!("Poisson(λ: {})", pois.rate)
}
}
impl_display!(Poisson);
macro_rules! impl_traits {
($kind:ty) => {
impl Rv<$kind> for Poisson {
fn ln_f(&self, x: &$kind) -> f64 {
let kf = f64::from(*x);
kf * self.ln_rate() - self.rate - ln_fact(*x as usize)
}
fn draw<R: Rng>(&self, rng: &mut R) -> $kind {
let pois = RPossion::new(self.rate).unwrap();
let x: u64 = rng.sample(pois) as u64;
x as $kind
}
fn sample<R: Rng>(&self, n: usize, rng: &mut R) -> Vec<$kind> {
let pois = RPossion::new(self.rate).unwrap();
(0..n)
.map(|_| {
let x: u64 = rng.sample(pois) as u64;
x as $kind
})
.collect()
}
}
impl Support<$kind> for Poisson {
#[allow(unused_comparisons)]
fn supports(&self, x: &$kind) -> bool {
*x >= 0
}
}
impl DiscreteDistr<$kind> for Poisson {}
impl Cdf<$kind> for Poisson {
fn cdf(&self, x: &$kind) -> f64 {
let kf = f64::from(*x);
1.0 - (self.rate).inc_gamma(kf + 1.0)
}
}
impl HasSuffStat<$kind> for Poisson {
type Stat = PoissonSuffStat;
fn empty_suffstat(&self) -> Self::Stat {
PoissonSuffStat::new()
}
}
impl Mode<($kind, $kind)> for Poisson {
fn mode(&self) -> Option<($kind, $kind)> {
let left = self.rate.ceil() as $kind - 1;
let right = self.rate.floor() as $kind;
Some((left, right))
}
}
impl Mode<$kind> for Poisson {
fn mode(&self) -> Option<$kind> {
Some(self.rate.ceil() as $kind - 1)
}
}
};
}
impl Mean<f64> for Poisson {
fn mean(&self) -> Option<f64> {
Some(self.rate)
}
}
impl Variance<f64> for Poisson {
fn variance(&self) -> Option<f64> {
Some(self.rate)
}
}
impl Skewness for Poisson {
fn skewness(&self) -> Option<f64> {
Some(self.rate.sqrt().recip())
}
}
impl Kurtosis for Poisson {
fn kurtosis(&self) -> Option<f64> {
Some(self.rate.recip())
}
}
impl KlDivergence for Poisson {
fn kl(&self, other: &Poisson) -> f64 {
self.rate() * (self.ln_rate() - other.ln_rate()) + other.rate()
- self.rate()
}
}
impl Entropy for Poisson {
fn entropy(&self) -> f64 {
if self.rate() < 200.0 {
let mid = self.rate().floor() as u32;
crate::misc::entropy::count_entropy(&self, mid)
} else {
(0.5) * (LN_2PI_E + self.ln_rate())
- (12.0 * self.rate()).recip()
- (24.0 * self.rate().powi(2)).recip()
- 19.0 * (360.0 * self.rate().powi(3)).recip()
}
}
}
impl_traits!(u8);
impl_traits!(u16);
impl_traits!(u32);
impl std::error::Error for PoissonError {}
impl fmt::Display for PoissonError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::RateTooLow { rate } => {
write!(f, "rate ({}) must be greater than zero", rate)
}
Self::RateNotFinite { rate } => {
write!(f, "non-finite rate: {}", rate)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::misc::x2_test;
use crate::test_basic_impls;
use std::f64;
const TOL: f64 = 1E-12;
const N_TRIES: usize = 5;
const X2_PVAL: f64 = 0.2;
fn brute_force_kl(fx: &Poisson, fy: &Poisson, x_max: u32) -> f64 {
(0..=x_max)
.map(|x| {
let lnfx = fx.ln_f(&x);
let lnfy = fy.ln_f(&x);
lnfx.exp() * (lnfx - lnfy)
})
.sum()
}
test_basic_impls!([count] Poisson::new(0.5).unwrap());
#[test]
fn new() {
assert::close(Poisson::new(0.001).unwrap().rate, 0.001, TOL);
assert::close(Poisson::new(1.234).unwrap().rate, 1.234, TOL);
}
#[test]
fn new_should_reject_non_finite_rate() {
assert!(Poisson::new(f64::INFINITY).is_err());
assert!(Poisson::new(f64::NAN).is_err());
}
#[test]
fn new_should_reject_rate_lteq_zero() {
assert!(Poisson::new(0.0).is_err());
assert!(Poisson::new(-1E-12).is_err());
assert!(Poisson::new(-1E12).is_err());
assert!(Poisson::new(f64::NEG_INFINITY).is_err());
}
#[test]
fn ln_pmf() {
let pois = Poisson::new(5.3).unwrap();
assert::close(pois.ln_pmf(&1_u32), -3.6322931794419238, TOL);
assert::close(pois.ln_pmf(&5_u32), -1.7489576399916658, TOL);
assert::close(pois.ln_pmf(&11_u32), -4.457_532_819_735_049, TOL);
}
#[test]
fn pmf_preserved_after_rate_set_reset() {
let x: u32 = 3;
let mut pois = Poisson::new(5.3).unwrap();
let pmf_1 = pois.pmf(&x);
let ln_pmf_1 = pois.ln_pmf(&x);
pois.set_rate(1.2).unwrap();
assert!((pmf_1 - pois.pmf(&x)).abs() > 1e-4);
assert!((ln_pmf_1 - pois.ln_pmf(&x)).abs() > 1e-4);
pois.set_rate(5.3).unwrap();
assert_eq!(pmf_1, pois.pmf(&x));
assert_eq!(ln_pmf_1, pois.ln_pmf(&x));
}
#[test]
fn cdf_low() {
let pois = Poisson::new(5.3).unwrap();
assert::close(pois.cdf(&1_u32), 0.031447041613534364, TOL);
}
#[test]
fn cdf_mid() {
let pois = Poisson::new(5.3).unwrap();
assert::close(pois.cdf(&5_u32), 0.563_473_392_288_071_7, TOL);
}
#[test]
fn cdf_high() {
let pois = Poisson::new(5.3).unwrap();
assert::close(pois.cdf(&15_u32), 0.999_866_999_508_350_3, TOL);
}
#[test]
fn mean() {
let m1 = Poisson::new(1.5).unwrap().mean().unwrap();
assert::close(m1, 1.5, TOL);
let m2 = Poisson::new(33.2).unwrap().mean().unwrap();
assert::close(m2, 33.2, TOL);
}
#[test]
fn variance() {
let v1 = Poisson::new(1.5).unwrap().variance().unwrap();
assert::close(v1, 1.5, TOL);
let v2 = Poisson::new(33.2).unwrap().variance().unwrap();
assert::close(v2, 33.2, TOL);
}
#[test]
fn skewness() {
let s = Poisson::new(5.3).unwrap().skewness().unwrap();
assert::close(s, 0.4343722427630694, TOL);
}
#[test]
fn kurtosis() {
let k = Poisson::new(5.3).unwrap().kurtosis().unwrap();
assert::close(k, 0.18867924528301888, TOL);
}
#[test]
fn draw_test() {
let mut rng = rand::thread_rng();
let pois = Poisson::new(2.0).unwrap();
let k: usize = (0..100)
.position(|x| pois.pmf(&(x as u32)) < f64::EPSILON)
.unwrap_or(99)
+ 1;
let ps: Vec<f64> = (0..k).map(|x| pois.pmf(&(x as u32))).collect();
let passes = (0..N_TRIES).fold(0, |acc, _| {
let mut f_obs: Vec<u32> = vec![0; k];
let xs: Vec<u32> = pois.sample(1000, &mut rng);
xs.iter().for_each(|&x| f_obs[x as usize] += 1);
let (_, p) = x2_test(&f_obs, &ps);
if p > X2_PVAL {
acc + 1
} else {
acc
}
});
assert!(passes > 0);
}
#[test]
fn kl_divergence_vs_brute() {
let prior = crate::dist::Gamma::new(1.0, 1.0).unwrap();
let mut rng = rand::thread_rng();
for _ in 0..10 {
let pois_x: Poisson = prior.draw(&mut rng);
let pois_y: Poisson = prior.draw(&mut rng);
let kl_true = pois_x.kl(&pois_y);
let kl_est = brute_force_kl(&pois_x, &pois_y, 1_000);
assert::close(kl_true, kl_est, TOL);
}
}
#[test]
fn entropy_value_checks() {
let rates = vec![0.1, 0.5, 1.0, 2.2, 3.4, 10.2, 131.4];
let hs = vec![
0.3336769965012327,
0.9276374674957975,
1.3048422422562516,
1.758957749331246,
1.9995315141091008,
2.571495552115918,
3.857424953514813,
];
rates.iter().zip(hs.iter()).for_each(|(rate, h)| {
let pois = Poisson::new(*rate).unwrap();
assert::close(*h, pois.entropy(), TOL);
})
}
#[test]
fn mode_value_checks() {
{
let pois = Poisson::new(2.0).unwrap();
let mode: (u32, u32) = pois.mode().unwrap();
assert_eq!(mode, (1, 2));
}
{
let pois = Poisson::new(2.1).unwrap();
let mode: (u32, u32) = pois.mode().unwrap();
assert_eq!(mode, (2, 2));
}
{
let pois = Poisson::new(2.1).unwrap();
let mode: u32 = pois.mode().unwrap();
assert_eq!(mode, 2);
}
{
let pois = Poisson::new(2.0).unwrap();
let mode: u32 = pois.mode().unwrap();
assert_eq!(mode, 1);
}
}
}