1use rand::Rng;
2
3use crate::data::{CategoricalDatum, CategoricalSuffStat, extract_stat_then};
4use crate::dist::{Categorical, Dirichlet, SymmetricDirichlet};
5use crate::misc::ln_gammafn;
6use crate::prelude::CategoricalData;
7use crate::traits::{ConjugatePrior, HasDensity, HasSuffStat, Sampleable};
8
9impl HasDensity<Categorical> for SymmetricDirichlet {
10 fn ln_f(&self, x: &Categorical) -> f64 {
11 self.ln_f(&x.weights())
12 }
13}
14
15impl Sampleable<Categorical> for SymmetricDirichlet {
16 fn draw<R: Rng>(&self, mut rng: &mut R) -> Categorical {
17 let weights: Vec<f64> = self.draw(&mut rng);
18 Categorical::new(&weights).expect("Invalid draw")
19 }
20}
21
22impl<X: CategoricalDatum> ConjugatePrior<X, Categorical>
23 for SymmetricDirichlet
24{
25 type Posterior = Dirichlet;
26 type MCache = f64;
27 type PpCache = (Vec<f64>, f64);
28
29 fn empty_stat(&self) -> <Categorical as HasSuffStat<X>>::Stat {
30 CategoricalSuffStat::new(self.k())
31 }
32
33 fn posterior(&self, x: &CategoricalData<X>) -> Self::Posterior {
34 extract_stat_then(self, x, |stat: &CategoricalSuffStat| {
35 let alphas: Vec<f64> =
36 stat.counts().iter().map(|&ct| self.alpha() + ct).collect();
37
38 Dirichlet::new(alphas).unwrap()
39 })
40 }
41
42 #[inline]
43 fn ln_m_cache(&self) -> Self::MCache {
44 let sum_alpha = self.alpha() * self.k() as f64;
45 let a = ln_gammafn(sum_alpha);
46 let d = ln_gammafn(self.alpha()) * self.k() as f64;
47 a - d
48 }
49
50 fn ln_m_with_cache(
51 &self,
52 cache: &Self::MCache,
53 x: &CategoricalData<X>,
54 ) -> f64 {
55 let sum_alpha = self.alpha() * self.k() as f64;
56
57 extract_stat_then(self, x, |stat: &CategoricalSuffStat| {
58 let b = ln_gammafn(sum_alpha + stat.n() as f64);
59 let c = stat
60 .counts()
61 .iter()
62 .fold(0.0, |acc, &ct| acc + ln_gammafn(self.alpha() + ct));
63
64 -b + c + cache
65 })
66 }
67
68 #[inline]
69 fn ln_pp_cache(&self, x: &CategoricalData<X>) -> Self::PpCache {
70 let post = self.posterior(x);
71 let norm = post.alphas().iter().fold(0.0, |acc, &a| acc + a);
72 (post.alphas, norm.ln())
73 }
74
75 fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &X) -> f64 {
76 let ix = y.into_usize();
77 cache.0[ix].ln() - cache.1
78 }
79}
80
81impl HasDensity<Categorical> for Dirichlet {
82 fn ln_f(&self, x: &Categorical) -> f64 {
83 self.ln_f(&x.weights())
84 }
85}
86
87impl Sampleable<Categorical> for Dirichlet {
88 fn draw<R: Rng>(&self, mut rng: &mut R) -> Categorical {
89 let weights: Vec<f64> = self.draw(&mut rng);
90 Categorical::new(&weights).expect("Invalid draw")
91 }
92}
93
94impl<X: CategoricalDatum> ConjugatePrior<X, Categorical> for Dirichlet {
95 type Posterior = Self;
96 type MCache = (f64, f64);
97 type PpCache = (Vec<f64>, f64);
98
99 fn empty_stat(&self) -> <Categorical as HasSuffStat<X>>::Stat {
100 CategoricalSuffStat::new(self.k())
101 }
102
103 fn posterior(&self, x: &CategoricalData<X>) -> Self::Posterior {
104 extract_stat_then(self, x, |stat: &CategoricalSuffStat| {
105 let alphas: Vec<f64> = self
106 .alphas()
107 .iter()
108 .zip(stat.counts().iter())
109 .map(|(&a, &ct)| a + ct)
110 .collect();
111
112 Dirichlet::new(alphas).unwrap()
113 })
114 }
115
116 #[inline]
117 fn ln_m_cache(&self) -> Self::MCache {
118 let sum_alpha = self.alphas().iter().fold(0.0, |acc, &a| acc + a);
119 let a = ln_gammafn(sum_alpha);
120 let d = self
121 .alphas()
122 .iter()
123 .fold(0.0, |acc, &a| acc + ln_gammafn(a));
124 (sum_alpha, a - d)
125 }
126
127 fn ln_m_with_cache(
128 &self,
129 cache: &Self::MCache,
130 x: &CategoricalData<X>,
131 ) -> f64 {
132 let (sum_alpha, ln_norm) = cache;
133 extract_stat_then(self, x, |stat: &CategoricalSuffStat| {
134 let b = ln_gammafn(sum_alpha + stat.n() as f64);
135 let c = self
136 .alphas()
137 .iter()
138 .zip(stat.counts().iter())
139 .map(|(&a, &ct)| ln_gammafn(a + ct))
140 .sum::<f64>();
141
142 -b + c + ln_norm
143 })
144 }
145
146 #[inline]
147 fn ln_pp_cache(&self, x: &CategoricalData<X>) -> Self::PpCache {
148 let post = self.posterior(x);
149 let norm = post.alphas().iter().fold(0.0, |acc, &a| acc + a);
150 (post.alphas, norm.ln())
151 }
152
153 fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &X) -> f64 {
154 let ix = y.into_usize();
155 cache.0[ix].ln() - cache.1
156 }
157}
158
159#[cfg(test)]
160mod test {
161 use super::*;
162 use crate::data::DataOrSuffStat;
163 use crate::test_conjugate_prior;
164
165 const TOL: f64 = 1E-12;
166
167 type CategoricalData<'a, X> = DataOrSuffStat<'a, X, Categorical>;
168
169 mod dir {
170 use super::*;
171
172 test_conjugate_prior!(
173 u8,
174 Categorical,
175 Dirichlet,
176 Dirichlet::new(vec![1.0, 2.0]).unwrap(),
177 n = 1_000_000
178 );
179 }
180
181 mod symmetric {
182 use super::*;
183
184 test_conjugate_prior!(
185 u8,
186 Categorical,
187 SymmetricDirichlet,
188 SymmetricDirichlet::jeffreys(2).unwrap(),
189 n = 1_000_000
190 );
191
192 #[test]
193 fn marginal_likelihood_u8_1() {
194 let alpha = 1.0;
195 let k = 3;
196 let xs: Vec<u8> = vec![0, 1, 1, 1, 1, 2, 2, 2, 2, 2];
197 let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
198
199 let csd = SymmetricDirichlet::new(alpha, k).unwrap();
200 let m = csd.ln_m(&data);
201
202 assert::close(-11.328_521_741_971_9, m, TOL);
203 }
204
205 #[test]
206 fn marginal_likelihood_u8_2() {
207 let alpha = 0.8;
208 let k = 3;
209 let mut xs: Vec<u8> = vec![0; 2];
210 let mut xs1: Vec<u8> = vec![1; 7];
211 let mut xs2: Vec<u8> = vec![2; 13];
212
213 xs.append(&mut xs1);
214 xs.append(&mut xs2);
215
216 let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
217
218 let csd = SymmetricDirichlet::new(alpha, k).unwrap();
219 let m = csd.ln_m(&data);
220
221 assert::close(-22.437_719_300_855_2, m, TOL);
222 }
223
224 #[test]
225 fn marginal_likelihood_u8_3() {
226 let alpha = 4.5;
227 let k = 3;
228 let mut xs: Vec<u8> = vec![0; 2];
229 let mut xs1: Vec<u8> = vec![1; 7];
230 let mut xs2: Vec<u8> = vec![2; 13];
231
232 xs.append(&mut xs1);
233 xs.append(&mut xs2);
234
235 let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
236
237 let csd = SymmetricDirichlet::new(alpha, k).unwrap();
238 let m = csd.ln_m(&data);
239
240 assert::close(-22.420_386_389_729_3, m, TOL);
241 }
242
243 #[test]
244 fn symmetric_prior_draw_log_weights_should_all_be_negative() {
245 let mut rng = rand::rng();
246 let csd = SymmetricDirichlet::new(1.0, 4).unwrap();
247 let ctgrl: Categorical = csd.draw(&mut rng);
248
249 assert!(ctgrl.ln_weights().iter().all(|lw| *lw < 0.0));
250 }
251
252 #[test]
253 fn symmetric_prior_draw_log_weights_should_be_unique() {
254 let mut rng = rand::rng();
255 let csd = SymmetricDirichlet::new(1.0, 4).unwrap();
256 let ctgrl: Categorical = csd.draw(&mut rng);
257
258 let ln_weights = ctgrl.ln_weights();
259
260 assert!((ln_weights[0] - ln_weights[1]).abs() > TOL);
261 assert!((ln_weights[1] - ln_weights[2]).abs() > TOL);
262 assert!((ln_weights[2] - ln_weights[3]).abs() > TOL);
263 assert!((ln_weights[0] - ln_weights[2]).abs() > TOL);
264 assert!((ln_weights[0] - ln_weights[3]).abs() > TOL);
265 assert!((ln_weights[1] - ln_weights[3]).abs() > TOL);
266 }
267
268 #[test]
269 fn symmetric_posterior_draw_log_weights_should_all_be_negative() {
270 let mut rng = rand::rng();
271
272 let xs: Vec<u8> = vec![0, 1, 2, 1, 2, 3, 0, 1, 1];
273 let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
274
275 let csd = SymmetricDirichlet::new(1.0, 4).unwrap();
276 let cd = csd.posterior(&data);
277 let ctgrl: Categorical = cd.draw(&mut rng);
278
279 assert!(ctgrl.ln_weights().iter().all(|lw| *lw < 0.0));
280 }
281
282 #[test]
283 fn symmetric_posterior_draw_log_weights_should_be_unique() {
284 let mut rng = rand::rng();
285
286 let xs: Vec<u8> = vec![0, 1, 2, 1, 2, 3, 0, 1, 1];
287 let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
288
289 let csd = SymmetricDirichlet::new(1.0, 4).unwrap();
290 let cd = csd.posterior(&data);
291 let ctgrl: Categorical = cd.draw(&mut rng);
292
293 let ln_weights = ctgrl.ln_weights();
294
295 assert!((ln_weights[0] - ln_weights[1]).abs() > TOL);
296 assert!((ln_weights[1] - ln_weights[2]).abs() > TOL);
297 assert!((ln_weights[2] - ln_weights[3]).abs() > TOL);
298 assert!((ln_weights[0] - ln_weights[2]).abs() > TOL);
299 assert!((ln_weights[0] - ln_weights[3]).abs() > TOL);
300 assert!((ln_weights[1] - ln_weights[3]).abs() > TOL);
301 }
302
303 #[test]
304 fn predictive_probability_value_1() {
305 let csd = SymmetricDirichlet::new(1.0, 3).unwrap();
306
307 let xs: Vec<u8> = vec![0, 1, 1, 1, 1, 2, 2, 2, 2, 2];
308 let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
309
310 let lp = csd.ln_pp(&0, &data);
311 assert::close(lp, -1.871_802_176_901_59, TOL);
312 }
313
314 #[test]
315 fn predictive_probability_value_2() {
316 let csd = SymmetricDirichlet::new(1.0, 3).unwrap();
317
318 let xs: Vec<u8> = vec![0, 1, 1, 1, 1, 2, 2, 2, 2, 2];
319 let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
320
321 let lp = csd.ln_pp(&1, &data);
322 assert::close(lp, -0.955_511_445_027_44, TOL);
323 }
324
325 #[test]
326 fn predictive_probability_value_3() {
327 let csd = SymmetricDirichlet::new(2.5, 3).unwrap();
328 let xs: Vec<u8> = vec![0, 1, 1, 1, 1, 2, 2, 2, 2, 2];
329 let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
330
331 let lp = csd.ln_pp(&0, &data);
332 assert::close(lp, -1.609_437_912_434_1, TOL);
333 }
334
335 #[test]
336 fn predictive_probability_value_4() {
337 let csd = SymmetricDirichlet::new(0.25, 3).unwrap();
338 let xs: Vec<u8> = vec![
339 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
340 2,
341 ];
342 let data: CategoricalData<u8> = DataOrSuffStat::Data(&xs);
343
344 let lp = csd.ln_pp(&0, &data);
345 assert::close(lp, -2.313_634_929_180_62, TOL);
346 }
347
348 #[test]
349 fn csd_loglike_value_1() {
350 let csd = SymmetricDirichlet::new(0.5, 3).unwrap();
351 let cat = Categorical::new(&[0.2, 0.3, 0.5]).unwrap();
352 let lf = csd.ln_f(&cat);
353 assert::close(lf, -0.084_598_117_749_354_22, TOL);
354 }
355 }
356}