rv/dist/
dirichlet.rs

1//! Dirichlet and Symmetric Dirichlet distributions over simplexes
2#[cfg(feature = "serde1")]
3use serde::{Deserialize, Serialize};
4
5use crate::impl_display;
6use crate::misc::ln_gammafn;
7use crate::misc::vec_to_string;
8use crate::traits::{
9    ContinuousDistr, HasDensity, Parameterized, Sampleable, Support,
10};
11use rand::Rng;
12use rand_distr::Gamma as RGamma;
13use std::fmt;
14use std::sync::OnceLock;
15
16mod categorical_prior;
17
18/// Symmetric [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution)
19/// where all alphas are the same.
20///
21/// `SymmetricDirichlet { alpha, k }` is mathematical equivalent to
22/// `Dirichlet { alphas: vec![alpha; k] }`. This version has some extra
23/// optimizations to seep up computing the PDF and drawing random vectors.
24#[derive(Debug, Clone)]
25#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
26#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
27pub struct SymmetricDirichlet {
28    alpha: f64,
29    k: usize,
30    /// Cached `ln_gamma(alpha)`
31    #[cfg_attr(feature = "serde1", serde(skip))]
32    ln_gamma_alpha: OnceLock<f64>,
33}
34
35pub struct SymmetricDirichletParameters {
36    pub alpha: f64,
37    pub k: usize,
38}
39
40impl Parameterized for SymmetricDirichlet {
41    type Parameters = SymmetricDirichletParameters;
42
43    fn emit_params(&self) -> Self::Parameters {
44        Self::Parameters {
45            alpha: self.alpha(),
46            k: self.k(),
47        }
48    }
49
50    fn from_params(params: Self::Parameters) -> Self {
51        Self::new_unchecked(params.alpha, params.k)
52    }
53}
54
55impl PartialEq for SymmetricDirichlet {
56    fn eq(&self, other: &Self) -> bool {
57        self.alpha == other.alpha && self.k == other.k
58    }
59}
60
61#[derive(Debug, Clone, PartialEq)]
62#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
63#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
64pub enum SymmetricDirichletError {
65    /// k parameter is zero
66    KIsZero,
67    /// alpha parameter(s) is less than or equal to zero
68    AlphaTooLow { alpha: f64 },
69    /// alpha parameter(s) is infinite or NaN
70    AlphaNotFinite { alpha: f64 },
71}
72
73impl SymmetricDirichlet {
74    /// Create a new symmetric Dirichlet distribution
75    ///
76    /// # Arguments
77    /// - alpha: The Dirichlet weight.
78    /// - k : The number of weights. `alpha` will be replicated `k` times.
79    #[inline]
80    pub fn new(alpha: f64, k: usize) -> Result<Self, SymmetricDirichletError> {
81        if k == 0 {
82            Err(SymmetricDirichletError::KIsZero)
83        } else if alpha <= 0.0 {
84            Err(SymmetricDirichletError::AlphaTooLow { alpha })
85        } else if !alpha.is_finite() {
86            Err(SymmetricDirichletError::AlphaNotFinite { alpha })
87        } else {
88            Ok(Self {
89                alpha,
90                k,
91                ln_gamma_alpha: OnceLock::new(),
92            })
93        }
94    }
95
96    /// Create a new `SymmetricDirichlet` without checking whether the parameters
97    /// are valid.
98    #[inline]
99    #[must_use]
100    pub fn new_unchecked(alpha: f64, k: usize) -> Self {
101        Self {
102            alpha,
103            k,
104            ln_gamma_alpha: OnceLock::new(),
105        }
106    }
107
108    /// The Jeffrey's Dirichlet prior for Categorical distributions
109    ///
110    /// # Example
111    ///
112    /// ```rust
113    /// # use rv::dist::SymmetricDirichlet;
114    /// let symdir = SymmetricDirichlet::jeffreys(4).unwrap();
115    /// assert_eq!(symdir, SymmetricDirichlet::new(0.5, 4).unwrap());
116    /// ```
117    #[inline]
118    pub fn jeffreys(k: usize) -> Result<Self, SymmetricDirichletError> {
119        if k == 0 {
120            Err(SymmetricDirichletError::KIsZero)
121        } else {
122            Ok(Self {
123                alpha: 0.5,
124                k,
125                ln_gamma_alpha: OnceLock::new(),
126            })
127        }
128    }
129
130    /// Get the alpha uniform weight parameter
131    ///
132    /// # Example
133    ///
134    /// ```rust
135    /// # use rv::dist::SymmetricDirichlet;
136    /// let symdir = SymmetricDirichlet::new(1.2, 5).unwrap();
137    /// assert_eq!(symdir.alpha(), 1.2);
138    /// ```
139    #[inline]
140    pub fn alpha(&self) -> f64 {
141        self.alpha
142    }
143
144    /// Set the value of alpha
145    ///
146    /// # Example
147    /// ```rust
148    /// # use rv::dist::SymmetricDirichlet;
149    /// let mut symdir = SymmetricDirichlet::new(1.1, 5).unwrap();
150    /// assert_eq!(symdir.alpha(), 1.1);
151    ///
152    /// symdir.set_alpha(2.3).unwrap();
153    /// assert_eq!(symdir.alpha(), 2.3);
154    /// ```
155    ///
156    /// Will error for invalid parameters
157    ///
158    /// ```rust
159    /// # use rv::dist::SymmetricDirichlet;
160    /// # let mut symdir = SymmetricDirichlet::new(1.1, 5).unwrap();
161    /// assert!(symdir.set_alpha(0.5).is_ok());
162    /// assert!(symdir.set_alpha(0.0).is_err());
163    /// assert!(symdir.set_alpha(-1.0).is_err());
164    /// assert!(symdir.set_alpha(f64::INFINITY).is_err());
165    /// assert!(symdir.set_alpha(f64::NEG_INFINITY).is_err());
166    /// assert!(symdir.set_alpha(f64::NAN).is_err());
167    /// ```
168    #[inline]
169    pub fn set_alpha(
170        &mut self,
171        alpha: f64,
172    ) -> Result<(), SymmetricDirichletError> {
173        if alpha <= 0.0 {
174            Err(SymmetricDirichletError::AlphaTooLow { alpha })
175        } else if !alpha.is_finite() {
176            Err(SymmetricDirichletError::AlphaNotFinite { alpha })
177        } else {
178            self.set_alpha_unchecked(alpha);
179            self.ln_gamma_alpha = OnceLock::new();
180            Ok(())
181        }
182    }
183
184    /// Set the value of alpha without input validation
185    #[inline]
186    pub fn set_alpha_unchecked(&mut self, alpha: f64) {
187        self.alpha = alpha;
188        self.ln_gamma_alpha = OnceLock::new();
189    }
190
191    /// Get the number of weights, k
192    ///
193    /// # Example
194    ///
195    /// ```rust
196    /// # use rv::dist::SymmetricDirichlet;
197    /// let symdir = SymmetricDirichlet::new(1.2, 5).unwrap();
198    /// assert_eq!(symdir.k(), 5);
199    /// ```
200    #[inline]
201    pub fn k(&self) -> usize {
202        self.k
203    }
204
205    #[inline]
206    fn ln_gamma_alpha(&self) -> f64 {
207        *self.ln_gamma_alpha.get_or_init(|| ln_gammafn(self.alpha))
208    }
209}
210
211impl From<&SymmetricDirichlet> for String {
212    fn from(symdir: &SymmetricDirichlet) -> String {
213        format!("SymmetricDirichlet({}; α: {})", symdir.k, symdir.alpha)
214    }
215}
216
217impl_display!(SymmetricDirichlet);
218
219impl Sampleable<Vec<f64>> for SymmetricDirichlet {
220    fn draw<R: Rng>(&self, rng: &mut R) -> Vec<f64> {
221        let g = RGamma::new(self.alpha, 1.0).unwrap();
222        let mut xs: Vec<f64> = (0..self.k).map(|_| rng.sample(g)).collect();
223        let z: f64 = xs.iter().sum();
224        xs.iter_mut().for_each(|x| *x /= z);
225        xs
226    }
227}
228
229impl HasDensity<Vec<f64>> for SymmetricDirichlet {
230    fn ln_f(&self, x: &Vec<f64>) -> f64 {
231        let kf = self.k as f64;
232        let sum_ln_gamma = self.ln_gamma_alpha() * kf;
233        let ln_gamma_sum = ln_gammafn(self.alpha * kf);
234
235        let am1 = self.alpha - 1.0;
236        let term = x.iter().fold(0.0, |acc, &xi| am1.mul_add(xi.ln(), acc));
237
238        term - (sum_ln_gamma - ln_gamma_sum)
239    }
240}
241
242#[derive(Debug, Clone, PartialEq)]
243#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
244#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
245pub enum DirichletError {
246    /// k parameter is zero
247    KIsZero,
248    /// alpha vector is empty
249    AlphasEmpty,
250    /// alphas parameter has one or more entries less than or equal to zero
251    AlphaTooLow { ix: usize, alpha: f64 },
252    /// alphas parameter has one or infinite or NaN entries
253    AlphaNotFinite { ix: usize, alpha: f64 },
254}
255
256/// [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution)
257/// over points on the k-simplex.
258#[derive(Debug, Clone, PartialEq)]
259#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
260#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
261pub struct Dirichlet {
262    /// A `Vec` of real numbers in (0, ∞)
263    pub(crate) alphas: Vec<f64>,
264}
265
266pub struct DirichletParameters {
267    pub alphas: Vec<f64>,
268}
269
270impl Parameterized for Dirichlet {
271    type Parameters = DirichletParameters;
272
273    fn emit_params(&self) -> Self::Parameters {
274        Self::Parameters {
275            alphas: self.alphas().clone(),
276        }
277    }
278
279    fn from_params(params: Self::Parameters) -> Self {
280        Self::new_unchecked(params.alphas)
281    }
282}
283
284impl From<SymmetricDirichlet> for Dirichlet {
285    fn from(symdir: SymmetricDirichlet) -> Self {
286        Dirichlet::new_unchecked(vec![symdir.alpha; symdir.k])
287    }
288}
289
290impl From<&SymmetricDirichlet> for Dirichlet {
291    fn from(symdir: &SymmetricDirichlet) -> Self {
292        Dirichlet::new_unchecked(vec![symdir.alpha; symdir.k])
293    }
294}
295
296impl Dirichlet {
297    /// Creates a `Dirichlet` with a given `alphas` vector
298    pub fn new(alphas: Vec<f64>) -> Result<Self, DirichletError> {
299        if alphas.is_empty() {
300            return Err(DirichletError::AlphasEmpty);
301        }
302
303        alphas.iter().enumerate().try_for_each(|(ix, &alpha)| {
304            if alpha <= 0.0 {
305                Err(DirichletError::AlphaTooLow { ix, alpha })
306            } else if !alpha.is_finite() {
307                Err(DirichletError::AlphaNotFinite { ix, alpha })
308            } else {
309                Ok(())
310            }
311        })?;
312
313        Ok(Dirichlet { alphas })
314    }
315
316    /// Creates a new Dirichlet without checking whether the parameters are
317    /// valid.
318    #[inline]
319    #[must_use]
320    pub fn new_unchecked(alphas: Vec<f64>) -> Self {
321        Dirichlet { alphas }
322    }
323
324    /// Creates a `Dirichlet` where all alphas are identical.
325    ///
326    /// # Notes
327    ///
328    /// `SymmetricDirichlet` if faster and more compact, and is the preferred
329    /// way to represent a Dirichlet symmetric weights.
330    ///
331    /// # Examples
332    ///
333    /// ```
334    /// # use rv::dist::{Dirichlet, SymmetricDirichlet};
335    /// # use rv::traits::*;
336    /// let dir = Dirichlet::symmetric(1.0, 4).unwrap();
337    /// assert_eq!(*dir.alphas(), vec![1.0, 1.0, 1.0, 1.0]);
338    ///
339    /// // Equivalent to SymmetricDirichlet
340    /// let symdir = SymmetricDirichlet::new(1.0, 4).unwrap();
341    /// let x: Vec<f64> = vec![0.1, 0.4, 0.3, 0.2];
342    /// assert::close(dir.ln_f(&x), symdir.ln_f(&x), 1E-12);
343    /// ```
344    #[inline]
345    pub fn symmetric(alpha: f64, k: usize) -> Result<Self, DirichletError> {
346        if k == 0 {
347            Err(DirichletError::KIsZero)
348        } else if alpha <= 0.0 {
349            Err(DirichletError::AlphaTooLow { ix: 0, alpha })
350        } else if !alpha.is_finite() {
351            Err(DirichletError::AlphaNotFinite { ix: 0, alpha })
352        } else {
353            Ok(Dirichlet {
354                alphas: vec![alpha; k],
355            })
356        }
357    }
358
359    /// Creates a `Dirichlet` with all alphas = 0.5 (Jeffreys prior)
360    ///
361    /// # Notes
362    ///
363    /// `SymmetricDirichlet` if faster and more compact, and is the preferred
364    /// way to represent a Dirichlet symmetric weights.
365    ///
366    /// # Examples
367    ///
368    /// ```
369    /// # use rv::dist::Dirichlet;
370    /// # use rv::dist::SymmetricDirichlet;
371    /// # use rv::traits::*;
372    /// let dir = Dirichlet::jeffreys(3).unwrap();
373    /// assert_eq!(*dir.alphas(), vec![0.5, 0.5, 0.5]);
374    ///
375    /// // Equivalent to SymmetricDirichlet::jeffreys
376    /// let symdir = SymmetricDirichlet::jeffreys(3).unwrap();
377    /// let x: Vec<f64> = vec![0.1, 0.4, 0.5];
378    /// assert::close(dir.ln_f(&x), symdir.ln_f(&x), 1E-12);
379    /// ```
380    #[inline]
381    pub fn jeffreys(k: usize) -> Result<Self, DirichletError> {
382        if k == 0 {
383            Err(DirichletError::KIsZero)
384        } else {
385            Ok(Dirichlet::new_unchecked(vec![0.5; k]))
386        }
387    }
388
389    /// The length of `alphas` / the number of categories
390    #[inline]
391    #[must_use]
392    pub fn k(&self) -> usize {
393        self.alphas.len()
394    }
395
396    /// Get a reference to the weights vector, `alphas`
397    #[inline]
398    #[must_use]
399    pub fn alphas(&self) -> &Vec<f64> {
400        &self.alphas
401    }
402}
403
404impl From<&Dirichlet> for String {
405    fn from(dir: &Dirichlet) -> String {
406        format!("Dir(α: {})", vec_to_string(&dir.alphas, 5))
407    }
408}
409
410impl_display!(Dirichlet);
411
412impl ContinuousDistr<Vec<f64>> for SymmetricDirichlet {}
413
414impl Support<Vec<f64>> for SymmetricDirichlet {
415    fn supports(&self, x: &Vec<f64>) -> bool {
416        if x.len() == self.k {
417            let sum = x.iter().fold(0.0, |acc, &xi| acc + xi);
418            x.iter().all(|&xi| xi > 0.0) && (1.0 - sum).abs() < 1E-12
419        } else {
420            false
421        }
422    }
423}
424
425impl Sampleable<Vec<f64>> for Dirichlet {
426    fn draw<R: Rng>(&self, rng: &mut R) -> Vec<f64> {
427        let gammas: Vec<RGamma<f64>> = self
428            .alphas
429            .iter()
430            .map(|&alpha| RGamma::new(alpha, 1.0).unwrap())
431            .collect();
432        let mut xs: Vec<f64> = gammas.iter().map(|g| rng.sample(g)).collect();
433        let z: f64 = xs.iter().sum();
434        xs.iter_mut().for_each(|x| *x /= z);
435        xs
436    }
437}
438
439impl HasDensity<Vec<f64>> for Dirichlet {
440    fn ln_f(&self, x: &Vec<f64>) -> f64 {
441        // XXX: could cache all ln_gamma(alpha)
442        let sum_ln_gamma: f64 = self
443            .alphas
444            .iter()
445            .fold(0.0, |acc, &alpha| acc + ln_gammafn(alpha));
446
447        let ln_gamma_sum: f64 = ln_gammafn(self.alphas.iter().sum::<f64>());
448
449        let term = x
450            .iter()
451            .zip(self.alphas.iter())
452            .fold(0.0, |acc, (&xi, &alpha)| {
453                (alpha - 1.0).mul_add(xi.ln(), acc)
454            });
455
456        term - (sum_ln_gamma - ln_gamma_sum)
457    }
458}
459
460impl ContinuousDistr<Vec<f64>> for Dirichlet {}
461
462impl Support<Vec<f64>> for Dirichlet {
463    fn supports(&self, x: &Vec<f64>) -> bool {
464        if x.len() == self.alphas.len() {
465            let sum = x.iter().fold(0.0, |acc, &xi| acc + xi);
466            x.iter().all(|&xi| xi > 0.0) && (1.0 - sum).abs() < 1E-12
467        } else {
468            false
469        }
470    }
471}
472
473impl std::error::Error for SymmetricDirichletError {}
474impl std::error::Error for DirichletError {}
475
476#[cfg_attr(coverage_nightly, coverage(off))]
477impl fmt::Display for SymmetricDirichletError {
478    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
479        match self {
480            Self::AlphaTooLow { alpha } => {
481                write!(f, "alpha ({alpha}) must be greater than zero")
482            }
483            Self::AlphaNotFinite { alpha } => {
484                write!(f, "alpha ({alpha}) was non-finite")
485            }
486            Self::KIsZero => write!(f, "k must be greater than zero"),
487        }
488    }
489}
490
491impl fmt::Display for DirichletError {
492    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
493        match self {
494            Self::KIsZero => write!(f, "k must be greater than zero"),
495            Self::AlphasEmpty => write!(f, "alphas vector was empty"),
496            Self::AlphaTooLow { ix, alpha } => {
497                write!(f, "Invalid alpha at index {ix}: {alpha} <= 0.0")
498            }
499            Self::AlphaNotFinite { ix, alpha } => {
500                write!(f, "Non-finite alpha at index {ix}: {alpha}")
501            }
502        }
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509    use crate::{test_basic_impls, verify_cache_resets};
510
511    const TOL: f64 = 1E-12;
512
513    mod dir {
514        use super::*;
515
516        test_basic_impls!(Vec<f64>, Dirichlet, Dirichlet::jeffreys(4).unwrap());
517
518        #[test]
519        fn properly_sized_points_on_simplex_should_be_in_support() {
520            let dir = Dirichlet::symmetric(1.0, 4).unwrap();
521            assert!(dir.supports(&vec![0.25, 0.25, 0.25, 0.25]));
522            assert!(dir.supports(&vec![0.1, 0.2, 0.3, 0.4]));
523        }
524
525        #[test]
526        fn improperly_sized_points_should_not_be_in_support() {
527            let dir = Dirichlet::symmetric(1.0, 3).unwrap();
528            assert!(!dir.supports(&vec![0.25, 0.25, 0.25, 0.25]));
529            assert!(!dir.supports(&vec![0.1, 0.2, 0.7, 0.4]));
530        }
531
532        #[test]
533        fn properly_sized_points_off_simplex_should_not_be_in_support() {
534            let dir = Dirichlet::symmetric(1.0, 4).unwrap();
535            assert!(!dir.supports(&vec![0.25, 0.25, 0.26, 0.25]));
536            assert!(!dir.supports(&vec![0.1, 0.3, 0.3, 0.4]));
537        }
538
539        #[test]
540        fn draws_should_be_in_support() {
541            let mut rng = rand::rng();
542            // Small alphas gives us more variability in the simplex, and more
543            // variability gives us a better test.
544            let dir = Dirichlet::jeffreys(10).unwrap();
545            for _ in 0..100 {
546                let x = dir.draw(&mut rng);
547                assert!(dir.supports(&x));
548            }
549        }
550
551        #[test]
552        fn sample_should_return_the_proper_number_of_draws() {
553            let mut rng = rand::rng();
554            let dir = Dirichlet::jeffreys(3).unwrap();
555            let xs: Vec<Vec<f64>> = dir.sample(88, &mut rng);
556            assert_eq!(xs.len(), 88);
557        }
558
559        #[test]
560        fn log_pdf_symmetric() {
561            let dir = Dirichlet::symmetric(1.0, 3).unwrap();
562            assert::close(
563                dir.ln_pdf(&vec![0.2, 0.3, 0.5]),
564                std::f64::consts::LN_2,
565                TOL,
566            );
567        }
568
569        #[test]
570        fn log_pdf_jeffreys() {
571            let dir = Dirichlet::jeffreys(3).unwrap();
572            assert::close(
573                dir.ln_pdf(&vec![0.2, 0.3, 0.5]),
574                -0.084_598_117_749_354_22,
575                TOL,
576            );
577        }
578
579        #[test]
580        fn log_pdf() {
581            let dir = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap();
582            assert::close(
583                dir.ln_pdf(&vec![0.2, 0.3, 0.5]),
584                1.504_077_396_776_273_7,
585                TOL,
586            );
587        }
588    }
589
590    mod symdir {
591        use std::f64::consts::PI;
592
593        use super::*;
594
595        test_basic_impls!(
596            Vec<f64>,
597            SymmetricDirichlet,
598            SymmetricDirichlet::jeffreys(4).unwrap()
599        );
600
601        #[test]
602        fn sample_should_return_the_proper_number_of_draws() {
603            let mut rng = rand::rng();
604            let symdir = SymmetricDirichlet::jeffreys(3).unwrap();
605            let xs: Vec<Vec<f64>> = symdir.sample(88, &mut rng);
606            assert_eq!(xs.len(), 88);
607        }
608
609        #[test]
610        fn log_pdf_jeffreys() {
611            let symdir = SymmetricDirichlet::jeffreys(3).unwrap();
612            assert::close(
613                symdir.ln_pdf(&vec![0.2, 0.3, 0.5]),
614                -0.084_598_117_749_354_22,
615                TOL,
616            );
617        }
618
619        #[test]
620        fn properly_sized_points_off_simplex_should_not_be_in_support() {
621            let symdir = SymmetricDirichlet::new(1.0, 4).unwrap();
622            assert!(!symdir.supports(&vec![0.25, 0.25, 0.26, 0.25]));
623            assert!(!symdir.supports(&vec![0.1, 0.3, 0.3, 0.4]));
624        }
625
626        #[test]
627        fn draws_should_be_in_support() {
628            let mut rng = rand::rng();
629            // Small alphas gives us more variability in the simplex, and more
630            // variability gives us a better test.
631            let symdir = SymmetricDirichlet::jeffreys(10).unwrap();
632            for _ in 0..100 {
633                let x: Vec<f64> = symdir.draw(&mut rng);
634                assert!(symdir.supports(&x));
635            }
636        }
637
638        verify_cache_resets!(
639            [unchecked],
640            ln_f_is_same_after_reset_unchecked_alpha_identically,
641            set_alpha_unchecked,
642            SymmetricDirichlet::new(1.2, 2).unwrap(),
643            vec![0.1_f64, 0.9_f64],
644            1.2,
645            PI
646        );
647
648        verify_cache_resets!(
649            [checked],
650            ln_f_is_same_after_reset_checked_alpha_identically,
651            set_alpha,
652            SymmetricDirichlet::new(1.2, 2).unwrap(),
653            vec![0.1_f64, 0.9_f64],
654            1.2,
655            PI
656        );
657    }
658
659    #[test]
660    fn emit_and_from_params_are_identity() {
661        let dist_a = SymmetricDirichlet::new(1.5, 7).unwrap();
662        let dist_b = SymmetricDirichlet::from_params(dist_a.emit_params());
663        assert_eq!(dist_a, dist_b);
664    }
665}