1#[cfg(feature = "serde1")]
3use serde::{Deserialize, Serialize};
4
5use crate::impl_display;
6use crate::misc::ln_gammafn;
7use crate::misc::vec_to_string;
8use crate::traits::{
9 ContinuousDistr, HasDensity, Parameterized, Sampleable, Support,
10};
11use rand::Rng;
12use rand_distr::Gamma as RGamma;
13use std::fmt;
14use std::sync::OnceLock;
15
16mod categorical_prior;
17
18#[derive(Debug, Clone)]
25#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
26#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
27pub struct SymmetricDirichlet {
28 alpha: f64,
29 k: usize,
30 #[cfg_attr(feature = "serde1", serde(skip))]
32 ln_gamma_alpha: OnceLock<f64>,
33}
34
35pub struct SymmetricDirichletParameters {
36 pub alpha: f64,
37 pub k: usize,
38}
39
40impl Parameterized for SymmetricDirichlet {
41 type Parameters = SymmetricDirichletParameters;
42
43 fn emit_params(&self) -> Self::Parameters {
44 Self::Parameters {
45 alpha: self.alpha(),
46 k: self.k(),
47 }
48 }
49
50 fn from_params(params: Self::Parameters) -> Self {
51 Self::new_unchecked(params.alpha, params.k)
52 }
53}
54
55impl PartialEq for SymmetricDirichlet {
56 fn eq(&self, other: &Self) -> bool {
57 self.alpha == other.alpha && self.k == other.k
58 }
59}
60
61#[derive(Debug, Clone, PartialEq)]
62#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
63#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
64pub enum SymmetricDirichletError {
65 KIsZero,
67 AlphaTooLow { alpha: f64 },
69 AlphaNotFinite { alpha: f64 },
71}
72
73impl SymmetricDirichlet {
74 #[inline]
80 pub fn new(alpha: f64, k: usize) -> Result<Self, SymmetricDirichletError> {
81 if k == 0 {
82 Err(SymmetricDirichletError::KIsZero)
83 } else if alpha <= 0.0 {
84 Err(SymmetricDirichletError::AlphaTooLow { alpha })
85 } else if !alpha.is_finite() {
86 Err(SymmetricDirichletError::AlphaNotFinite { alpha })
87 } else {
88 Ok(Self {
89 alpha,
90 k,
91 ln_gamma_alpha: OnceLock::new(),
92 })
93 }
94 }
95
96 #[inline]
99 #[must_use]
100 pub fn new_unchecked(alpha: f64, k: usize) -> Self {
101 Self {
102 alpha,
103 k,
104 ln_gamma_alpha: OnceLock::new(),
105 }
106 }
107
108 #[inline]
118 pub fn jeffreys(k: usize) -> Result<Self, SymmetricDirichletError> {
119 if k == 0 {
120 Err(SymmetricDirichletError::KIsZero)
121 } else {
122 Ok(Self {
123 alpha: 0.5,
124 k,
125 ln_gamma_alpha: OnceLock::new(),
126 })
127 }
128 }
129
130 #[inline]
140 pub fn alpha(&self) -> f64 {
141 self.alpha
142 }
143
144 #[inline]
169 pub fn set_alpha(
170 &mut self,
171 alpha: f64,
172 ) -> Result<(), SymmetricDirichletError> {
173 if alpha <= 0.0 {
174 Err(SymmetricDirichletError::AlphaTooLow { alpha })
175 } else if !alpha.is_finite() {
176 Err(SymmetricDirichletError::AlphaNotFinite { alpha })
177 } else {
178 self.set_alpha_unchecked(alpha);
179 self.ln_gamma_alpha = OnceLock::new();
180 Ok(())
181 }
182 }
183
184 #[inline]
186 pub fn set_alpha_unchecked(&mut self, alpha: f64) {
187 self.alpha = alpha;
188 self.ln_gamma_alpha = OnceLock::new();
189 }
190
191 #[inline]
201 pub fn k(&self) -> usize {
202 self.k
203 }
204
205 #[inline]
206 fn ln_gamma_alpha(&self) -> f64 {
207 *self.ln_gamma_alpha.get_or_init(|| ln_gammafn(self.alpha))
208 }
209}
210
211impl From<&SymmetricDirichlet> for String {
212 fn from(symdir: &SymmetricDirichlet) -> String {
213 format!("SymmetricDirichlet({}; α: {})", symdir.k, symdir.alpha)
214 }
215}
216
217impl_display!(SymmetricDirichlet);
218
219impl Sampleable<Vec<f64>> for SymmetricDirichlet {
220 fn draw<R: Rng>(&self, rng: &mut R) -> Vec<f64> {
221 let g = RGamma::new(self.alpha, 1.0).unwrap();
222 let mut xs: Vec<f64> = (0..self.k).map(|_| rng.sample(g)).collect();
223 let z: f64 = xs.iter().sum();
224 xs.iter_mut().for_each(|x| *x /= z);
225 xs
226 }
227}
228
229impl HasDensity<Vec<f64>> for SymmetricDirichlet {
230 fn ln_f(&self, x: &Vec<f64>) -> f64 {
231 let kf = self.k as f64;
232 let sum_ln_gamma = self.ln_gamma_alpha() * kf;
233 let ln_gamma_sum = ln_gammafn(self.alpha * kf);
234
235 let am1 = self.alpha - 1.0;
236 let term = x.iter().fold(0.0, |acc, &xi| am1.mul_add(xi.ln(), acc));
237
238 term - (sum_ln_gamma - ln_gamma_sum)
239 }
240}
241
242#[derive(Debug, Clone, PartialEq)]
243#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
244#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
245pub enum DirichletError {
246 KIsZero,
248 AlphasEmpty,
250 AlphaTooLow { ix: usize, alpha: f64 },
252 AlphaNotFinite { ix: usize, alpha: f64 },
254}
255
256#[derive(Debug, Clone, PartialEq)]
259#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
260#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
261pub struct Dirichlet {
262 pub(crate) alphas: Vec<f64>,
264}
265
266pub struct DirichletParameters {
267 pub alphas: Vec<f64>,
268}
269
270impl Parameterized for Dirichlet {
271 type Parameters = DirichletParameters;
272
273 fn emit_params(&self) -> Self::Parameters {
274 Self::Parameters {
275 alphas: self.alphas().clone(),
276 }
277 }
278
279 fn from_params(params: Self::Parameters) -> Self {
280 Self::new_unchecked(params.alphas)
281 }
282}
283
284impl From<SymmetricDirichlet> for Dirichlet {
285 fn from(symdir: SymmetricDirichlet) -> Self {
286 Dirichlet::new_unchecked(vec![symdir.alpha; symdir.k])
287 }
288}
289
290impl From<&SymmetricDirichlet> for Dirichlet {
291 fn from(symdir: &SymmetricDirichlet) -> Self {
292 Dirichlet::new_unchecked(vec![symdir.alpha; symdir.k])
293 }
294}
295
296impl Dirichlet {
297 pub fn new(alphas: Vec<f64>) -> Result<Self, DirichletError> {
299 if alphas.is_empty() {
300 return Err(DirichletError::AlphasEmpty);
301 }
302
303 alphas.iter().enumerate().try_for_each(|(ix, &alpha)| {
304 if alpha <= 0.0 {
305 Err(DirichletError::AlphaTooLow { ix, alpha })
306 } else if !alpha.is_finite() {
307 Err(DirichletError::AlphaNotFinite { ix, alpha })
308 } else {
309 Ok(())
310 }
311 })?;
312
313 Ok(Dirichlet { alphas })
314 }
315
316 #[inline]
319 #[must_use]
320 pub fn new_unchecked(alphas: Vec<f64>) -> Self {
321 Dirichlet { alphas }
322 }
323
324 #[inline]
345 pub fn symmetric(alpha: f64, k: usize) -> Result<Self, DirichletError> {
346 if k == 0 {
347 Err(DirichletError::KIsZero)
348 } else if alpha <= 0.0 {
349 Err(DirichletError::AlphaTooLow { ix: 0, alpha })
350 } else if !alpha.is_finite() {
351 Err(DirichletError::AlphaNotFinite { ix: 0, alpha })
352 } else {
353 Ok(Dirichlet {
354 alphas: vec![alpha; k],
355 })
356 }
357 }
358
359 #[inline]
381 pub fn jeffreys(k: usize) -> Result<Self, DirichletError> {
382 if k == 0 {
383 Err(DirichletError::KIsZero)
384 } else {
385 Ok(Dirichlet::new_unchecked(vec![0.5; k]))
386 }
387 }
388
389 #[inline]
391 #[must_use]
392 pub fn k(&self) -> usize {
393 self.alphas.len()
394 }
395
396 #[inline]
398 #[must_use]
399 pub fn alphas(&self) -> &Vec<f64> {
400 &self.alphas
401 }
402}
403
404impl From<&Dirichlet> for String {
405 fn from(dir: &Dirichlet) -> String {
406 format!("Dir(α: {})", vec_to_string(&dir.alphas, 5))
407 }
408}
409
410impl_display!(Dirichlet);
411
412impl ContinuousDistr<Vec<f64>> for SymmetricDirichlet {}
413
414impl Support<Vec<f64>> for SymmetricDirichlet {
415 fn supports(&self, x: &Vec<f64>) -> bool {
416 if x.len() == self.k {
417 let sum = x.iter().fold(0.0, |acc, &xi| acc + xi);
418 x.iter().all(|&xi| xi > 0.0) && (1.0 - sum).abs() < 1E-12
419 } else {
420 false
421 }
422 }
423}
424
425impl Sampleable<Vec<f64>> for Dirichlet {
426 fn draw<R: Rng>(&self, rng: &mut R) -> Vec<f64> {
427 let gammas: Vec<RGamma<f64>> = self
428 .alphas
429 .iter()
430 .map(|&alpha| RGamma::new(alpha, 1.0).unwrap())
431 .collect();
432 let mut xs: Vec<f64> = gammas.iter().map(|g| rng.sample(g)).collect();
433 let z: f64 = xs.iter().sum();
434 xs.iter_mut().for_each(|x| *x /= z);
435 xs
436 }
437}
438
439impl HasDensity<Vec<f64>> for Dirichlet {
440 fn ln_f(&self, x: &Vec<f64>) -> f64 {
441 let sum_ln_gamma: f64 = self
443 .alphas
444 .iter()
445 .fold(0.0, |acc, &alpha| acc + ln_gammafn(alpha));
446
447 let ln_gamma_sum: f64 = ln_gammafn(self.alphas.iter().sum::<f64>());
448
449 let term = x
450 .iter()
451 .zip(self.alphas.iter())
452 .fold(0.0, |acc, (&xi, &alpha)| {
453 (alpha - 1.0).mul_add(xi.ln(), acc)
454 });
455
456 term - (sum_ln_gamma - ln_gamma_sum)
457 }
458}
459
460impl ContinuousDistr<Vec<f64>> for Dirichlet {}
461
462impl Support<Vec<f64>> for Dirichlet {
463 fn supports(&self, x: &Vec<f64>) -> bool {
464 if x.len() == self.alphas.len() {
465 let sum = x.iter().fold(0.0, |acc, &xi| acc + xi);
466 x.iter().all(|&xi| xi > 0.0) && (1.0 - sum).abs() < 1E-12
467 } else {
468 false
469 }
470 }
471}
472
473impl std::error::Error for SymmetricDirichletError {}
474impl std::error::Error for DirichletError {}
475
476#[cfg_attr(coverage_nightly, coverage(off))]
477impl fmt::Display for SymmetricDirichletError {
478 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
479 match self {
480 Self::AlphaTooLow { alpha } => {
481 write!(f, "alpha ({alpha}) must be greater than zero")
482 }
483 Self::AlphaNotFinite { alpha } => {
484 write!(f, "alpha ({alpha}) was non-finite")
485 }
486 Self::KIsZero => write!(f, "k must be greater than zero"),
487 }
488 }
489}
490
491impl fmt::Display for DirichletError {
492 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
493 match self {
494 Self::KIsZero => write!(f, "k must be greater than zero"),
495 Self::AlphasEmpty => write!(f, "alphas vector was empty"),
496 Self::AlphaTooLow { ix, alpha } => {
497 write!(f, "Invalid alpha at index {ix}: {alpha} <= 0.0")
498 }
499 Self::AlphaNotFinite { ix, alpha } => {
500 write!(f, "Non-finite alpha at index {ix}: {alpha}")
501 }
502 }
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use crate::{test_basic_impls, verify_cache_resets};
510
511 const TOL: f64 = 1E-12;
512
513 mod dir {
514 use super::*;
515
516 test_basic_impls!(Vec<f64>, Dirichlet, Dirichlet::jeffreys(4).unwrap());
517
518 #[test]
519 fn properly_sized_points_on_simplex_should_be_in_support() {
520 let dir = Dirichlet::symmetric(1.0, 4).unwrap();
521 assert!(dir.supports(&vec![0.25, 0.25, 0.25, 0.25]));
522 assert!(dir.supports(&vec![0.1, 0.2, 0.3, 0.4]));
523 }
524
525 #[test]
526 fn improperly_sized_points_should_not_be_in_support() {
527 let dir = Dirichlet::symmetric(1.0, 3).unwrap();
528 assert!(!dir.supports(&vec![0.25, 0.25, 0.25, 0.25]));
529 assert!(!dir.supports(&vec![0.1, 0.2, 0.7, 0.4]));
530 }
531
532 #[test]
533 fn properly_sized_points_off_simplex_should_not_be_in_support() {
534 let dir = Dirichlet::symmetric(1.0, 4).unwrap();
535 assert!(!dir.supports(&vec![0.25, 0.25, 0.26, 0.25]));
536 assert!(!dir.supports(&vec![0.1, 0.3, 0.3, 0.4]));
537 }
538
539 #[test]
540 fn draws_should_be_in_support() {
541 let mut rng = rand::rng();
542 let dir = Dirichlet::jeffreys(10).unwrap();
545 for _ in 0..100 {
546 let x = dir.draw(&mut rng);
547 assert!(dir.supports(&x));
548 }
549 }
550
551 #[test]
552 fn sample_should_return_the_proper_number_of_draws() {
553 let mut rng = rand::rng();
554 let dir = Dirichlet::jeffreys(3).unwrap();
555 let xs: Vec<Vec<f64>> = dir.sample(88, &mut rng);
556 assert_eq!(xs.len(), 88);
557 }
558
559 #[test]
560 fn log_pdf_symmetric() {
561 let dir = Dirichlet::symmetric(1.0, 3).unwrap();
562 assert::close(
563 dir.ln_pdf(&vec![0.2, 0.3, 0.5]),
564 std::f64::consts::LN_2,
565 TOL,
566 );
567 }
568
569 #[test]
570 fn log_pdf_jeffreys() {
571 let dir = Dirichlet::jeffreys(3).unwrap();
572 assert::close(
573 dir.ln_pdf(&vec![0.2, 0.3, 0.5]),
574 -0.084_598_117_749_354_22,
575 TOL,
576 );
577 }
578
579 #[test]
580 fn log_pdf() {
581 let dir = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap();
582 assert::close(
583 dir.ln_pdf(&vec![0.2, 0.3, 0.5]),
584 1.504_077_396_776_273_7,
585 TOL,
586 );
587 }
588 }
589
590 mod symdir {
591 use std::f64::consts::PI;
592
593 use super::*;
594
595 test_basic_impls!(
596 Vec<f64>,
597 SymmetricDirichlet,
598 SymmetricDirichlet::jeffreys(4).unwrap()
599 );
600
601 #[test]
602 fn sample_should_return_the_proper_number_of_draws() {
603 let mut rng = rand::rng();
604 let symdir = SymmetricDirichlet::jeffreys(3).unwrap();
605 let xs: Vec<Vec<f64>> = symdir.sample(88, &mut rng);
606 assert_eq!(xs.len(), 88);
607 }
608
609 #[test]
610 fn log_pdf_jeffreys() {
611 let symdir = SymmetricDirichlet::jeffreys(3).unwrap();
612 assert::close(
613 symdir.ln_pdf(&vec![0.2, 0.3, 0.5]),
614 -0.084_598_117_749_354_22,
615 TOL,
616 );
617 }
618
619 #[test]
620 fn properly_sized_points_off_simplex_should_not_be_in_support() {
621 let symdir = SymmetricDirichlet::new(1.0, 4).unwrap();
622 assert!(!symdir.supports(&vec![0.25, 0.25, 0.26, 0.25]));
623 assert!(!symdir.supports(&vec![0.1, 0.3, 0.3, 0.4]));
624 }
625
626 #[test]
627 fn draws_should_be_in_support() {
628 let mut rng = rand::rng();
629 let symdir = SymmetricDirichlet::jeffreys(10).unwrap();
632 for _ in 0..100 {
633 let x: Vec<f64> = symdir.draw(&mut rng);
634 assert!(symdir.supports(&x));
635 }
636 }
637
638 verify_cache_resets!(
639 [unchecked],
640 ln_f_is_same_after_reset_unchecked_alpha_identically,
641 set_alpha_unchecked,
642 SymmetricDirichlet::new(1.2, 2).unwrap(),
643 vec![0.1_f64, 0.9_f64],
644 1.2,
645 PI
646 );
647
648 verify_cache_resets!(
649 [checked],
650 ln_f_is_same_after_reset_checked_alpha_identically,
651 set_alpha,
652 SymmetricDirichlet::new(1.2, 2).unwrap(),
653 vec![0.1_f64, 0.9_f64],
654 1.2,
655 PI
656 );
657 }
658
659 #[test]
660 fn emit_and_from_params_are_identity() {
661 let dist_a = SymmetricDirichlet::new(1.5, 7).unwrap();
662 let dist_b = SymmetricDirichlet::from_params(dist_a.emit_params());
663 assert_eq!(dist_a, dist_b);
664 }
665}