rv/dist/gamma/
poisson_prior.rs

1use rand::Rng;
2
3use crate::data::PoissonSuffStat;
4use crate::dist::poisson::PoissonError;
5use crate::dist::{Gamma, Poisson};
6use crate::misc::ln_binom;
7use crate::traits::*;
8
9impl HasDensity<Poisson> for Gamma {
10    fn ln_f(&self, x: &Poisson) -> f64 {
11        match x.mean() {
12            Some(mean) => self.ln_f(&mean),
13            None => f64::NEG_INFINITY,
14        }
15    }
16}
17
18impl Sampleable<Poisson> for Gamma {
19    fn draw<R: Rng>(&self, mut rng: &mut R) -> Poisson {
20        let mean: f64 = self.draw(&mut rng);
21        match Poisson::new(mean) {
22            Ok(pois) => pois,
23            Err(PoissonError::RateTooLow { .. }) => {
24                Poisson::new_unchecked(f64::EPSILON)
25            }
26            Err(err) => panic!("Failed to draw Possion: {}", err),
27        }
28    }
29}
30
31impl Support<Poisson> for Gamma {
32    fn supports(&self, x: &Poisson) -> bool {
33        match x.mean() {
34            Some(mean) => mean > 0.0 && !mean.is_infinite(),
35            None => false,
36        }
37    }
38}
39
40impl ContinuousDistr<Poisson> for Gamma {}
41
42macro_rules! impl_traits {
43    ($kind: ty) => {
44        impl ConjugatePrior<$kind, Poisson> for Gamma {
45            type Posterior = Self;
46            type MCache = f64;
47            type PpCache = (f64, f64, f64);
48
49            fn posterior(&self, x: &DataOrSuffStat<$kind, Poisson>) -> Self {
50                let (n, sum) = match x {
51                    DataOrSuffStat::Data(ref xs) => {
52                        let mut stat = PoissonSuffStat::new();
53                        xs.iter().for_each(|x| stat.observe(x));
54                        (stat.n(), stat.sum())
55                    }
56                    DataOrSuffStat::SuffStat(ref stat) => {
57                        (stat.n(), stat.sum())
58                    }
59                };
60
61                let a = self.shape() + sum;
62                let b = self.rate() + (n as f64);
63                Self::new(a, b).expect("Invalid posterior parameters")
64            }
65
66            #[inline]
67            fn ln_m_cache(&self) -> Self::MCache {
68                let z0 = self
69                    .shape()
70                    .mul_add(-self.ln_rate(), self.ln_gamma_shape());
71                z0
72            }
73
74            fn ln_m_with_cache(
75                &self,
76                cache: &Self::MCache,
77                x: &DataOrSuffStat<$kind, Poisson>,
78            ) -> f64 {
79                let stat: PoissonSuffStat = match x {
80                    DataOrSuffStat::Data(ref xs) => {
81                        let mut stat = PoissonSuffStat::new();
82                        xs.iter().for_each(|x| stat.observe(x));
83                        stat
84                    }
85                    DataOrSuffStat::SuffStat(ref stat) => (*stat).clone(),
86                };
87
88                let data_or_suff: DataOrSuffStat<$kind, Poisson> =
89                    DataOrSuffStat::SuffStat(&stat);
90                let post = self.posterior(&data_or_suff);
91
92                let zn = post
93                    .shape()
94                    .mul_add(-post.ln_rate(), post.ln_gamma_shape());
95
96                zn - cache - stat.sum_ln_fact()
97            }
98
99            #[inline]
100            fn ln_pp_cache(
101                &self,
102                x: &DataOrSuffStat<$kind, Poisson>,
103            ) -> Self::PpCache {
104                let post = self.posterior(x);
105                let r = post.shape();
106                let p = 1.0 / (1.0 + post.rate());
107                (r, p, p.ln())
108            }
109
110            fn ln_pp_with_cache(
111                &self,
112                cache: &Self::PpCache,
113                y: &$kind,
114            ) -> f64 {
115                let (r, p, ln_p) = cache;
116                let k = f64::from(*y);
117                let bnp = ln_binom(k + r - 1.0, k);
118                bnp + (1.0 - p).ln() * r + k * ln_p
119            }
120        }
121    };
122}
123
124impl_traits!(u8);
125impl_traits!(u16);
126impl_traits!(u32);
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use crate::test_conjugate_prior;
132
133    const TOL: f64 = 1E-12;
134
135    test_conjugate_prior!(u32, Poisson, Gamma, Gamma::new(2.0, 1.2).unwrap());
136
137    #[test]
138    fn posterior_from_data() {
139        let data: Vec<u8> = vec![1, 2, 3, 4, 5];
140        let xs = DataOrSuffStat::Data::<u8, Poisson>(&data);
141        let posterior = Gamma::new(1.0, 1.0).unwrap().posterior(&xs);
142
143        assert::close(posterior.shape(), 16.0, TOL);
144        assert::close(posterior.rate(), 6.0, TOL);
145    }
146
147    #[test]
148    fn ln_m_no_data() {
149        let dist = Gamma::new(1.0, 1.0).unwrap();
150        let new_vec = Vec::new();
151        let data: DataOrSuffStat<u8, Poisson> = DataOrSuffStat::from(&new_vec);
152        assert::close(dist.ln_m(&data), 0.0, TOL);
153    }
154
155    #[test]
156    fn ln_m_data() {
157        let dist = Gamma::new(1.0, 1.0).unwrap();
158        let inputs: [u8; 5] = [0, 1, 2, 3, 4];
159        let expected: [f64; 5] = [
160            -std::f64::consts::LN_2,
161            -2.197_224_577_336_219_6,
162            -4.446_565_155_811_452,
163            -7.171_720_824_816_601,
164            -10.267_902_068_569_033,
165        ];
166
167        // Then test on the sequence of inputs
168        let suff_stats: Vec<PoissonSuffStat> = inputs
169            .iter()
170            .scan(PoissonSuffStat::new(), |acc, x| {
171                acc.observe(x);
172                Some(acc.clone())
173            })
174            .collect();
175
176        suff_stats
177            .iter()
178            .zip(expected.iter())
179            .for_each(|(ss, exp)| {
180                let data: DataOrSuffStat<u8, Poisson> =
181                    DataOrSuffStat::SuffStat(ss);
182                let r = dist.ln_m(&data);
183                assert::close(r, *exp, TOL);
184            });
185    }
186
187    #[test]
188    fn ln_pp_no_data() {
189        let dist = Gamma::new(1.0, 1.0).unwrap();
190        let inputs: [u8; 5] = [0, 1, 2, 3, 4];
191        let expected: [f64; 5] = [
192            -std::f64::consts::LN_2,
193            -1.386_294_361_119_890_6,
194            -2.079_441_541_679_835_7,
195            -2.772_588_722_239_781,
196            -3.465_735_902_799_726_5,
197        ];
198
199        for i in 0..inputs.len() {
200            assert::close(
201                dist.ln_pp(&inputs[i], &DataOrSuffStat::from(&vec![])),
202                expected[i],
203                TOL,
204            )
205        }
206    }
207
208    #[test]
209    fn ln_pp_data() {
210        let data: [u8; 10] = [5, 7, 8, 1, 0, 2, 2, 5, 1, 4];
211        let mut suff_stat = PoissonSuffStat::new();
212        data.iter().for_each(|d| suff_stat.observe(d));
213
214        let doss = DataOrSuffStat::SuffStat::<u8, Poisson>(&suff_stat);
215
216        let dist = Gamma::new(1.0, 1.0).unwrap();
217        let inputs: [u8; 5] = [0, 1, 2, 3, 4];
218        let expected: [f64; 5] = [
219            -3.132_409_571_626_673,
220            -2.033_797_282_958_563_5,
221            -1.600_933_200_662_284_5,
222            -1.546_865_979_392_009,
223            -1.754_505_344_170_253_6,
224        ];
225
226        for (i, e) in inputs.iter().zip(expected.iter()) {
227            assert::close(dist.ln_pp(i, &doss), *e, TOL);
228        }
229    }
230
231    #[test]
232    fn cannot_draw_zero_rate() {
233        let mut rng = rand::thread_rng();
234        let dist = Gamma::new(1.0, 1e-10).unwrap();
235        let stream =
236            <Gamma as Sampleable<Poisson>>::sample_stream(&dist, &mut rng);
237        assert!(stream.take(10_000).all(|pois| pois.rate() > 0.0));
238    }
239}