Skip to main content

scirs2_stats/distributions/multivariate/
dirichlet.rs

1//! Dirichlet distribution functions
2//!
3//! This module provides functionality for the Dirichlet distribution.
4
5use crate::error::{StatsError, StatsResult};
6use crate::sampling::SampleableDistribution;
7use scirs2_core::ndarray::{Array1, ArrayBase, Data, Ix1};
8use scirs2_core::random::prelude::*;
9use scirs2_core::random::{Distribution, Gamma as RandGamma};
10use std::fmt::Debug;
11
12/// Implementation of the natural logarithm of the gamma function
13///
14/// This is a workaround for the unstable gamma function in Rust
15#[allow(dead_code)]
16fn lgamma(x: f64) -> f64 {
17    if x <= 0.0 {
18        panic!("lgamma requires positive input");
19    }
20
21    // For integers, we can use a simpler calculation
22    if x.fract() == 0.0 && x <= 20.0 {
23        let n = x as usize;
24        if n == 1 || n == 2 {
25            return 0.0; // ln(1) = 0
26        }
27
28        let mut result = 0.0;
29        for i in 2..n {
30            result += (i as f64).ln();
31        }
32        return result;
33    }
34
35    // For x = 0.5, we have Γ(0.5) = sqrt(π)
36    if (x - 0.5).abs() < 1e-10 {
37        return (std::f64::consts::PI.sqrt()).ln();
38    }
39
40    // For x > 1, use the recurrence relation: Γ(x+1) = x * Γ(x)
41    if x > 1.0 {
42        return (x - 1.0).ln() + lgamma(x - 1.0);
43    }
44
45    // For 0 < x < 1, use the reflection formula: Γ(x) * Γ(1-x) = π/sin(πx)
46    if x < 1.0 {
47        return (std::f64::consts::PI / (std::f64::consts::PI * x).sin()).ln() - lgamma(1.0 - x);
48    }
49
50    // Lanczos approximation for other values around 1
51    let p = [
52        676.5203681218851,
53        -1259.1392167224028,
54        771.323_428_777_653_1,
55        -176.615_029_162_140_6,
56        12.507343278686905,
57        -0.13857109526572012,
58        9.984_369_578_019_572e-6,
59        1.5056327351493116e-7,
60    ];
61
62    let x_adj = x - 1.0;
63    let t = x_adj + 7.5;
64
65    let mut sum = 0.0;
66    for (i, &coef) in p.iter().enumerate() {
67        sum += coef / (x_adj + (i + 1) as f64);
68    }
69
70    let pi = std::f64::consts::PI;
71    let sqrt_2pi = (2.0 * pi).sqrt();
72
73    sqrt_2pi.ln() + sum.ln() + (x_adj + 0.5) * t.ln() - t
74}
75
76/// Dirichlet distribution structure
77#[derive(Debug, Clone)]
78pub struct Dirichlet {
79    /// Concentration parameters (alpha values)
80    pub alpha: Array1<f64>,
81    /// Dimension of the distribution (number of categories)
82    pub dim: usize,
83    /// Natural log of the normalization constant (cached for efficiency)
84    log_norm_const: f64,
85}
86
87impl Dirichlet {
88    /// Create a new Dirichlet distribution with given concentration parameters
89    ///
90    /// # Arguments
91    ///
92    /// * `alpha` - Concentration parameters (all values must be positive)
93    ///
94    /// # Returns
95    ///
96    /// * A new Dirichlet distribution instance
97    ///
98    /// # Examples
99    ///
100    /// ```
101    /// use scirs2_core::ndarray::array;
102    /// use scirs2_stats::distributions::multivariate::dirichlet::Dirichlet;
103    ///
104    /// // Create a 3D Dirichlet distribution with symmetric parameters (equivalent to a uniform distribution over the simplex)
105    /// let alpha = array![1.0, 1.0, 1.0];
106    /// let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
107    /// ```
108    pub fn new<D>(alpha: ArrayBase<D, Ix1>) -> StatsResult<Self>
109    where
110        D: Data<Elem = f64>,
111    {
112        let alpha_owned = alpha.to_owned();
113        let dim = alpha_owned.len();
114
115        // Check that all _alpha values are positive
116        for &a in alpha_owned.iter() {
117            if a <= 0.0 {
118                return Err(StatsError::DomainError(
119                    "All concentration parameters must be positive".to_string(),
120                ));
121            }
122        }
123
124        let alpha_sum = alpha_owned.sum();
125
126        // Compute the log normalization constant:
127        // ln[B(α)] = sum(ln[Γ(αᵢ)]) - ln[Γ(sum(αᵢ))]
128        let mut log_norm_const = 0.0;
129
130        // Sum of log(Gamma(alpha_i))
131        for &a in alpha_owned.iter() {
132            log_norm_const += lgamma(a);
133        }
134
135        // Subtract log(Gamma(sum(alpha_i)))
136        log_norm_const -= lgamma(alpha_sum);
137
138        Ok(Dirichlet {
139            alpha: alpha_owned,
140            dim,
141            log_norm_const,
142        })
143    }
144
145    /// Calculate the probability density function (PDF) at a given point
146    ///
147    /// # Arguments
148    ///
149    /// * `x` - The point at which to evaluate the PDF (must sum to 1)
150    ///
151    /// # Returns
152    ///
153    /// * The value of the PDF at the given point
154    ///
155    /// # Examples
156    ///
157    /// ```
158    /// use scirs2_core::ndarray::array;
159    /// use scirs2_stats::distributions::multivariate::dirichlet::Dirichlet;
160    ///
161    /// let alpha = array![1.0, 1.0, 1.0];
162    /// let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
163    ///
164    /// // PDF for a uniform Dirichlet at any point on the simplex is 2 (in 3D)
165    /// let point = array![0.3, 0.3, 0.4];
166    /// let pdf_value = dirichlet.pdf(&point);
167    /// assert!((pdf_value - 2.0).abs() < 1e-10);
168    /// ```
169    pub fn pdf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
170    where
171        D: Data<Elem = f64>,
172    {
173        if x.len() != self.dim {
174            return 0.0; // Return zero for invalid dimensions
175        }
176
177        // Check if x is on the simplex (all values > 0 and sum to 1)
178        let sum: f64 = x.iter().sum();
179        if (sum - 1.0).abs() > 1e-10 {
180            return 0.0; // Point not on the simplex
181        }
182
183        for &val in x.iter() {
184            if val <= 0.0 || val >= 1.0 {
185                return 0.0; // Values must be in (0, 1)
186            }
187        }
188
189        // Calculate the PDF using the formula:
190        // p(x|α) = [∏ xᵢ^(αᵢ-1)] / B(α)
191        // where B(α) is the multivariate beta function
192
193        // We'll work in log space for numerical stability
194        let log_pdf = self.logpdf(x);
195        log_pdf.exp()
196    }
197
198    /// Calculate the log probability density function (log PDF) at a given point
199    ///
200    /// # Arguments
201    ///
202    /// * `x` - The point at which to evaluate the log PDF (must sum to 1)
203    ///
204    /// # Returns
205    ///
206    /// * The value of the log PDF at the given point
207    ///
208    /// # Examples
209    ///
210    /// ```
211    /// use scirs2_core::ndarray::array;
212    /// use scirs2_stats::distributions::multivariate::dirichlet::Dirichlet;
213    ///
214    /// let alpha = array![1.0, 1.0, 1.0];
215    /// let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
216    ///
217    /// let point = array![0.3, 0.3, 0.4];
218    /// let logpdf_value = dirichlet.logpdf(&point);
219    /// assert!((logpdf_value - 0.693).abs() < 1e-3);  // ln(2) ≈ 0.693
220    /// ```
221    pub fn logpdf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
222    where
223        D: Data<Elem = f64>,
224    {
225        if x.len() != self.dim {
226            return f64::NEG_INFINITY; // Return -∞ for invalid dimensions
227        }
228
229        // Check if x is on the simplex (all values > 0 and sum to 1)
230        let sum: f64 = x.iter().sum();
231        if (sum - 1.0).abs() > 1e-10 {
232            return f64::NEG_INFINITY; // Point not on the simplex
233        }
234
235        for &val in x.iter() {
236            if val <= 0.0 || val >= 1.0 {
237                return f64::NEG_INFINITY; // Values must be in (0, 1)
238            }
239        }
240
241        // Calculate the log PDF using the formula:
242        // log p(x|α) = sum[(αᵢ-1)log(xᵢ)] - log B(α)
243        let mut log_pdf = -self.log_norm_const;
244
245        for i in 0..self.dim {
246            log_pdf += (self.alpha[i] - 1.0) * x[i].ln();
247        }
248
249        log_pdf
250    }
251
252    /// Generate random samples from the distribution
253    ///
254    /// # Arguments
255    ///
256    /// * `size` - Number of samples to generate
257    ///
258    /// # Returns
259    ///
260    /// * Matrix where each row is a random sample
261    ///
262    /// # Examples
263    ///
264    /// ```
265    /// use scirs2_core::ndarray::array;
266    /// use scirs2_stats::distributions::multivariate::dirichlet::Dirichlet;
267    ///
268    /// let alpha = array![1.0, 2.0, 3.0];
269    /// let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
270    ///
271    /// let samples = dirichlet.rvs(10).expect("Operation failed");
272    /// assert_eq!(samples.len(), 10);
273    /// assert_eq!(samples[0].len(), 3);
274    /// ```
275    pub fn rvs(&self, size: usize) -> StatsResult<Vec<Array1<f64>>> {
276        let mut rng = thread_rng();
277        let mut samples = Vec::with_capacity(size);
278
279        // Generate samples using the gamma method:
280        // 1. Generate independent gamma samples with shape αᵢ and scale=1
281        // 2. Normalize by their sum
282
283        for _ in 0..size {
284            let mut sample = Array1::<f64>::zeros(self.dim);
285            let mut sum = 0.0;
286
287            // Generate gamma samples
288            for i in 0..self.dim {
289                let gamma_dist = RandGamma::new(self.alpha[i], 1.0).map_err(|_| {
290                    StatsError::ComputationError("Failed to create gamma distribution".to_string())
291                })?;
292
293                let gamma_sample = gamma_dist.sample(&mut rng);
294                sample[i] = gamma_sample;
295                sum += gamma_sample;
296            }
297
298            // Normalize to get a point on the simplex
299            sample.mapv_inplace(|x| x / sum);
300            samples.push(sample);
301        }
302
303        Ok(samples)
304    }
305
306    /// Generate a single random sample from the distribution
307    ///
308    /// # Returns
309    ///
310    /// * Vector representing a single sample
311    ///
312    /// # Examples
313    ///
314    /// ```
315    /// use scirs2_core::ndarray::array;
316    /// use scirs2_stats::distributions::multivariate::dirichlet::Dirichlet;
317    ///
318    /// let alpha = array![1.0, 2.0, 3.0];
319    /// let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
320    ///
321    /// let sample = dirichlet.rvs_single().expect("Operation failed");
322    /// assert_eq!(sample.len(), 3);
323    /// ```
324    pub fn rvs_single(&self) -> StatsResult<Array1<f64>> {
325        let samples = self.rvs(1)?;
326        Ok(samples[0].clone())
327    }
328}
329
330/// Create a Dirichlet distribution with the given parameters.
331///
332/// This is a convenience function to create a Dirichlet distribution with
333/// the given concentration parameters.
334///
335/// # Arguments
336///
337/// * `alpha` - Concentration parameters (all values must be positive)
338///
339/// # Returns
340///
341/// * A Dirichlet distribution object
342///
343/// # Examples
344///
345/// ```
346/// use scirs2_core::ndarray::array;
347/// use scirs2_stats::distributions::multivariate;
348///
349/// let alpha = array![1.0, 1.0, 1.0];
350/// let dirichlet = multivariate::dirichlet(&alpha).expect("Operation failed");
351/// let point = array![0.3, 0.3, 0.4];
352/// let pdf_at_point = dirichlet.pdf(&point);
353/// ```
354#[allow(dead_code)]
355pub fn dirichlet<D>(alpha: &ArrayBase<D, Ix1>) -> StatsResult<Dirichlet>
356where
357    D: Data<Elem = f64>,
358{
359    Dirichlet::new(alpha.to_owned())
360}
361
362/// Implementation of SampleableDistribution for Dirichlet
363impl SampleableDistribution<Array1<f64>> for Dirichlet {
364    fn rvs(&self, size: usize) -> StatsResult<Vec<Array1<f64>>> {
365        self.rvs(size)
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use approx::assert_relative_eq;
373    use scirs2_core::ndarray::array;
374
375    #[test]
376    fn test_dirichlet_creation() {
377        // Uniform Dirichlet
378        let alpha = array![1.0, 1.0, 1.0];
379        let dirichlet = Dirichlet::new(alpha.clone()).expect("Operation failed");
380
381        assert_eq!(dirichlet.dim, 3);
382        assert_eq!(dirichlet.alpha, alpha);
383
384        // Non-uniform Dirichlet
385        let alpha2 = array![2.0, 3.0, 4.0];
386        let dirichlet2 = Dirichlet::new(alpha2.clone()).expect("Operation failed");
387
388        assert_eq!(dirichlet2.dim, 3);
389        assert_eq!(dirichlet2.alpha, alpha2);
390    }
391
392    #[test]
393    fn test_dirichlet_creation_errors() {
394        // Zero alpha value
395        let alpha = array![1.0, 0.0, 1.0];
396        assert!(Dirichlet::new(alpha).is_err());
397
398        // Negative alpha value
399        let alpha = array![1.0, -1.0, 1.0];
400        assert!(Dirichlet::new(alpha).is_err());
401    }
402
403    #[test]
404    fn test_dirichlet_pdf() {
405        // Uniform Dirichlet (alpha = [1,1,1])
406        // PDF value should be constant on the simplex: 2 for 3D
407        let alpha = array![1.0, 1.0, 1.0];
408        let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
409
410        let point1 = array![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
411        let point2 = array![0.2, 0.3, 0.5];
412
413        assert_relative_eq!(dirichlet.pdf(&point1), 2.0, epsilon = 1e-10);
414        assert_relative_eq!(dirichlet.pdf(&point2), 2.0, epsilon = 1e-10);
415
416        // Concentrated Dirichlet
417        let alpha = array![5.0, 5.0, 5.0];
418        let concentrated = Dirichlet::new(alpha).expect("Operation failed");
419
420        // PDF should be higher at the center than at the edges
421        let center = array![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
422        let edge = array![0.01, 0.01, 0.98];
423
424        assert!(concentrated.pdf(&center) > concentrated.pdf(&edge));
425    }
426
427    #[test]
428    fn test_dirichlet_pdf_edge_cases() {
429        let alpha = array![1.0, 1.0, 1.0];
430        let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
431
432        // Points not on the simplex
433        let invalid1 = array![0.3, 0.3, 0.3]; // Sum != 1
434        let invalid2 = array![0.5, 0.6, 0.2]; // Sum > 1
435        let invalid3 = array![0.0, 0.5, 0.5]; // Contains 0
436        let invalid4 = array![1.0, 0.0, 0.0]; // Contains 0
437
438        assert_eq!(dirichlet.pdf(&invalid1), 0.0);
439        assert_eq!(dirichlet.pdf(&invalid2), 0.0);
440        assert_eq!(dirichlet.pdf(&invalid3), 0.0);
441        assert_eq!(dirichlet.pdf(&invalid4), 0.0);
442    }
443
444    #[test]
445    fn test_dirichlet_logpdf() {
446        let alpha = array![1.0, 1.0, 1.0];
447        let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
448
449        let point = array![0.3, 0.3, 0.4];
450
451        // Log of uniform Dirichlet with alpha=[1,1,1] is ln(2) ≈ 0.693
452        assert_relative_eq!(dirichlet.logpdf(&point), 0.693, epsilon = 1e-3);
453
454        // Check that exp(logPDF) = PDF
455        assert_relative_eq!(
456            dirichlet.logpdf(&point).exp(),
457            dirichlet.pdf(&point),
458            epsilon = 1e-10
459        );
460    }
461
462    #[test]
463    fn test_dirichlet_rvs() {
464        let alpha = array![1.0, 2.0, 3.0];
465        let dirichlet = Dirichlet::new(alpha.clone()).expect("Operation failed");
466
467        // Generate samples
468        let n_samples_ = 1000;
469        let samples = dirichlet.rvs(n_samples_).expect("Operation failed");
470
471        // Check number of samples
472        assert_eq!(samples.len(), n_samples_);
473
474        // Check that all samples sum to 1 (within floating point error)
475        for sample in &samples {
476            let sum: f64 = sample.iter().sum();
477            assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
478
479            // Check all values are in [0,1]
480            for &val in sample.iter() {
481                assert!(val >= 0.0 && val <= 1.0);
482            }
483        }
484
485        // Check sample mean is close to expected mean: E[X_i] = alpha_i / sum(alpha)
486        let mut sample_mean = [0.0; 3];
487        for sample in &samples {
488            for i in 0..3 {
489                sample_mean[i] += sample[i];
490            }
491        }
492
493        let alpha_sum = alpha.sum();
494        for i in 0..3 {
495            sample_mean[i] /= n_samples_ as f64;
496            let expected_mean = alpha[i] / alpha_sum;
497            assert_relative_eq!(sample_mean[i], expected_mean, epsilon = 0.05);
498        }
499    }
500
501    #[test]
502    fn test_dirichlet_rvs_single() {
503        let alpha = array![1.0, 2.0, 3.0];
504        let dirichlet = Dirichlet::new(alpha.clone()).expect("Operation failed");
505
506        let sample = dirichlet.rvs_single().expect("Operation failed");
507
508        // Check sample dimension
509        assert_eq!(sample.len(), 3);
510
511        // Check sample sums to 1
512        let sum: f64 = sample.iter().sum();
513        assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
514
515        // Check all values in [0,1]
516        for &val in sample.iter() {
517            assert!(val >= 0.0 && val <= 1.0);
518        }
519    }
520}