1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
use rand::Rng;
use special::Beta as SBeta;
use crate::data::{BernoulliSuffStat, Booleable, DataOrSuffStat};
use crate::dist::{Bernoulli, Beta};
use crate::traits::*;
impl Rv<Bernoulli> for Beta {
fn ln_f(&self, x: &Bernoulli) -> f64 {
self.ln_f(&x.p())
}
fn draw<R: Rng>(&self, mut rng: &mut R) -> Bernoulli {
let p: f64 = self.draw(&mut rng);
Bernoulli::new(p).expect("Failed to draw valid weight")
}
}
impl Support<Bernoulli> for Beta {
fn supports(&self, x: &Bernoulli) -> bool {
0.0 < x.p() && x.p() < 1.0
}
}
impl ContinuousDistr<Bernoulli> for Beta {}
impl<X: Booleable> ConjugatePrior<X, Bernoulli> for Beta {
type Posterior = Self;
type LnMCache = f64;
type LnPpCache = (f64, f64);
#[allow(clippy::many_single_char_names)]
fn posterior(&self, x: &DataOrSuffStat<X, Bernoulli>) -> Self {
let (n, k) = match x {
DataOrSuffStat::Data(ref xs) => {
let mut stat = BernoulliSuffStat::new();
xs.iter().for_each(|x| stat.observe(x));
(stat.n(), stat.k())
}
DataOrSuffStat::SuffStat(ref stat) => (stat.n(), stat.k()),
DataOrSuffStat::None => (0, 0),
};
let a = self.alpha() + k as f64;
let b = self.beta() + (n - k) as f64;
Beta::new(a, b).expect("Invalid posterior parameters")
}
#[inline]
fn ln_m_cache(&self) -> Self::LnMCache {
self.alpha().ln_beta(self.beta())
}
fn ln_m_with_cache(
&self,
cache: &Self::LnMCache,
x: &DataOrSuffStat<X, Bernoulli>,
) -> f64 {
let post = self.posterior(x);
post.alpha().ln_beta(post.beta()) - cache
}
#[inline]
fn ln_pp_cache(&self, x: &DataOrSuffStat<X, Bernoulli>) -> Self::LnPpCache {
let post = self.posterior(x);
let p: f64 = post.mean().expect("Mean undefined");
(p.ln(), (1.0 - p).ln())
}
fn ln_pp_with_cache(&self, cache: &Self::LnPpCache, y: &X) -> f64 {
if y.into_bool() {
cache.0
} else {
cache.1
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const TOL: f64 = 1E-12;
#[test]
fn posterior_from_data_bool() {
let data = vec![false, true, false, true, true];
let xs = DataOrSuffStat::Data::<bool, Bernoulli>(&data);
let posterior = Beta::new(1.0, 1.0).unwrap().posterior(&xs);
assert::close(posterior.alpha(), 4.0, TOL);
assert::close(posterior.beta(), 3.0, TOL);
}
#[test]
fn posterior_from_data_u16() {
let data: Vec<u16> = vec![0, 1, 0, 1, 1];
let xs = DataOrSuffStat::Data::<u16, Bernoulli>(&data);
let posterior = Beta::new(1.0, 1.0).unwrap().posterior(&xs);
assert::close(posterior.alpha(), 4.0, TOL);
assert::close(posterior.beta(), 3.0, TOL);
}
}