#[cfg(feature = "serde_support")]
use serde_derive::{Deserialize, Serialize};
use crate::data::CategoricalDatum;
use crate::data::DataOrSuffStat;
use crate::dist::{Bernoulli, Categorical, Gaussian, Poisson};
use crate::traits::SuffStat;
use nalgebra::{DMatrix, DVector};
use special::Gamma as SGamma;
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
pub struct BernoulliSuffStat {
n: usize,
k: usize,
}
impl BernoulliSuffStat {
pub fn new() -> Self {
BernoulliSuffStat { n: 0, k: 0 }
}
pub fn n(&self) -> usize {
self.n
}
pub fn k(&self) -> usize {
self.k
}
}
impl Default for BernoulliSuffStat {
fn default() -> Self {
BernoulliSuffStat::new()
}
}
impl SuffStat<bool> for BernoulliSuffStat {
fn n(&self) -> usize {
self.n
}
fn observe(&mut self, x: &bool) {
self.n += 1;
if *x {
self.k += 1
}
}
fn forget(&mut self, x: &bool) {
self.n -= 1;
if *x {
self.k -= 1
}
}
}
impl<'a> Into<DataOrSuffStat<'a, bool, Bernoulli>> for &'a BernoulliSuffStat {
fn into(self) -> DataOrSuffStat<'a, bool, Bernoulli> {
DataOrSuffStat::SuffStat(self)
}
}
impl<'a> Into<DataOrSuffStat<'a, bool, Bernoulli>> for &'a Vec<bool> {
fn into(self) -> DataOrSuffStat<'a, bool, Bernoulli> {
DataOrSuffStat::Data(self)
}
}
macro_rules! impl_bernoulli_suffstat {
($kind:ty) => {
impl<'a> Into<DataOrSuffStat<'a, $kind, Bernoulli>>
for &'a BernoulliSuffStat
{
fn into(self) -> DataOrSuffStat<'a, $kind, Bernoulli> {
DataOrSuffStat::SuffStat(self)
}
}
impl<'a> Into<DataOrSuffStat<'a, $kind, Bernoulli>> for &'a Vec<$kind> {
fn into(self) -> DataOrSuffStat<'a, $kind, Bernoulli> {
DataOrSuffStat::Data(self)
}
}
impl SuffStat<$kind> for BernoulliSuffStat {
fn n(&self) -> usize {
self.n
}
fn observe(&mut self, x: &$kind) {
self.n += 1;
if *x == 1 {
self.k += 1
}
}
fn forget(&mut self, x: &$kind) {
self.n -= 1;
if *x == 1 {
self.k -= 1
}
}
}
};
}
impl_bernoulli_suffstat!(u8);
impl_bernoulli_suffstat!(u16);
impl_bernoulli_suffstat!(u32);
impl_bernoulli_suffstat!(u64);
impl_bernoulli_suffstat!(usize);
impl_bernoulli_suffstat!(i8);
impl_bernoulli_suffstat!(i16);
impl_bernoulli_suffstat!(i32);
impl_bernoulli_suffstat!(i64);
impl_bernoulli_suffstat!(isize);
#[derive(Debug, Clone, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
pub struct CategoricalSuffStat {
n: usize,
counts: Vec<f64>,
}
impl CategoricalSuffStat {
pub fn new(k: usize) -> Self {
CategoricalSuffStat {
n: 0,
counts: vec![0.0; k],
}
}
pub fn n(&self) -> usize {
self.n
}
pub fn counts(&self) -> &Vec<f64> {
&self.counts
}
}
impl<'a, X> Into<DataOrSuffStat<'a, X, Categorical>> for &'a CategoricalSuffStat
where
X: CategoricalDatum,
{
fn into(self) -> DataOrSuffStat<'a, X, Categorical> {
DataOrSuffStat::SuffStat(self)
}
}
macro_rules! impl_into_dos_for_categorical {
($kind: ty) => {
impl<'a> Into<DataOrSuffStat<'a, $kind, Categorical>>
for &'a Vec<$kind>
{
fn into(self) -> DataOrSuffStat<'a, $kind, Categorical> {
DataOrSuffStat::Data(self)
}
}
};
}
impl_into_dos_for_categorical!(bool);
impl_into_dos_for_categorical!(usize);
impl_into_dos_for_categorical!(u8);
impl_into_dos_for_categorical!(u16);
impl_into_dos_for_categorical!(u32);
impl<X: CategoricalDatum> SuffStat<X> for CategoricalSuffStat {
fn n(&self) -> usize {
self.n
}
fn observe(&mut self, x: &X) {
let ix = x.into_usize();
self.n += 1;
self.counts[ix] += 1.0;
}
fn forget(&mut self, x: &X) {
let ix = x.into_usize();
self.n -= 1;
self.counts[ix] -= 1.0;
}
}
#[derive(Debug, Clone, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
pub struct PoissonSuffStat {
n: usize,
sum: f64,
sum_log_fact: f64,
}
impl PoissonSuffStat {
pub fn new() -> Self {
Self {
n: 0,
sum: 0.0,
sum_log_fact: 0.0,
}
}
pub fn n(&self) -> usize {
self.n
}
pub fn sum(&self) -> f64 {
self.sum
}
pub fn sum_log_fact(&self) -> f64 {
self.sum_log_fact
}
}
impl Default for PoissonSuffStat {
fn default() -> Self {
Self::new()
}
}
macro_rules! impl_poisson_suffstat {
($kind:ty) => {
impl<'a> Into<DataOrSuffStat<'a, $kind, Poisson>>
for &'a PoissonSuffStat
{
fn into(self) -> DataOrSuffStat<'a, $kind, Poisson> {
DataOrSuffStat::SuffStat(self)
}
}
impl<'a> Into<DataOrSuffStat<'a, $kind, Poisson>> for &'a Vec<$kind> {
fn into(self) -> DataOrSuffStat<'a, $kind, Poisson> {
DataOrSuffStat::Data(self)
}
}
impl SuffStat<$kind> for PoissonSuffStat {
fn n(&self) -> usize {
self.n
}
fn observe(&mut self, x: &$kind) {
let xf = f64::from(*x);
self.n += 1;
self.sum += xf;
self.sum_log_fact += f64::from(*x + 1).ln_gamma().0;
}
fn forget(&mut self, x: &$kind) {
if self.n > 1 {
let xf = f64::from(*x);
self.n -= 1;
self.sum -= xf;
self.sum_log_fact -= f64::from(*x + 1).ln_gamma().0;
} else {
self.n = 0;
self.sum = 0.0;
self.sum_log_fact = 0.0;
}
}
}
};
}
impl_poisson_suffstat!(u8);
impl_poisson_suffstat!(u16);
impl_poisson_suffstat!(u32);
#[derive(Debug, Clone, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
pub struct GaussianSuffStat {
n: usize,
sum_x: f64,
sum_x_sq: f64,
}
impl GaussianSuffStat {
pub fn new() -> Self {
GaussianSuffStat {
n: 0,
sum_x: 0.0,
sum_x_sq: 0.0,
}
}
pub fn n(&self) -> usize {
self.n
}
pub fn sum_x(&self) -> f64 {
self.sum_x
}
pub fn sum_x_sq(&self) -> f64 {
self.sum_x_sq
}
}
impl Default for GaussianSuffStat {
fn default() -> Self {
GaussianSuffStat::new()
}
}
macro_rules! impl_gaussian_suffstat {
($kind:ty) => {
impl<'a> Into<DataOrSuffStat<'a, $kind, Gaussian>>
for &'a GaussianSuffStat
{
fn into(self) -> DataOrSuffStat<'a, $kind, Gaussian> {
DataOrSuffStat::SuffStat(self)
}
}
impl<'a> Into<DataOrSuffStat<'a, $kind, Gaussian>> for &'a Vec<$kind> {
fn into(self) -> DataOrSuffStat<'a, $kind, Gaussian> {
DataOrSuffStat::Data(self)
}
}
impl SuffStat<$kind> for GaussianSuffStat {
fn n(&self) -> usize {
self.n
}
fn observe(&mut self, x: &$kind) {
let xf = f64::from(*x);
self.n += 1;
self.sum_x += xf;
self.sum_x_sq += xf.powi(2);
}
fn forget(&mut self, x: &$kind) {
if self.n > 1 {
let xf = f64::from(*x);
self.n -= 1;
self.sum_x -= xf;
self.sum_x_sq -= xf.powi(2);
} else {
self.n = 0;
self.sum_x = 0.0;
self.sum_x_sq = 0.0;
}
}
}
};
}
impl_gaussian_suffstat!(f32);
impl_gaussian_suffstat!(f64);
#[derive(Debug, Clone, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
pub struct MvGaussianSuffStat {
n: usize,
sum_x: DVector<f64>,
sum_x_sq: DMatrix<f64>,
}
impl MvGaussianSuffStat {
pub fn new(dims: usize) -> Self {
MvGaussianSuffStat {
n: 0,
sum_x: DVector::zeros(dims),
sum_x_sq: DMatrix::zeros(dims, dims),
}
}
pub fn n(&self) -> usize {
self.n
}
pub fn sum_x(&self) -> &DVector<f64> {
&self.sum_x
}
pub fn sum_x_sq(&self) -> &DMatrix<f64> {
&self.sum_x_sq
}
}
impl SuffStat<DVector<f64>> for MvGaussianSuffStat {
fn n(&self) -> usize {
self.n
}
fn observe(&mut self, x: &DVector<f64>) {
self.n += 1;
if self.n == 1 {
self.sum_x = x.clone();
self.sum_x_sq = x * x.transpose();
} else {
self.sum_x += x;
self.sum_x_sq += x * x.transpose();
}
}
fn forget(&mut self, x: &DVector<f64>) {
self.n -= 1;
if self.n > 0 {
self.sum_x -= x;
self.sum_x_sq -= x * x.transpose();
} else {
let dims = self.sum_x.len();
self.sum_x = DVector::zeros(dims);
self.sum_x_sq = DMatrix::zeros(dims, dims);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
mod bernoulli {
use super::*;
#[test]
fn new_should_be_empty() {
let stat = BernoulliSuffStat::new();
assert_eq!(stat.n, 0);
assert_eq!(stat.k, 0);
}
#[test]
fn observe_1() {
let mut stat = BernoulliSuffStat::new();
stat.observe(&1_u8);
assert_eq!(stat.n, 1);
assert_eq!(stat.k, 1);
}
#[test]
fn observe_true() {
let mut stat = BernoulliSuffStat::new();
stat.observe(&true);
assert_eq!(stat.n, 1);
assert_eq!(stat.k, 1);
}
#[test]
fn observe_0() {
let mut stat = BernoulliSuffStat::new();
stat.observe(&0_i8);
assert_eq!(stat.n, 1);
assert_eq!(stat.k, 0);
}
#[test]
fn observe_false() {
let mut stat = BernoulliSuffStat::new();
stat.observe(&false);
assert_eq!(stat.n, 1);
assert_eq!(stat.k, 0);
}
}
mod categorical {
use super::*;
#[test]
fn new() {
let sf = CategoricalSuffStat::new(4);
assert_eq!(sf.counts.len(), 4);
assert_eq!(sf.n, 0);
assert!(sf.counts.iter().all(|&ct| ct.abs() < 1E-12))
}
}
}