rv/dist/dirichlet/
categorical_prior.rs

1use rand::Rng;
2
3use crate::data::{CategoricalDatum, CategoricalSuffStat, extract_stat_then};
4use crate::dist::{Categorical, Dirichlet, SymmetricDirichlet};
5use crate::misc::ln_gammafn;
6use crate::prelude::CategoricalData;
7use crate::traits::{ConjugatePrior, HasDensity, HasSuffStat, Sampleable};
8
9impl HasDensity<Categorical> for SymmetricDirichlet {
10    fn ln_f(&self, x: &Categorical) -> f64 {
11        self.ln_f(&x.weights())
12    }
13}
14
15impl Sampleable<Categorical> for SymmetricDirichlet {
16    fn draw<R: Rng>(&self, mut rng: &mut R) -> Categorical {
17        let weights: Vec<f64> = self.draw(&mut rng);
18        Categorical::new(&weights).expect("Invalid draw")
19    }
20}
21
22impl<X: CategoricalDatum> ConjugatePrior<X, Categorical>
23    for SymmetricDirichlet
24{
25    type Posterior = Dirichlet;
26    type MCache = f64;
27    type PpCache = (Vec<f64>, f64);
28
29    fn empty_stat(&self) -> <Categorical as HasSuffStat<X>>::Stat {
30        CategoricalSuffStat::new(self.k())
31    }
32
33    fn posterior(&self, x: &CategoricalData<X>) -> Self::Posterior {
34        extract_stat_then(self, x, |stat: &CategoricalSuffStat| {
35            let alphas: Vec<f64> =
36                stat.counts().iter().map(|&ct| self.alpha() + ct).collect();
37
38            Dirichlet::new(alphas).unwrap()
39        })
40    }
41
42    #[inline]
43    fn ln_m_cache(&self) -> Self::MCache {
44        let sum_alpha = self.alpha() * self.k() as f64;
45        let a = ln_gammafn(sum_alpha);
46        let d = ln_gammafn(self.alpha()) * self.k() as f64;
47        a - d
48    }
49
50    fn ln_m_with_cache(
51        &self,
52        cache: &Self::MCache,
53        x: &CategoricalData<X>,
54    ) -> f64 {
55        let sum_alpha = self.alpha() * self.k() as f64;
56
57        extract_stat_then(self, x, |stat: &CategoricalSuffStat| {
58            let b = ln_gammafn(sum_alpha + stat.n() as f64);
59            let c = stat
60                .counts()
61                .iter()
62                .fold(0.0, |acc, &ct| acc + ln_gammafn(self.alpha() + ct));
63
64            -b + c + cache
65        })
66    }
67
68    #[inline]
69    fn ln_pp_cache(&self, x: &CategoricalData<X>) -> Self::PpCache {
70        let post = self.posterior(x);
71        let norm = post.alphas().iter().fold(0.0, |acc, &a| acc + a);
72        (post.alphas, norm.ln())
73    }
74
75    fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &X) -> f64 {
76        let ix = y.into_usize();
77        cache.0[ix].ln() - cache.1
78    }
79}
80
81impl HasDensity<Categorical> for Dirichlet {
82    fn ln_f(&self, x: &Categorical) -> f64 {
83        self.ln_f(&x.weights())
84    }
85}
86
87impl Sampleable<Categorical> for Dirichlet {
88    fn draw<R: Rng>(&self, mut rng: &mut R) -> Categorical {
89        let weights: Vec<f64> = self.draw(&mut rng);
90        Categorical::new(&weights).expect("Invalid draw")
91    }
92}
93
94impl<X: CategoricalDatum> ConjugatePrior<X, Categorical> for Dirichlet {
95    type Posterior = Self;
96    type MCache = (f64, f64);
97    type PpCache = (Vec<f64>, f64);
98
99    fn empty_stat(&self) -> <Categorical as HasSuffStat<X>>::Stat {
100        CategoricalSuffStat::new(self.k())
101    }
102
103    fn posterior(&self, x: &CategoricalData<X>) -> Self::Posterior {
104        extract_stat_then(self, x, |stat: &CategoricalSuffStat| {
105            let alphas: Vec<f64> = self
106                .alphas()
107                .iter()
108                .zip(stat.counts().iter())
109                .map(|(&a, &ct)| a + ct)
110                .collect();
111
112            Dirichlet::new(alphas).unwrap()
113        })
114    }
115
116    #[inline]
117    fn ln_m_cache(&self) -> Self::MCache {
118        let sum_alpha = self.alphas().iter().fold(0.0, |acc, &a| acc + a);
119        let a = ln_gammafn(sum_alpha);
120        let d = self
121            .alphas()
122            .iter()
123            .fold(0.0, |acc, &a| acc + ln_gammafn(a));
124        (sum_alpha, a - d)
125    }
126
127    fn ln_m_with_cache(
128        &self,
129        cache: &Self::MCache,
130        x: &CategoricalData<X>,
131    ) -> f64 {
132        let (sum_alpha, ln_norm) = cache;
133        extract_stat_then(self, x, |stat: &CategoricalSuffStat| {
134            let b = ln_gammafn(sum_alpha + stat.n() as f64);
135            let c = self
136                .alphas()
137                .iter()
138                .zip(stat.counts().iter())
139                .map(|(&a, &ct)| ln_gammafn(a + ct))
140                .sum::<f64>();
141
142            -b + c + ln_norm
143        })
144    }
145
146    #[inline]
147    fn ln_pp_cache(&self, x: &CategoricalData<X>) -> Self::PpCache {
148        let post = self.posterior(x);
149        let norm = post.alphas().iter().fold(0.0, |acc, &a| acc + a);
150        (post.alphas, norm.ln())
151    }
152
153    fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &X) -> f64 {
154        let ix = y.into_usize();
155        cache.0[ix].ln() - cache.1
156    }
157}
158
159#[cfg(test)]
160mod test {
161    use super::*;
162    use crate::data::DataOrSuffStat;
163    use crate::test_conjugate_prior;
164
165    const TOL: f64 = 1E-12;
166
167    type CategoricalData<'a, X> = DataOrSuffStat<'a, X, Categorical>;
168
169    mod dir {
170        use super::*;
171
172        test_conjugate_prior!(
173            u8,
174            Categorical,
175            Dirichlet,
176            Dirichlet::new(vec![1.0, 2.0]).unwrap(),
177            n = 1_000_000
178        );
179    }
180
181    mod symmetric {
182        use super::*;
183
184        test_conjugate_prior!(
185            u8,
186            Categorical,
187            SymmetricDirichlet,
188            SymmetricDirichlet::jeffreys(2).unwrap(),
189            n = 1_000_000
190        );
191
192        #[test]
193        fn marginal_likelihood_u8_1() {
194            let alpha = 1.0;
195            let k = 3;
196            let xs: Vec<u8> = vec![0, 1, 1, 1, 1, 2, 2, 2, 2, 2];
197            let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
198
199            let csd = SymmetricDirichlet::new(alpha, k).unwrap();
200            let m = csd.ln_m(&data);
201
202            assert::close(-11.328_521_741_971_9, m, TOL);
203        }
204
205        #[test]
206        fn marginal_likelihood_u8_2() {
207            let alpha = 0.8;
208            let k = 3;
209            let mut xs: Vec<u8> = vec![0; 2];
210            let mut xs1: Vec<u8> = vec![1; 7];
211            let mut xs2: Vec<u8> = vec![2; 13];
212
213            xs.append(&mut xs1);
214            xs.append(&mut xs2);
215
216            let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
217
218            let csd = SymmetricDirichlet::new(alpha, k).unwrap();
219            let m = csd.ln_m(&data);
220
221            assert::close(-22.437_719_300_855_2, m, TOL);
222        }
223
224        #[test]
225        fn marginal_likelihood_u8_3() {
226            let alpha = 4.5;
227            let k = 3;
228            let mut xs: Vec<u8> = vec![0; 2];
229            let mut xs1: Vec<u8> = vec![1; 7];
230            let mut xs2: Vec<u8> = vec![2; 13];
231
232            xs.append(&mut xs1);
233            xs.append(&mut xs2);
234
235            let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
236
237            let csd = SymmetricDirichlet::new(alpha, k).unwrap();
238            let m = csd.ln_m(&data);
239
240            assert::close(-22.420_386_389_729_3, m, TOL);
241        }
242
243        #[test]
244        fn symmetric_prior_draw_log_weights_should_all_be_negative() {
245            let mut rng = rand::rng();
246            let csd = SymmetricDirichlet::new(1.0, 4).unwrap();
247            let ctgrl: Categorical = csd.draw(&mut rng);
248
249            assert!(ctgrl.ln_weights().iter().all(|lw| *lw < 0.0));
250        }
251
252        #[test]
253        fn symmetric_prior_draw_log_weights_should_be_unique() {
254            let mut rng = rand::rng();
255            let csd = SymmetricDirichlet::new(1.0, 4).unwrap();
256            let ctgrl: Categorical = csd.draw(&mut rng);
257
258            let ln_weights = ctgrl.ln_weights();
259
260            assert!((ln_weights[0] - ln_weights[1]).abs() > TOL);
261            assert!((ln_weights[1] - ln_weights[2]).abs() > TOL);
262            assert!((ln_weights[2] - ln_weights[3]).abs() > TOL);
263            assert!((ln_weights[0] - ln_weights[2]).abs() > TOL);
264            assert!((ln_weights[0] - ln_weights[3]).abs() > TOL);
265            assert!((ln_weights[1] - ln_weights[3]).abs() > TOL);
266        }
267
268        #[test]
269        fn symmetric_posterior_draw_log_weights_should_all_be_negative() {
270            let mut rng = rand::rng();
271
272            let xs: Vec<u8> = vec![0, 1, 2, 1, 2, 3, 0, 1, 1];
273            let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
274
275            let csd = SymmetricDirichlet::new(1.0, 4).unwrap();
276            let cd = csd.posterior(&data);
277            let ctgrl: Categorical = cd.draw(&mut rng);
278
279            assert!(ctgrl.ln_weights().iter().all(|lw| *lw < 0.0));
280        }
281
282        #[test]
283        fn symmetric_posterior_draw_log_weights_should_be_unique() {
284            let mut rng = rand::rng();
285
286            let xs: Vec<u8> = vec![0, 1, 2, 1, 2, 3, 0, 1, 1];
287            let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
288
289            let csd = SymmetricDirichlet::new(1.0, 4).unwrap();
290            let cd = csd.posterior(&data);
291            let ctgrl: Categorical = cd.draw(&mut rng);
292
293            let ln_weights = ctgrl.ln_weights();
294
295            assert!((ln_weights[0] - ln_weights[1]).abs() > TOL);
296            assert!((ln_weights[1] - ln_weights[2]).abs() > TOL);
297            assert!((ln_weights[2] - ln_weights[3]).abs() > TOL);
298            assert!((ln_weights[0] - ln_weights[2]).abs() > TOL);
299            assert!((ln_weights[0] - ln_weights[3]).abs() > TOL);
300            assert!((ln_weights[1] - ln_weights[3]).abs() > TOL);
301        }
302
303        #[test]
304        fn predictive_probability_value_1() {
305            let csd = SymmetricDirichlet::new(1.0, 3).unwrap();
306
307            let xs: Vec<u8> = vec![0, 1, 1, 1, 1, 2, 2, 2, 2, 2];
308            let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
309
310            let lp = csd.ln_pp(&0, &data);
311            assert::close(lp, -1.871_802_176_901_59, TOL);
312        }
313
314        #[test]
315        fn predictive_probability_value_2() {
316            let csd = SymmetricDirichlet::new(1.0, 3).unwrap();
317
318            let xs: Vec<u8> = vec![0, 1, 1, 1, 1, 2, 2, 2, 2, 2];
319            let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
320
321            let lp = csd.ln_pp(&1, &data);
322            assert::close(lp, -0.955_511_445_027_44, TOL);
323        }
324
325        #[test]
326        fn predictive_probability_value_3() {
327            let csd = SymmetricDirichlet::new(2.5, 3).unwrap();
328            let xs: Vec<u8> = vec![0, 1, 1, 1, 1, 2, 2, 2, 2, 2];
329            let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
330
331            let lp = csd.ln_pp(&0, &data);
332            assert::close(lp, -1.609_437_912_434_1, TOL);
333        }
334
335        #[test]
336        fn predictive_probability_value_4() {
337            let csd = SymmetricDirichlet::new(0.25, 3).unwrap();
338            let xs: Vec<u8> = vec![
339                0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
340                2,
341            ];
342            let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
343
344            let lp = csd.ln_pp(&0, &data);
345            assert::close(lp, -2.313_634_929_180_62, TOL);
346        }
347
348        #[test]
349        fn csd_loglike_value_1() {
350            let csd = SymmetricDirichlet::new(0.5, 3).unwrap();
351            let cat = Categorical::new(&[0.2, 0.3, 0.5]).unwrap();
352            let lf = csd.ln_f(&cat);
353            assert::close(lf, -0.084_598_117_749_354_22, TOL);
354        }
355    }
356}