#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use crate::data::{CategoricalDatum, CategoricalSuffStat};
use crate::impl_display;
use crate::misc::{argmax, ln_pflip, logsumexp, vec_to_string};
use crate::traits::*;
use rand::Rng;
use std::fmt;
#[derive(Debug, Clone, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Categorical {
ln_weights: Vec<f64>,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum CategoricalError {
NonFiniteWeight { ix: usize, ln: bool, weight: f64 },
NegativeWeight { ix: usize, weight: f64 },
WeightsDoNotSumToOne { ln: bool, sum: f64 },
EmptyWights,
}
impl Categorical {
pub fn new(weights: &[f64]) -> Result<Self, CategoricalError> {
if weights.is_empty() {
return Err(CategoricalError::EmptyWights);
}
weights.iter().enumerate().try_for_each(|(ix, &weight)| {
if weight < 0.0 {
Err(CategoricalError::NegativeWeight { ix, weight })
} else if !weight.is_finite() {
Err(CategoricalError::NonFiniteWeight {
ix,
ln: false,
weight,
})
} else {
Ok(())
}
})?;
let ln_weights: Vec<f64> = weights.iter().map(|w| w.ln()).collect();
let ln_norm = logsumexp(&ln_weights);
let normed_weights =
ln_weights.iter().map(|lnw| lnw - ln_norm).collect();
Ok(Categorical::new_unchecked(normed_weights))
}
pub fn from_ln_weights(
ln_weights: Vec<f64>,
) -> Result<Self, CategoricalError> {
if ln_weights.is_empty() {
return Err(CategoricalError::EmptyWights);
}
ln_weights
.iter()
.enumerate()
.try_for_each(|(ix, &weight)| {
if weight.is_finite() {
Ok(())
} else {
Err(CategoricalError::NonFiniteWeight {
ix,
ln: false,
weight,
})
}
})?;
let sum = logsumexp(&ln_weights).abs();
if sum < 10E-12 {
Ok(Categorical { ln_weights })
} else {
Err(CategoricalError::WeightsDoNotSumToOne { ln: true, sum })
}
}
#[inline]
pub fn new_unchecked(ln_weights: Vec<f64>) -> Self {
Categorical { ln_weights }
}
#[inline]
pub fn uniform(k: usize) -> Self {
let lnp = (1.0 / k as f64).ln();
Categorical::new_unchecked(vec![lnp; k])
}
#[inline]
pub fn weights(&self) -> Vec<f64> {
self.ln_weights.iter().map(|&w| w.exp()).collect()
}
#[inline]
pub fn k(&self) -> usize {
self.ln_weights.len()
}
#[inline]
pub fn ln_weights(&self) -> &Vec<f64> {
&self.ln_weights
}
}
impl From<&Categorical> for String {
fn from(cat: &Categorical) -> String {
let weights = vec_to_string(&cat.weights(), 5);
format!("Categorical({};, {})", cat.k(), weights)
}
}
impl_display!(Categorical);
impl<X: CategoricalDatum> Rv<X> for Categorical {
fn ln_f(&self, x: &X) -> f64 {
let ix: usize = x.into_usize();
self.ln_weights[ix]
}
fn draw<R: Rng>(&self, mut rng: &mut R) -> X {
let ix = ln_pflip(&self.ln_weights, 1, true, &mut rng)[0];
CategoricalDatum::from_usize(ix)
}
fn sample<R: Rng>(&self, n: usize, mut rng: &mut R) -> Vec<X> {
ln_pflip(&self.ln_weights, n, true, &mut rng)
.iter()
.map(|&ix| CategoricalDatum::from_usize(ix))
.collect()
}
}
impl<X: CategoricalDatum> Support<X> for Categorical {
fn supports(&self, x: &X) -> bool {
let ix: usize = x.into_usize();
ix < self.ln_weights.len()
}
}
impl<X: CategoricalDatum> DiscreteDistr<X> for Categorical {}
impl<X: CategoricalDatum> Cdf<X> for Categorical {
fn cdf(&self, x: &X) -> f64 {
let xu: usize = x.into_usize();
self.ln_weights
.iter()
.take(xu + 1)
.fold(0.0, |acc, &w| w.exp() + acc)
}
}
impl<X: CategoricalDatum> Mode<X> for Categorical {
fn mode(&self) -> Option<X> {
let max_ixs = argmax(&self.ln_weights);
if max_ixs.len() > 1 {
None
} else {
Some(CategoricalDatum::from_usize(max_ixs[0]))
}
}
}
impl Entropy for Categorical {
fn entropy(&self) -> f64 {
self.ln_weights
.iter()
.fold(0.0, |acc, ln_weight| acc - ln_weight.exp() * ln_weight)
}
}
impl<X: CategoricalDatum> HasSuffStat<X> for Categorical {
type Stat = CategoricalSuffStat;
fn empty_suffstat(&self) -> Self::Stat {
CategoricalSuffStat::new(self.k())
}
}
impl KlDivergence for Categorical {
fn kl(&self, other: &Self) -> f64 {
self.ln_weights
.iter()
.zip(other.ln_weights.iter())
.fold(0.0, |acc, (&ws, &wo)| acc + ws.exp() * (ws - wo))
}
}
impl fmt::Display for CategoricalError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NonFiniteWeight { ix, ln, weight } if *ln => {
write!(f, "non-finite ln weight at index {}: {}", ix, weight)
}
Self::NonFiniteWeight { ix, weight, .. } => {
write!(f, "non-finite weight at index {}: {}", ix, weight)
}
Self::NegativeWeight { ix, weight } => {
write!(f, "negative weight at index {}: {}", ix, weight)
}
Self::WeightsDoNotSumToOne { ln, sum } if *ln => {
write!(f, "ln weights sum to {}, should sum to zero", sum)
}
Self::WeightsDoNotSumToOne { sum, .. } => {
write!(f, "weights sum to {}, should sum to one", sum)
}
Self::EmptyWights => write!(f, "empty weights vector"),
}
}
}
impl std::error::Error for CategoricalError {}
#[cfg(test)]
mod tests {
use super::*;
use crate::misc::x2_test;
use crate::test_basic_impls;
const TOL: f64 = 1E-12;
const N_TRIES: usize = 5;
const X2_PVAL: f64 = 0.2;
test_basic_impls!([categorical] Categorical::uniform(3));
#[test]
fn ln_weights_should_logsumexp_to_1() {
let weights: Vec<f64> = vec![2.0, 1.0, 2.0, 3.0, 1.0];
let cat = Categorical::new(&weights).unwrap();
assert::close(logsumexp(&cat.ln_weights), 0.0, TOL);
}
#[test]
fn ln_weights_unifor_should_logsumexp_to_1() {
let cat = Categorical::uniform(5);
let ln_weight = (1_f64 / 5.0).ln();
cat.ln_weights
.iter()
.for_each(|&ln_w| assert::close(ln_w, ln_weight, TOL));
assert::close(logsumexp(&cat.ln_weights), 0.0, TOL);
}
#[test]
fn ln_f_should_be_ln_weight() {
let cat = Categorical::new(&[2.0, 1.0, 2.0, 4.0, 3.0]).unwrap();
assert::close(cat.ln_f(&0_u8), -1.791759469228055, TOL);
assert::close(cat.ln_f(&1_u8), -2.4849066497880004, TOL);
assert::close(cat.ln_f(&2_u8), -1.791759469228055, TOL);
assert::close(cat.ln_f(&3_u8), -1.0986122886681098, TOL);
assert::close(cat.ln_f(&4_u8), -1.3862943611198906, TOL);
}
#[test]
fn ln_pmf_should_be_ln_weight() {
let cat = Categorical::new(&[2.0, 1.0, 2.0, 4.0, 3.0]).unwrap();
assert::close(cat.ln_pmf(&0_u16), -1.791759469228055, TOL);
assert::close(cat.ln_pmf(&1_u16), -2.4849066497880004, TOL);
assert::close(cat.ln_pmf(&2_u16), -1.791759469228055, TOL);
assert::close(cat.ln_pmf(&3_u16), -1.0986122886681098, TOL);
assert::close(cat.ln_pmf(&4_u16), -1.3862943611198906, TOL);
}
#[test]
fn draw_should_return_numbers_in_0_to_k() {
let mut rng = rand::thread_rng();
let k = 5;
let cat = Categorical::uniform(k);
let mut counts = vec![0; k];
for _ in 0..1000 {
let ix: usize = cat.draw(&mut rng);
counts[ix] += 1;
assert!(ix < 5);
}
assert!(counts.iter().all(|&ct| ct > 0));
}
#[test]
fn sample_should_return_the_correct_number_of_draws() {
let mut rng = rand::thread_rng();
let cat = Categorical::uniform(5);
let xs: Vec<u8> = cat.sample(103, &mut rng);
assert_eq!(xs.len(), 103);
}
#[test]
fn should_contain_zero_to_one_minus_k() {
let k = 3;
let cat = Categorical::uniform(k);
assert!(cat.supports(&0_usize));
assert!(cat.supports(&1_usize));
assert!(cat.supports(&2_usize));
assert!(!cat.supports(&3_usize));
}
#[test]
fn uniform_mode_does_not_exist() {
let mode: Option<u8> = Categorical::uniform(4).mode();
assert!(mode.is_none());
}
#[test]
fn mode() {
let cat = Categorical::new(&[1.0, 2.0, 3.0, 1.0]).unwrap();
let mode: usize = cat.mode().unwrap();
assert_eq!(mode, 2);
}
#[test]
fn draw_test() {
let mut rng = rand::thread_rng();
let cat = Categorical::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let ps: Vec<f64> = vec![0.1, 0.2, 0.3, 0.4];
let passes = (0..N_TRIES).fold(0, |acc, _| {
let mut f_obs: Vec<u32> = vec![0; 4];
let xs: Vec<usize> = cat.sample(1000, &mut rng);
xs.iter().for_each(|&x| f_obs[x] += 1);
let (_, p) = x2_test(&f_obs, &ps);
if p > X2_PVAL {
acc + 1
} else {
acc
}
});
assert!(passes > 0);
}
#[test]
fn kl() {
let cat1 = Categorical::new(&[
0.2280317, 0.1506706, 0.33620052, 0.13911904, 0.14597815,
])
.unwrap();
let cat2 = Categorical::new(&[
0.30050657, 0.04237857, 0.20973238, 0.32858568, 0.1187968,
])
.unwrap();
assert::close(cat1.kl(&cat2), 0.1973394327976612, 1E-7);
assert::close(cat2.kl(&cat1), 0.18814408198625582, 1E-7);
}
#[test]
fn cdf() {
let cat = Categorical::new(&[1.0, 2.0, 4.0, 3.0]).unwrap();
assert::close(cat.cdf(&0_u8), 0.1, TOL);
assert::close(cat.cdf(&1_u8), 0.3, TOL);
assert::close(cat.cdf(&2_u8), 0.7, TOL);
assert::close(cat.cdf(&3_u8), 1.0, TOL);
}
}